一、什么是KD-Tree?
KD-Tree,又称(k-dimensional tree),是一种基于二叉树的数据结构。它可以用来高效地处理多维空间搜索问题,例如 最近邻搜索(nearest neighbor search) 和 范围搜索(range search) 等。
二、k-d树的结构
KD-Tree 是每个节点都为 k kk 维点的二叉树。所有非叶子节点可以视作用一个超平面把空间分割成两个半空间。节点左边的子树代表在超平面左边的点,节点右边的子树代表在超平面右边的点。
选择超平面的方法如下:每个节点都与 k kk 维中垂直于超平面的那一维有关。因此,如果选择按照 x xx 轴划分,所有 x xx 值小于指定值的节点都会出现在左子树,所有 x xx 值大于指定值的节点都会出现在右子树。这样,超平面可以用该 x xx 值来确定,其法线为 x xx 轴的单位向量。
三、k-d树的创建
有很多种方法可以选择轴垂直分割面( axis-aligned splitting planes ),所以有很多种创建 KD-Tree 的方法。
最典型的方法如下:
随着树的深度轮流选择轴当作分割面。(例如:在三维空间中根节点是 x 轴垂直分割面,其子节点皆为 y 轴垂直分割面,其孙节点皆为 z 轴垂直分割面,其曾孙节点则皆为 x 轴垂直分割面,依此类推。)
点由垂直分割面之轴座标的中位数区分并放入子树
这个方法产生一个平衡的k-d树。每个叶节点的高度都十分接近。然而,平衡的树不一定对每个应用都是最佳的。
四、k-d树的应用
最邻近搜索(Nearest Neighbor Search)
最邻近搜索是一种简单的分类或回归方法,它的基本思想是找到与待分类样本最接近的已知类别的样本,并将待分类样本归为该类别。最邻近搜索可以应用于各种不同的数据类型,例如文本、图像、音频等。
最邻近搜索用来找出在树中与输入点最接近的点。
k-d树最邻近搜索的过程如下:
从根节点开始,递归的往下移。往左还是往右的决定方法与插入元素的方法一样(如果输入点在分区面的左边则进入左子节点,在右边则进入右子节点)。
一旦移动到叶节点,将该节点当作"当前最佳点"。
解开递归,并对每个经过的节点运行下列步骤:
(1)如果当前所在点比当前最佳点更靠近输入点,则将其变为当前最佳点。
(2)检查另一边子树有没有更近的点,如果有则从该节点往下找。
当根节点搜索完毕后完成最邻近搜索。
范围查询(range searches)
范围查询就是给定查询点和查询距离的阈值,从数据集中找出所有与查询点距离小于阈值的数据。
k-d 树范围查询的过程如下:
从根节点开始,递归地往下移,直到叶节点。
如果当前节点所代表的区域与查询范围没有交集,则返回。
如果当前节点所代表的区域完全包含在查询范围内,则将该节点下所有的数据点全部加入结果集中。
如果当前节点所代表的区域与查询范围有交集,则分别对左右子树递归执行上述步骤。
K近邻搜索(K-Nearest Neighbor Search)
K近邻查询是一种基于距离度量的搜索算法,它可以查找与给定点最近的 k 个数据点。当 k=1 时,就是最近邻查询(nearest neighbor searches)
五、KD-Tree的优缺点
优点
KD-Tree可以高效地处理多维空间搜索问题,例如最近邻搜索和范围搜索等。
KD-Tree的构建和搜索时间复杂度均为O(log n),其中n为数据点的数量。
KD-Tree的空间复杂度比朴素的暴力搜索算法要小很多。
缺点
KD-Tree的构建和搜索过程都需要大量的计算,对于高维数据集来说,效率可能会变得很低。
KD-Tree的查询结果可能会受到数据分布的影响,例如如果数据点都集中在某个区域,那么查询结果可能会偏向该区域。
KD-Tree需要占用较大的内存空间,因为每个节点都需要存储多个数据点。
例题
JZPFAR
P2093 [国家集训队]JZPFAR
思路:
KD-Tree 模板题。
存储每个节点的信息:二维坐标:x[2]表示x、y坐标,id对应节点编号。
struct Point { int x[2], id; bool operator<(const Point& A) const { return x[type] < A.x[type]; } } a[N];
k-d 树的节点:ls、rs表示当前节点的左右孩子,maxp 和 minp 分别表示该节点所代表的区域在每个维度上的最大和最小值,id 表示该节点所代表的数据点的编号,v 表示该节点所代表的在 k 维空间中的数据点。
#define ls tr[rt].ls #define rs tr[rt].rs struct kdtree { int ls, rs; int maxp[2], minp[2]; int id; Point v; } tr[N];
答案结构:维护优先队列的小根堆,id 表示查询点的编号,val 表示查询点与待查找点之间的距离。
struct ask { int id, val; bool operator<(const ask& A) const { if (val == A.val) return id < A.id; return val > A.val; } };
build 用于构建 k-d 树:rt 表示当前节点的编号,l 和 r 分别表示当前区间的左右端点,d 表示当前处理的维度。
具体实现过程如下:
如果当前区间为空,则返回。
计算当前区间的中间位置 mid。
根据当前处理的维度 d,将 a[l] 到 a[r] 中第 mid - l + 1 小的元素(即中位数)放在 a[mid] 的位置上。
创建一个新节点,将其坐标设置为 a[mid],id 设置为 a[mid].id。
递归地构建左子树,区间为 [l, mid - 1],维度为 d ^ 1。
递归地构建右子树,区间为 [mid + 1, r],维度为 d ^ 1。
更新当前节点的 maxp 和 minp,即将左右子树的 maxp 和 minp 合并到当前节点上。
通过这样的方式,我们可以构建出一棵 k-d 树来进行 KNN 算法的查询。
void build(int &rt, int l, int r, int d) { if(l > r) return ; rt = ++cnt; int mid = l + r >> 1; type = d; nth_element(a + l, a + mid, a + r + 1); tr[rt].v = a[mid]; tr[rt].id = a[mid].id; build(ls, l, mid - 1, d ^ 1); build(rs, mid + 1, r, d ^ 1); update(rt); }
query 用于在 k-d 树中查找与给定点 v 最近的 k 个数据点:rt 表示当前节点的编号,v 表示待查找的点。
由于是求距离给定点最大的第 k kk 个点,所以从根节点开始询问,遇到更大距离的点即 if(t.val > q.top().val),就更新小根堆,动态维护着最大的 k kk 个点。
这里做 if (l < r) 的判断,然后区分先递归左右子树,尽可能地缩小搜索范围,可以大大减少查询的次数,以减少不必要的计算。
因此,如果当前节点的左子树在目标点的某个维度上比当前节点更接近目标点,那么我们应该先遍历右子树,再遍历左子树;否则,应该先遍历左子树,再遍历右子树。(因为是对q.top()的比较增改,先查询会先压入更大的值,减少回溯后其他分支的查询概率)
void query(int rt, Point v) { ask t; t.id = tr[rt].id; t.val = dis(v.x[0], v.x[1], tr[rt].v.x[0], tr[rt].v.x[1]); if(t.val > q.top().val) q.pop(), q.push(t); int l = -2e18, r = -2e18; if (ls) l = getdis(ls, v); if (rs) r = getdis(rs, v); if (l < r) { if(r >= q.top().val) query(rs, v); if(l >= q.top().val) query(ls, v); } else{ if(l >= q.top().val) query(ls, v); if(r >= q.top().val) query(rs, v); } }
代码:
#include <bits/stdc++.h> using namespace std; #define int long long const int N = 100010; int root, type, cnt; struct Point { int x[2], id; bool operator<(const Point& A) const { return x[type] < A.x[type]; } } a[N]; struct kdtree { int ls, rs; int maxp[2], minp[2]; int id; Point v; } tr[N]; struct ask { int id, val; bool operator<(const ask& A) const { if (val == A.val) return id < A.id; return val > A.val; } }; priority_queue<ask> q; #define ls tr[rt].ls #define rs tr[rt].rs void update(int rt){ for(int i = 0; i < 2; i++) { tr[rt].maxp[i] = tr[rt].minp[i] = tr[rt].v.x[i]; if(ls) { tr[rt].maxp[i] = max(tr[rt].maxp[i], tr[ls].maxp[i]); tr[rt].minp[i] = min(tr[rt].minp[i], tr[ls].minp[i]); } if(rs){ tr[rt].maxp[i] = max(tr[rt].maxp[i], tr[rs].maxp[i]); tr[rt].minp[i] = min(tr[rt].minp[i], tr[rs].minp[i]); } } } void build(int &rt, int l, int r, int d) { if(l > r) return ; rt = ++cnt; int mid = l + r >> 1; type = d; nth_element(a + l, a + mid, a + r + 1); tr[rt].v = a[mid]; tr[rt].id = a[mid].id; build(ls, l, mid - 1, d ^ 1); build(rs, mid + 1, r, d ^ 1); update(rt); } int getdis(int rt, Point v) { int res = 0; for(int i = 0; i < 2; i++) { int t = max(abs(v.x[i] - tr[rt].maxp[i]), abs(v.x[i] - tr[rt].minp[i])); res += t * t; } return res; } int dis(int x, int y, int xx, int yy) { return (x - xx) * (x - xx) + (y - yy) * (y - yy); } void query(int rt, Point v) { ask t; t.id = tr[rt].id; t.val = dis(v.x[0], v.x[1], tr[rt].v.x[0], tr[rt].v.x[1]); if(t.val > q.top().val) q.pop(), q.push(t); int l = -2e18, r = -2e18; if (ls) l = getdis(ls, v); if (rs) r = getdis(rs, v); if (l < r) { if(r >= q.top().val) query(rs, v); if(l >= q.top().val) query(ls, v); } else{ if(l >= q.top().val) query(ls, v); if(r >= q.top().val) query(rs, v); } } signed main(){ int n; cin >> n; for (int i = 1; i <= n; i++) { cin >> a[i].x[0] >> a[i].x[1]; a[i].id = i; } build(root, 1, n, 0); int m; cin >> m; while(m--) { int k; Point v; cin >> v.x[0] >> v.x[1] >> k; while(!q.empty()) q.pop(); while(k--) q.push(ask{0, -1}); query(root, v); cout << q.top().id << endl; } return 0; }
以上是二维的 KD-Tree
例题,后续有时间在多更新几题。