Python 绘制一个二叉树实际上是一个比较简单的需求,比如我们可以使用控制台直接分层打印出来,那么这个问题实际上就转化为了对二叉树的层次遍历,实际上一个二叉树,为了让人能够很直观理解他的结构,我们通常表达出来,就是一个有层次感的结构。
在Python中,绘制二叉树,我们可以利用一些绘制图形的库,比如学校里面,我们基本都或多或少接触过matplotlib
,这个其实就比较适合用到我们这个例子中来。但是在这之前,我们必须思考,我们使用什么样的方式来表达一颗二叉树,通常二叉树的结构体,我们定义为:
class TreeNode:
def __init__(self, value):
self.val = value
self.left = None
self.right = None
但是用户输入的时候,肯定不是太方便输入一个这样的结构给到你,我们在学习算法的时候,构造二叉树的时候一般都是使用一个数组来表示的,因此,这里我们也将采用这样的方式来整。
为了简单,我们需要一个函数来将字符数组转换为二叉树。这里我们假设数组是按层序遍历的结果,即数组的第一个元素是树的根,接下来的两个元素是根的左右子节点,以此类推。如果某个位置是None
,则表示该位置没有节点。那,ok,我们写一个这样的方法来将一维数组转化为二叉树。
# 示例使用
# 假如说 array = ['a', 'b', 'c', None, 'd', 'e', None] 是这样的
# 我们就可以构建如下的二叉树
# a
# / \
# b c
# \ / \
# d e None
# 我们期望后面借助于matplotlib 绘制 图形也大概长这个样子
代码语言:javascript复制from collections import deque
def array_to_bst(array):
if not array:
return None
iter_array = iter(array)
root = TreeNode(next(iter_array))
queue = deque([root])
while True:
current_node = queue.popleft()
try:
left_value = next(iter_array)
if left_value is not None:
current_node.left = TreeNode(left_value)
queue.append(current_node.left)
right_value = next(iter_array)
if right_value is not None:
current_node.right = TreeNode(right_value)
queue.append(current_node.right)
except StopIteration:
break
return root
ok,一旦我们使用array_to_bst方法构建了这颗二叉树之后,我们就可以开始遍历这颗二叉树了。
二叉树的遍历方式,想必不用过多的介绍,当然,为了简单起见,我们还是采用递归
的方式来遍历了。下面是通过这个二叉树递归去绘制的流程图:
对这个过程进行一点点简单的解释,主要就是遍历到当前节点,进行一些检查,如果存在左子节点,就使用matplotlib
的api 进行绘制,然后在看右子树,这个过程是 遍历的过程,当然我们可以采用深度优先的方式去遍历这颗树。
下面就是对整个上述过程的描述,为了比较清晰标记子节点是左子树的,还是右子树的,我们使用 L,和 R 进行标记一下。整个实现的过程如下所示:
代码语言:javascript复制import matplotlib.pyplot as plt
def plot_tree(node, parent_name, node_name, edge_label, pos=None, x=0, y=0, layer=1):
if pos is None:
pos = {}
pos[node_name] = (x, y)
plt.text(x, y, str(node.val), fontsize=12, ha='center')
if parent_name is not None:
plt.plot([x, pos[parent_name][0]], [y, pos[parent_name][1]], 'k-')
plt.text((x pos[parent_name][0])/2, (y pos[parent_name][1])/2, edge_label, fontsize=8, ha='center')
if node.left:
plot_tree(node.left, node_name, node_name "L", 'L', pos, x-1/2**layer, y-1, layer 1)
if node.right:
plot_tree(node.right, node_name, node_name "R", 'R', pos, x 1/2**layer, y-1, layer 1)
return pos
def draw_bst(root):
fig, ax = plt.subplots()
ax.axis('off')
plot_tree(root, None, 'Root', None)
plt.show()
# 示例使用
array = ['a', 'b', 'c', None, 'd', 'e', None]
root = array_to_bst(array)
draw_bst(root)
其中,我们的plot_tree 函数就是递归之所在,既然是一个递归方法,那么,就会有开始和终止条件,我们开始节点是 root,即我们的根节点作为初始节点,然后,我们的终止条件是所有的左右子树的节点都遍历完毕了。我们使用的plt.plot这个绘制函数是相当基础的,它假设所有的节点值都是单个字符,而且没有考虑到节点值可能会重叠的情况。对于树的节点比较多的话,我估计可能会出现重叠的情况。
总结
至此,我们就完成了使用 Python 绘制一颗二叉树的小玩意,虽然这个需求完成了,但是我们不妨去思考一下,我们将二叉树绘制出来有哪些更深层的意义,比如,做算法可视化来辅助我们学习理解树的算法,甚至做一个教学工具来帮助我们进行算法教学,绘图的过程我们使用递归的方式偷懒了,相信还有更加高效的办法,那就是树的遍历上的一些问题了,你能想到怎么提高绘制效率吗?
或者说,如果不限于 Python,要绘制一颗二叉树,你将会采用什么方式来绘制呢,即高效又简洁的呢?
提示,可以了解一下:
- VisuAlgo
- Data Structure Visualizations
- D3.js
我正在参与2023腾讯技术创作特训营第四期有奖征文,快来和我瓜分大奖!