这篇文章主要聚焦一下各种奇奇怪怪的树查询算法。虽然我主要关注最近邻检索的图方法,但是因为很多图方法实际上都使用了各种 tree 作为辅助索引,所以有必要简单了解一下所有的树查询算法。

有请第一位 受害者

# KD 树

kd 树算是用的很广泛的一种最近邻检索树了,它的思想实际上和二叉搜索树很像。

看这张图就够了:

image-20220423164038307

首先思考动机,为什么我们需要构建 KD 树?树结构是一种非常高效的数据结构,对于二叉搜索树,它的查询复杂度只有 h,也就是 log (n)。如果直接暴力检索最近邻,那么复杂度会是 n。

KD 树是什么?与搜索二叉树类似,通过比较数据,把大于当前节点的数据插入到树右侧,把小于节点的数据放到左侧。而 KD 树比较大小是在不同维度进行的。

例如,对于当前节点 root。我们比较所有数据点(p Ed\in E^d) 在维度 0 的数据,对于某个向量 [1.1,3.4,...,0.1] 在 dim=0 的数据即 1.1,如果 1.1 小于 root(dim=0),那么就把这个向量插入到左侧,如果小于就插入到右侧。同样在 KD 树下一层,我们比较 dim=1 的数据,按照同样方式进行插入。

一个很重要的优化方式就是,我们要保证 KD 树左右两边数据点数量差不多,这样进行插入时候可以保证得到的 KD 树尽可能平衡。于是我们构建 KD 树时候选择当前点集在当前维度下处于 中位数 的那个点作为 base,然后比较插入。

第二个优化的点,维度选择。我们尽量选择数据分布方差很大的维度。

所以构建方法:

假设我们已经知道了 维空间内的 个不同的点的坐标,要将其构建成一棵 k-D Tree,步骤如下:

  1. 若当前超长方体中只有一个点,返回这个点。
  2. 选择一个维度,将当前超长方体按照这个维度分成两个超长方体。
  3. 选择切割点:在方差大的维度上选择中位数那个点,这一维度上的值小于这个点的归入一个超长方体(左子树),其余的归入另一个超长方体(右子树)。
  4. 将选择的点作为这棵子树的根节点,递归对分出的两个超长方体构建左右子树,维护子树的信息。

这里定义可以参考:https://oi-wiki.org/ds/kdt/

代码实现我看的是:https://github.com/crvs/KDTree

构建树的算法:

KDNodePtr KDTree::make_tree(const pointIndexArr::iterator &begin,  //
                            const pointIndexArr::iterator &end,    //
                            const size_t &length,                  //
                            const size_t &level                    //
) {
    if (begin == end) {
        return NewKDNodePtr();  // empty tree
    }
    size_t dim = begin->first.size();
    if (length > 1) {
        sort_on_idx(begin, end, level); // 类似于快排,就是把中位数放到正确的位置,同时左边的数都比他小,右边的数都比中位数大
    }
    auto middle = begin + (length / 2);
    auto l_begin = begin;
    auto l_end = middle;
    auto r_begin = middle + 1;
    auto r_end = end;
    size_t l_len = length / 2;
    size_t r_len = length - l_len - 1;
    KDNodePtr left;
    if (l_len > 0 && dim > 0) {
        left = make_tree(l_begin, l_end, l_len, (level + 1) % dim);
    } else {
        left = leaf;
    }
    KDNodePtr right;
    if (r_len > 0 && dim > 0) {
        right = make_tree(r_begin, r_end, r_len, (level + 1) % dim);
    } else {
        right = leaf;
    }
    // KDNode result = KDNode();
    return std::make_shared< KDNode >(*middle, left, right);
}

他的代码没有用到第二个优化,选择划分的 dim 只是简单的递增取模。

查询算法:

KDNodePtr KDTree::nearest_(   //
    const KDNodePtr &branch,  //
    const point_t &pt,        //
    const size_t &level,      //
    const KDNodePtr &best,    //
    const double &best_dist   //
) {
    double d, dx, dx2;
    if (!bool(*branch)) { // 走到了一个叶子结点
        return NewKDNodePtr();  // basically, null
    }
    point_t branch_pt(*branch); // 之间重载了操作符 ()
    size_t dim = branch_pt.size();
    d = dist2(branch_pt, pt);
    dx = branch_pt.at(level) - pt.at(level);
    dx2 = dx * dx;
    KDNodePtr best_l = best;
    double best_dist_l = best_dist;
    if (d < best_dist) {
        best_dist_l = d;
        best_l = branch;
    }
    size_t next_lv = (level + 1) % dim;
    KDNodePtr section;
    KDNodePtr other;
    // select which branch makes sense to check
    if (dx > 0) {
        section = branch->left;
        other = branch->right;
    } else {
        section = branch->right;
        other = branch->left;
    }
    // keep nearest neighbor from further down the tree
    KDNodePtr further = nearest_(section, pt, next_lv, best_l, best_dist_l);
    if (!further->x.empty()) { //best_l 实际上是一个指针,这一步 check 并没有必要。
        double dl = dist2(further->x, pt);
        if (dl < best_dist_l) {
            best_dist_l = dl;
            best_l = further;
        }
    }
    // only check the other branch if it makes sense to do so
    if (dx2 < best_dist_l) { // 说明 other 这个 branch 也有可能能够找到最近邻
        further = nearest_(other, pt, next_lv, best_l, best_dist_l);
        if (!further->x.empty()) {
            double dl = dist2(further->x, pt);
            if (dl < best_dist_l) {
                best_dist_l = dl;
                best_l = further;
            }
        }
    }
    return best_l;
};

查询的思想也和二叉树更类似,不多赘述。

# ball tree

ball tree 的思想更简单了。

将当前点集的质心作为 root,搜索离当前质心最远的节点 p,搜索离节点 p 最远的节点 q。通过 p,q 将当前点集进行划分为两个 cluster,看点离 p,q 哪个更近,如果离 p 更近就加入 p cluster,如果离 q 更近就加入 q cluster。反复迭代,知道到达最大深度。

image-20220423175340044

首先取灰色的点作为 root,用 3 和 9 来划分两个 cluster。然后递归进行划分。

image-20220423175455261

# VP treee

这里有一篇文章,对于我理解 VP tree 有很大帮助:http://stevehanov.ca/blog/?id=130

so,what is vp tree?

image-20220423194326540

对于每一个节点 p,我们设置一个半径 r,把和该节点距离小于 r 的所有点插入到节点的左子树,把距离大于 r 的所有点插入到右子树。实际代码中,我们取点集中位数之前的所有点作为左子树,中位数之后的点作为右子树。

查询时候,需要去确定向左还是向右去查询。

image-20220423195255100

如果查询的点 x 在当前也就是距离小于 tau,把当前 p 加入到 result 中(因为需要查 k 个最近邻), tau 始终等于结果 result 中距离 x 的最大距离。

因为 x 在圆内,于是我们向左子树进行查询,蓝色的点,就是我们在左子树中查询得到的点,并且更新得到了 tau。

如果 tau > distance to shell,说明外部还可能有节点,需要向右子树进行查询。

image-20220423195812891

#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <stdio.h>
#include <queue>
#include <limits>
template<typename T, double (*distance)( const T&, const T& )>
class VpTree
{
public:
    VpTree() : _root(0) {}
    ~VpTree() {
        delete _root;
    }
    void create( const std::vector& items ) {
        delete _root;
        _items = items;
        _root = buildFromPoints(0, items.size());
    }
    void search( const T& target, int k, std::vector* results, 
        std::vector<double>* distances) 
    {
        std::priority_queue<HeapItem> heap;
        _tau = std::numeric_limits::max();
        search( _root, target, k, heap );
        results->clear(); distances->clear();
        while( !heap.empty() ) {
            results->push_back( _items[heap.top().index] );
            distances->push_back( heap.top().dist );
            heap.pop();
        }
        std::reverse( results->begin(), results->end() );
        std::reverse( distances->begin(), distances->end() );
    }
private:
    std::vector<T> _items;
    double _tau;
    struct Node 
    {
        int index;
        double threshold;
        Node* left;
        Node* right;
        Node() :
            index(0), threshold(0.), left(0), right(0) {}
        ~Node() {
            delete left;
            delete right;
        }
    }* _root;
    struct HeapItem {
        HeapItem( int index, double dist) :
            index(index), dist(dist) {}
        int index;
        double dist;
        bool operator<( const HeapItem& o ) const {
            return dist < o.dist;   
        }
    };
    struct DistanceComparator
    {
        const T& item;
        DistanceComparator( const T& item ) : item(item) {}
        bool operator()(const T& a, const T& b) {
            return distance( item, a ) < distance( item, b );
        }
    };
    Node* buildFromPoints( int lower, int upper )
    {
        if ( upper == lower ) {
            return NULL;
        }
        Node* node = new Node();
        node->index = lower;
        if ( upper - lower > 1 ) {
            // choose an arbitrary point and move it to the start
            int i = (int)((double)rand() / RAND_MAX * (upper - lower - 1) ) + lower;
            std::swap( _items[lower], _items[i] );
            int median = ( upper + lower ) / 2;
            // partitian around the median distance
            std::nth_element( 
                _items.begin() + lower + 1, 
                _items.begin() + median,
                _items.begin() + upper,
                DistanceComparator( _items[lower] ));
            // what was the median?
            node->threshold = distance( _items[lower], _items[median] );
            node->index = lower;
            node->left = buildFromPoints( lower + 1, median );
            node->right = buildFromPoints( median, upper );
        }
        return node;
    }
    void search( Node* node, const T& target, int k,
                 std::priority_queue& heap )
    {
        if ( node == NULL ) return;
        double dist = distance( _items[node->index], target );
        //printf("dist=%g tau=%gn", dist, _tau );
        if ( dist < _tau ) {
            if ( heap.size() == k ) heap.pop();
            heap.push( HeapItem(node->index, dist) );
            if ( heap.size() == k ) _tau = heap.top().dist;
        }
        if ( node->left == NULL && node->right == NULL ) {
            return;
        }
        if ( dist < node->threshold ) {
            search( node->left, target, k, heap );
            if ( dist + _tau >= node->threshold ) { // 说明外部还有可能有节点。
                search( node->right, target, k, heap );
            }
        } else {
            search( node->right, target, k, heap );
            if ( dist - _tau <= node->threshold ) {
                search( node->left, target, k, heap );
            }
        }
    }
};

# Brute Force

说白了就是暴力检索。

更新于

请我喝[茶]~( ̄▽ ̄)~*

Kalice 微信支付

微信支付

Kalice 支付宝

支付宝