学习笔记|k近邻分类算法 指出k近邻分类算法通过kd树的构造和搜索来实现。
1. 构建二叉树类
为了实现kd树的构造和搜索算法,我们先构建一个二叉树类。首先,申明类,初始化根结点和左、右子结点。
代码语言:javascript复制class binary_tree(object):
def __init__(self, root_obj):
self.key = root_obj
self.left_child = None
self.right_child = None
其次,构造插入左、右子树方法。
代码语言:javascript复制 def insert_left_child(self, new_obj):
new_tree = binary_tree(new_obj)
if self.left_child == None:
self.left_child = new_tree
else:
new_tree.left_child = self.left_child
self.left_child = new_tree
def insert_right_child(self, new_obj):
new_tree = binary_tree(new_obj)
if self.right_child == None:
self.right_child = new_tree
else:
new_tree.right_child = self.right_child
self.right_child = new_tree
再次,构造读、写根节点方法。
代码语言:javascript复制 def get_root_value(self):
return self.key
def set_root_value(self, root_obj):
self.key = root_obj
最后,构造读取左、右子树方法。
代码语言:javascript复制 def get_left_child(self):
return self.left_child
def get_right_child(self):
return self.right_child
2. 构造kd树
在二叉树的基础上构造kd树,事实上kd树只需要实现二叉树的部分功能。
代码语言:javascript复制class kd_tree(binary_tree):
def set_child(self, new_obj, lr='l'):
if lr == 'l':
self.left_child = new_obj
else:
self.right_child = new_obj
然后生成kd树。
代码语言:javascript复制def generate_kd_tree(x, d=0):
if len(x) > 1:
x = x[np.argsort(-x[:,d])]
mi = np.argmin(np.abs(x[:, d] - np.median(x[:, d])))
kd_tree1 = kd_tree(x[mi])
d = 1
if len(x[:mi]):
kd_tree1.set_child(generate_kd_tree(x[:mi], d), lr='r')
if len(x[mi 1:]):
kd_tree1.set_child(generate_kd_tree(x[mi 1:], d), lr='l')
return kd_tree1
return kd_tree(x[0])
根据书本(参考文献2)上的案例对生成的kd树进行简单验证。
代码语言:javascript复制if __name__ == "__main__":
kd_tree = generate_kd_tree(np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]))
print(kd_tree.key, kd_tree.left_child.key, kd_tree.right_child.key, kd_tree.left_child.left_child.key, kd_tree.left_child.right_child.key, kd_tree.right_child.left_child.key)
得到: [7 2] [5 4] [9 6] [2 3] [4 7] [8 1]
3. kd树搜索
kd树的搜索可以通过递归的方法来实现。当然,有点偷懒,这里的代码比较冗长。
代码语言:javascript复制def search_kd_tree(kd_tree1, s, d=0):
if kd_tree1.left_child == None:
if kd_tree1.right_child == None:
return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
if np.linalg.norm(kd_tree1.right_child.key - s) < np.linalg.norm(kd_tree1.key - s):
return {'p': kd_tree1.right_child.key, 'r': np.linalg.norm(kd_tree1.right_child.key - s)}
return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
elif kd_tree1.right_child == None:
if np.linalg.norm(kd_tree1.left_child.key - s) < np.linalg.norm(kd_tree1.key - s):
return {'p': kd_tree1.left_child.key, 'r': np.linalg.norm(kd_tree1.left_child.key - s)}
return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
if s[d] < kd_tree1.key[d]:
d = (d 1) % len(s)
c = search_kd_tree(kd_tree1.left_child, s, d=d)
lr = 'l'
else:
d = (d 1) % len(s)
c = search_kd_tree(kd_tree1.right_child, s, d=d)
lr = 'r'
if np.linalg.norm(kd_tree1.key - s) < c['r']:
c['p'] = kd_tree1.key
c['r'] = np.linalg.norm(c['p'] - 1)
s1 = s.copy()
s1[d] = kd_tree1.key[d]
if np.linalg.norm(s1 - s) < c['r']:
if lr == 'l':
c1 = search_kd_tree(kd_tree1.right_child, s, d=d)
else:
c1 = search_kd_tree(kd_tree1.left_child, s, d=d)
if c1['r'] < c['r']:
c = c1
return c
4. kd树搜索举例
代码语言:javascript复制if __name__ == "__main__":
kd_tree1 = generate_kd_tree(np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]))
print(kd_tree1.key, kd_tree1.left_child.key, kd_tree1.right_child.key, kd_tree1.left_child.left_child.key, kd_tree1.left_child.right_child.key, kd_tree1.right_child.left_child.key)
print(search_kd_tree(kd_tree1, np.array([2, 4.5])))
print(search_kd_tree(kd_tree1, np.array([9, 7])))
结合前面的例子,做个简单的验证,可以看到当s=(2,4.5)时,最近邻是点(2,3),距离是1.5;当s=(9,7)时,最近邻是点(9,6),距离是1.0。其中,s=(9,7)是具有一定的特殊性的,这里不再赘述。 [7 2] [5 4] [9 6] [2 3] [4 7] [8 1] {'p': array([2, 3]), 'r': 1.5} {'p': array([9, 6]), 'r': 1.0}
参考文献
[1]https://blog.csdn.net/m0_37324740/article/details/79435814 [2]统计学习方法(第2版),李航著,清华大学出版社 [3]https://zhuanlan.zhihu.com/p/104758420