学习笔记|k近邻法的实现

2021-10-25 10:19:02 浏览数 (1)

学习笔记|k近邻分类算法 指出k近邻分类算法通过kd树的构造和搜索来实现。

1. 构建二叉树类

为了实现kd树的构造和搜索算法,我们先构建一个二叉树类。首先,申明类,初始化根结点和左、右子结点。

代码语言:javascript复制
class binary_tree(object):
    def __init__(self, root_obj):
        self.key = root_obj
        self.left_child = None
        self.right_child = None

其次,构造插入左、右子树方法。

代码语言:javascript复制
    def insert_left_child(self, new_obj):
        new_tree = binary_tree(new_obj)
        if self.left_child == None:
            self.left_child = new_tree
        else:
            new_tree.left_child = self.left_child
            self.left_child = new_tree

    def insert_right_child(self, new_obj):
        new_tree = binary_tree(new_obj)
        if self.right_child == None:
            self.right_child = new_tree
        else:
            new_tree.right_child = self.right_child
            self.right_child = new_tree

再次,构造读、写根节点方法。

代码语言:javascript复制
    def get_root_value(self):
        return self.key

    def set_root_value(self, root_obj):
        self.key = root_obj

最后,构造读取左、右子树方法。

代码语言:javascript复制
    def get_left_child(self):
        return self.left_child

    def get_right_child(self):
        return self.right_child

2. 构造kd树

在二叉树的基础上构造kd树,事实上kd树只需要实现二叉树的部分功能。

代码语言:javascript复制
class kd_tree(binary_tree):
    def set_child(self, new_obj, lr='l'):
        if lr == 'l':
            self.left_child = new_obj
        else:
            self.right_child = new_obj

然后生成kd树。

代码语言:javascript复制
def generate_kd_tree(x, d=0):
    if len(x) > 1:
        x = x[np.argsort(-x[:,d])]
        mi = np.argmin(np.abs(x[:, d] - np.median(x[:, d])))
        kd_tree1 = kd_tree(x[mi])
        d  = 1
        if len(x[:mi]):
            kd_tree1.set_child(generate_kd_tree(x[:mi], d), lr='r')
        if len(x[mi 1:]):
            kd_tree1.set_child(generate_kd_tree(x[mi 1:], d), lr='l')
        return kd_tree1
    return kd_tree(x[0])

根据书本(参考文献2)上的案例对生成的kd树进行简单验证。

代码语言:javascript复制
if __name__ == "__main__":
    kd_tree = generate_kd_tree(np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]))
    print(kd_tree.key, kd_tree.left_child.key, kd_tree.right_child.key, kd_tree.left_child.left_child.key, kd_tree.left_child.right_child.key, kd_tree.right_child.left_child.key)

得到: [7 2] [5 4] [9 6] [2 3] [4 7] [8 1]

3. kd树搜索

kd树的搜索可以通过递归的方法来实现。当然,有点偷懒,这里的代码比较冗长。

代码语言:javascript复制
def search_kd_tree(kd_tree1, s, d=0):
    if kd_tree1.left_child == None:
        if kd_tree1.right_child == None:
            return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
        if np.linalg.norm(kd_tree1.right_child.key - s) < np.linalg.norm(kd_tree1.key - s):
            return {'p': kd_tree1.right_child.key, 'r': np.linalg.norm(kd_tree1.right_child.key - s)}
        return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
    elif kd_tree1.right_child == None:
        if np.linalg.norm(kd_tree1.left_child.key - s) < np.linalg.norm(kd_tree1.key - s):
            return {'p': kd_tree1.left_child.key, 'r': np.linalg.norm(kd_tree1.left_child.key - s)}
        return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
    if s[d] < kd_tree1.key[d]:
        d = (d   1) % len(s)
        c = search_kd_tree(kd_tree1.left_child, s, d=d)
        lr = 'l'
    else:
        d = (d   1) % len(s)
        c = search_kd_tree(kd_tree1.right_child, s, d=d)
        lr = 'r'
    if np.linalg.norm(kd_tree1.key - s) < c['r']:
        c['p'] = kd_tree1.key
        c['r'] = np.linalg.norm(c['p'] - 1)
    s1 = s.copy()
    s1[d] = kd_tree1.key[d]
    if np.linalg.norm(s1 - s) < c['r']:
        if lr == 'l':
            c1 = search_kd_tree(kd_tree1.right_child, s, d=d)
        else:
            c1 = search_kd_tree(kd_tree1.left_child, s, d=d)
        if c1['r'] < c['r']:
            c = c1
    return c

4. kd树搜索举例

代码语言:javascript复制
if __name__ == "__main__":
    kd_tree1 = generate_kd_tree(np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]))
    print(kd_tree1.key, kd_tree1.left_child.key, kd_tree1.right_child.key, kd_tree1.left_child.left_child.key, kd_tree1.left_child.right_child.key, kd_tree1.right_child.left_child.key)
    print(search_kd_tree(kd_tree1, np.array([2, 4.5])))
    print(search_kd_tree(kd_tree1, np.array([9, 7])))

结合前面的例子,做个简单的验证,可以看到当s=(2,4.5)时,最近邻是点(2,3),距离是1.5;当s=(9,7)时,最近邻是点(9,6),距离是1.0。其中,s=(9,7)是具有一定的特殊性的,这里不再赘述。 [7 2] [5 4] [9 6] [2 3] [4 7] [8 1] {'p': array([2, 3]), 'r': 1.5} {'p': array([9, 6]), 'r': 1.0}

参考文献

[1]https://blog.csdn.net/m0_37324740/article/details/79435814 [2]统计学习方法(第2版),李航著,清华大学出版社 [3]https://zhuanlan.zhihu.com/p/104758420

0 人点赞