KDTree

时间:2023-03-10 01:10:24
KDTree

学习链接:http://www.cnblogs.com/eyeszjwang/articles/2429382.html

下面实现的kdtree支持以下操作:
(1) 插入一个节点
(2) 插入n个节点
(3) 查找距离某个给定点距离最小的K个节点。

插入的时候可能会导致树严重不平衡,这个时候会重建某个子树。

 #include <algorithm>
#include <queue> /***
_NodeType: 节点类型
_CompareFuncType: 比较函数类型,它应该接受两个_NodeType并返回0或者1
_UpdateFuncType: 更新函数类型,它有三个参数,分别表示Father,leftSon,rightSon,参数的类型的 _NodeType*
_DIMENSION: 维数
_DISTANCE_TYPE: 两个_NodeType之间距离的类型
_REBUILD_ALPHA: 这个参数用来维持树大致平衡,它应该大于50小于等于100
***/
template<class _NodeType,
class _CompareFuncType,
class _UpdateFuncType,
int _DIMENSION,
class _DISTANCE_TYPE,
int _REBUILD_ALPHA=>
class KDTree
{
private:
struct TreeNode
{
TreeNode(const _NodeType& _Node):m_Left(nullptr),m_Right(nullptr),m_Size(),
m_Node(_Node) {}
TreeNode() {} _NodeType m_Node;
TreeNode* m_Left;
TreeNode* m_Right;
unsigned int m_Size;
}; TreeNode* NewNode(const _NodeType& _Node)
{
return new TreeNode(_Node);
} public:
KDTree(_CompareFuncType** _CompareFuncs,_UpdateFuncType* _UpdateFunc):m_Root(nullptr)
{
for(unsigned int Idx=;Idx<_DIMENSION;++Idx)
{
m_CompareFuncs[Idx]=_CompareFuncs[Idx];
}
m_UpdateFunc=_UpdateFunc;
} void insert(const _NodeType& _Value)
{
TreeNode* BadTreeNode=nullptr;
TreeNode* BadTreeNodeParent=nullptr;
unsigned int BadTreeNodeDimension=;
m_Root=Insert(m_Root,NewNode(_Value),,
BadTreeNode,BadTreeNodeParent,BadTreeNodeDimension); if(BadTreeNode!=nullptr)
{
if(BadTreeNodeParent==nullptr)
{
m_Root=RebuildTree(BadTreeNode,BadTreeNodeDimension);
}
else
{
BadTreeNodeParent=RebuildTree(BadTreeNode,BadTreeNodeDimension);
}
}
} template<class _NodeTypeBegin>
void insert(_NodeTypeBegin _ValueArray,const unsigned int _ValueArraySize)
{
if(_ValueArraySize<=) return;
if(m_Root==nullptr)
{
m_Root=BuildGroup(_ValueArray,,_ValueArraySize-,);
}
else
{
if(m_Root->m_Size<=(unsigned int)(_ValueArraySize*_REBUILD_ALPHA/))
{
_NodeType* NewValueArray=new _NodeType[_ValueArraySize+m_Root->m_Size];
unsigned int NewValueArraySize=StoreSubtreeNodeIntoArray(m_Root,NewValueArray);
ClearSubTrees(m_Root);
for(unsigned int Idx=;Idx<_ValueArraySize;++Idx)
{
NewValueArray[NewValueArraySize++]=_ValueArray[Idx];
}
m_Root=BuildGroup(NewValueArray,,NewValueArraySize-,);
delete[] NewValueArray;
}
else
{
for(unsigned int Idx=;Idx<_ValueArraySize;++Idx)
{
insert(_ValueArray[Idx]);
}
}
}
} /***
查找距离_Value "最近" 的_SearchNumber个元素 存储在_StoreAnswerArray _ComputeMinDistanceFunc接受两个_NodeType(first,second) 用来计算second范围内所有点到first的 "最近"距离
_DISTANCE_TYPE是它的返回值类型
_ComputeDistanceFunc接受两个_NodeType(first,second) 计算first和second的距离
_DISTANCE_TYPE是它的返回值类型
_CompareDistanceFunc它接受两个_DISTANCE_TYPE(first,second) 并返回0或者1
1表示first小于second
函数返回查找到的元素个数(有可能小于_SearchNumber)
***/
template<class _ComputeDistanceFuncType,
class _CompareDistanceFuncType>
unsigned int searchKNear(const _NodeType& _Value,_NodeType* _StoreAnswerArray,
const unsigned int _SearchNumber,_CompareDistanceFuncType* _CompareDistanceFunc,
_ComputeDistanceFuncType* _ComputeMinDistanceFunc,
_ComputeDistanceFuncType* _ComputeDistanceFunc)
{
if(_SearchNumber==) return ;
unsigned int AnswerArrayElementNumber=;
SearchKNear(m_Root,_Value,_StoreAnswerArray,AnswerArrayElementNumber,_SearchNumber,
_CompareDistanceFunc,_ComputeMinDistanceFunc,_ComputeDistanceFunc);
return AnswerArrayElementNumber;
} unsigned int size() const
{
if(m_Root) return m_Root->m_Size;
return ;
} private:
template<class _ComputeDistanceFuncType,
class _CompareDistanceFuncType>
void SearchKNear(
TreeNode* _Root,const _NodeType& _Value,_NodeType* _StoreAnswerArray,
unsigned int &_CurAnswerArrayElementNumber,
const unsigned int _SearchNumber,_CompareDistanceFuncType* _CompareDistanceFunc,
_ComputeDistanceFuncType* _ComputeMinDistanceFunc,
_ComputeDistanceFuncType* _ComputeDistanceFunc)
{
if(_Root==nullptr) return;
if(_CurAnswerArrayElementNumber==)
{
_StoreAnswerArray[]=_Root->m_Node;
++_CurAnswerArrayElementNumber;
}
else
{
_DISTANCE_TYPE CurNodeDis=_ComputeDistanceFunc(_Value,_Root->m_Node);
for(unsigned int Idx=_CurAnswerArrayElementNumber-;;--Idx)
{
_DISTANCE_TYPE PreNodeDis=_ComputeDistanceFunc(_Value,_StoreAnswerArray[Idx]);
if(_CompareDistanceFunc(CurNodeDis,PreNodeDis))
{
if(Idx+<_SearchNumber)
{
_StoreAnswerArray[Idx+]=_StoreAnswerArray[Idx];
}
}
else
{
if(Idx+<_SearchNumber) _StoreAnswerArray[Idx+]=_Root->m_Node;
if(_CurAnswerArrayElementNumber<_SearchNumber)
{
++_CurAnswerArrayElementNumber;
}
break;
}
if(==Idx)
{
_StoreAnswerArray[]=_Root->m_Node;
if(_CurAnswerArrayElementNumber<_SearchNumber)
{
++_CurAnswerArrayElementNumber;
}
break;
}
}
}
if(_Root->m_Left&&_Root->m_Right)
{
_DISTANCE_TYPE LSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Left->m_Node);
_DISTANCE_TYPE RSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Right->m_Node);
_DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc(
_Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-]); if(_CompareDistanceFunc(LSonMinDis,RSonMinDis))
{
if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis))
{
SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
_SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
_ComputeDistanceFunc);
}
CurMaxDis=_ComputeDistanceFunc(
_Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-]);
if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis))
{
SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
_SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
_ComputeDistanceFunc);
}
}
else
{
if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis))
{
SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
_SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
_ComputeDistanceFunc);
}
CurMaxDis=_ComputeDistanceFunc(
_Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-]);
if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis))
{
SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
_SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
_ComputeDistanceFunc);
}
}
}
else if(_Root->m_Left)
{
_DISTANCE_TYPE LSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Left->m_Node);
_DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc(
_Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-]);
if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis))
{
SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
_SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
_ComputeDistanceFunc);
}
}
else if(_Root->m_Right)
{
_DISTANCE_TYPE RSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Right->m_Node);
_DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc(
_Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-]);
if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis))
{
SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
_SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
_ComputeDistanceFunc);
}
}
} unsigned int StoreSubtreeNodeIntoArray(TreeNode* _Root,_NodeType* _NodeArray)
{
if(_Root==nullptr) return ;
unsigned int NodeArraySize=;
std::queue<TreeNode*> TmpQue;
TmpQue.push(_Root);
_NodeArray[NodeArraySize++]=_Root->m_Node;
while(!TmpQue.empty())
{
TreeNode* Tmp=TmpQue.front(); TmpQue.pop();
if(Tmp->m_Left)
{
TmpQue.push(Tmp->m_Left);
_NodeArray[NodeArraySize++]=Tmp->m_Left->m_Node;
}
if(Tmp->m_Right)
{
TmpQue.push(Tmp->m_Right);
_NodeArray[NodeArraySize++]=Tmp->m_Right->m_Node;
}
}
return NodeArraySize;
} void ClearSubTrees(TreeNode* _Root)
{
if(_Root)
{
if(_Root->m_Left) ClearSubTrees(_Root->m_Left);
if(_Root->m_Right) ClearSubTrees(_Root->m_Right);
delete _Root;
}
} TreeNode* RebuildTree(TreeNode* _Root,const unsigned int _Dimension)
{
_NodeType* TmpPool=new _NodeType[_Root->m_Size];
unsigned int TmpPoolSize=StoreSubtreeNodeIntoArray(_Root,TmpPool);
ClearSubTrees(_Root);
if(TmpPoolSize==) return nullptr;
TreeNode* Tmp=BuildGroup(TmpPool,,TmpPoolSize-,_Dimension);
delete[] TmpPool;
return Tmp;
} template<class _NodeTypeBegin>
TreeNode* BuildGroup(_NodeTypeBegin _TmpPool,const unsigned int _Left,const unsigned int _Right,const unsigned int _CurLayerDimention)
{
if(_Left>_Right) return nullptr; const unsigned int MidPos=(_Left+_Right)>>;
std::nth_element(_TmpPool+_Left,_TmpPool+MidPos,_TmpPool+_Right+,
m_CompareFuncs[_CurLayerDimention]); TreeNode* CurNode=NewNode(_TmpPool[MidPos]);
if(_Left+<=MidPos)
{
CurNode->m_Left=BuildGroup(_TmpPool,_Left,MidPos-,(_CurLayerDimention+)%_DIMENSION);
}
CurNode->m_Right=BuildGroup(_TmpPool,MidPos+,_Right,(_CurLayerDimention+)%_DIMENSION); PushUp(CurNode);
return CurNode;
} void PushUp(TreeNode* _Root)
{
if(_Root)
{
_Root->m_Size=+SonSize(_Root->m_Left)+SonSize(_Root->m_Right);
_NodeType *LsonNode=_Root->m_Left?&(_Root->m_Left->m_Node):nullptr;
_NodeType *RsonNode=_Root->m_Right?&_Root->m_Right->m_Node:nullptr;
m_UpdateFunc(&_Root->m_Node,LsonNode,RsonNode);
}
} TreeNode* Insert(
TreeNode* _Root,
TreeNode* _InsertNode,
const unsigned int _CurLayerDimention,
TreeNode* &_BadTreeNode,
TreeNode* &_BadTreeNodeParent,
unsigned int& _BadTreeNodeDimension)
{
if(nullptr==_Root)
{
PushUp(_InsertNode);
return _InsertNode;
}
if(m_CompareFuncs[_CurLayerDimention](_InsertNode->m_Node,_Root->m_Node))
{
_Root->m_Left=Insert(_Root->m_Left,_InsertNode,(_CurLayerDimention+)%_DIMENSION,
_BadTreeNode,_BadTreeNodeParent,_BadTreeNodeDimension);
}
else
{
_Root->m_Right=Insert(_Root->m_Right,_InsertNode,(_CurLayerDimention+)%_DIMENSION,
_BadTreeNode,_BadTreeNodeParent,_BadTreeNodeDimension);
} PushUp(_Root); if(_BadTreeNode==nullptr)
{
if(IsSeriousBadTree(_Root))
{
_BadTreeNode=_Root;
_BadTreeNodeDimension=_CurLayerDimention;
}
}
else if(_BadTreeNode==_Root->m_Left||_BadTreeNode==_Root->m_Right)
{
_BadTreeNodeParent=_Root;
}
return _Root;
} unsigned int SonSize(TreeNode* _Node)
{
if(_Node==nullptr) return ;
return _Node->m_Size;
} bool IsSeriousBadTree(TreeNode* _Root)
{
if(SonSize(_Root)==) return false;
return std::max(SonSize(_Root->m_Left),SonSize(_Root->m_Right))
>=(unsigned int)(SonSize(_Root)*_REBUILD_ALPHA/)+;
} TreeNode* m_Root;
_CompareFuncType* m_CompareFuncs[_DIMENSION];
_UpdateFuncType* m_UpdateFunc;
};

下面是一个简单的测试代码,两个点之间的距离定义为曼哈顿距离。

 struct node
{
int x,y;
int MinX,MaxX,MinY,MaxY; node(int _x=,int _y=):x(_x),y(_y) {} }; typedef int Func(const node&,const node&);
typedef void Func1(node*,node*,node*); int cmp0(const node &a,const node &b)
{
return a.x<b.x;
} int cmp1(const node &a,const node &b)
{
return a.y<b.y;
} void pushUp(node *Fa,node *lson,node *rson)
{
Fa->MinX=Fa->MaxX=Fa->x;
Fa->MinY=Fa->MaxY=Fa->y;
for(int i=;i<;++i)
{
node* p=i==?lson:rson;
if(!p) continue; Fa->MinX=min(Fa->MinX,p->MinX);
Fa->MaxX=max(Fa->MaxX,p->MaxX);
Fa->MinY=min(Fa->MinY,p->MinY);
Fa->MaxY=max(Fa->MaxY,p->MaxY);
}
} int caldis(const node &a,const node &b)
{
return abs(a.x-b.x)+abs(a.y-b.y);
} int calMinDis(const node &a,const node &b)
{
int xx=;
int yy=;
if(a.x<b.MinX) xx=b.MinX-a.x;
else if(a.x>b.MaxX) xx=a.x-b.MaxX; if(a.y<b.MinY) yy=b.MinY-a.y;
else if(a.y>b.MaxY) yy=a.y-b.MaxY; return xx+yy;
} int cmp(int x,int y)
{
return x<y;
} int main()
{
Func* funs[]={cmp0,cmp1};
KDTree<node,Func,Func1,,int> *T=
new KDTree<node,Func,Func1,,int>(funs,pushUp); node a[]={node(,),node(-,),node(,),node(,)};
T->insert(a,);
T->insert(node(,));
unsigned int Num=T->searchKNear(node(,),a,,cmp,calMinDis,caldis);
for(unsigned int Idx=;Idx<Num;++Idx)
{
printf("%d %d\n",a[Idx].x,a[Idx].y);
}
/**
5 6
5 9
2 3
**/
}