Python 数据结构 tree 树

2020-01-08 16:42:12 浏览数 (1)

[Python] 数据结构 tree 树

树节点类 TreeNode

作为最简单的树节点,我们只需要3个基本属性

  • name: 当前节点的名字(使用str来保存)
  • parent: 父节点对象(对根节点来说,该值为Null)
  • child: 字节点对象们(使用dict来保存)

代码如下:

代码语言:javascript复制
class TreeNode(object):
    """The basic node of tree structure"""

    def __init__(self, name, parent=None):
        super(TreeNode, self).__init__()
        self.name = name
        self.parent = parent
        self.child = {}

    def __repr__(self) :
        return 'TreeNode(%s)' % self.name

树节点方法

针对每个树节点的操作,例如:

  • get_child(name) 获取子节点 (仅在当前节点下)
  • find_child(name/path) 查找子节点(甚至子节点的子节点的…子节点)
  • add_child(name, obj) 增加子节点
  • del_child(name) 删除子节点
代码语言:javascript复制
class TreeNode(object):

    def get_child(self, name, defval=None):
        """get a child node of current node"""
        return self.child.get(name, defval)

    def add_child(self, name, obj=None):
        """add a child node to current node"""
        if obj and not isinstance(obj, TreeNode):
            raise ValueError('TreeNode only add another TreeNode obj as child')
        if obj is None:
            obj = TreeNode(name)
        obj.parent = self
        self.child[name] = obj
        return obj

    def del_child(self, name):
        """remove a child node from current node"""
        if name in self.child:
            del self.child[name]

    def find_child(self, path, create=False):
        """find child node by path/name, return None if not found"""
        # convert path to a list if input is a string
        path = path if isinstance(path, list) else path.split()
        cur = self
        for sub in path:
            # search
            obj = cur.get_child(sub)
            if obj is None and create:
                # create new node if need
                obj = cur.add_child(sub)
            # check if search done
            if obj is None:
                break
            cur = obj
        return obj

树节点属性

除了已经存在name, child, parent属性外,我们可以自定义其他属性方便操作。

例如:

  • path: 得到当前节点从root的路径
代码语言:javascript复制
class TreeNode(object):

    @property
    def path(self):
        """return path string (from root to current node)"""
        if self.parent:
            return '%s %s' % (self.parent.path.strip(), self.name)
        else:
            return self.name

NOTE: 上面使用空格作为路径的分隔符,也可以改用/或者.。 如果使用/的话需要在find_child()重写路径分割代码来取代path.split()。

其他

如果想使用for ... in ...操作来遍子节点,我们可以实现items()方法:

代码语言:javascript复制
class TreeNode(object):

    def items(self):
        return self.child.items()

如果想使用系统的in操作符,来判断是否存在名字为name的子节点,

代码语言:javascript复制
class TreeNode(object):

    def __contains__(self, item):
        return item in self.child

如果想得到当前节点中子节点的个数,可以使用系统的len()函数。 我们所要做的就是重写__len__()

注意:如果重写__len__()的话,最好同时重写__bool__()。 因为python在做布尔判断时,如果没有找到__bool__()的话,会使用__len__()来替代。 这样就导致如果没有子节点,当前节点的布尔返回False

这里我们定义__bool__()永远返回True,这样我们可以通过布尔判断来判断一个节点是否存在。

代码语言:javascript复制
class TreeNode(object):

    def __len__(self):
        """return number of children node"""
        return len(self.child)

    def __bool__(self, item):
        """always return True for exist node"""
        return True

如果想把树结构打印出来,可以创建一个dump()方法。

代码语言:javascript复制
class TreeNode(object):

    def dump(self, indent=0):
        """dump tree to string"""
        tab = '    '*(indent-1)   ' |- ' if indent > 0 else ''
        print('%s%s' % (tab, self.name))
        for name, obj in self.items():
            obj.dump(indent 1)

如果想把树结构保存到文件里,稍候参考本人另一篇关于序列化的文章

源代码

类代码和测试代码如下(python2.7和python3)

代码语言:javascript复制
#!/usr/bin/python

from __future__ import unicode_literals  # at top of module
from __future__ import division, print_function, with_statement



class TreeNode(object):
    """The basic node of tree structure"""

    def __init__(self, name, parent=None):
        super(TreeNode, self).__init__()
        self.name = name
        self.parent = parent
        self.child = {}

    def __repr__(self) :
        return 'TreeNode(%s)' % self.name


    def __contains__(self, item):
        return item in self.child


    def __len__(self):
        """return number of children node"""
        return len(self.child)

    def __bool__(self, item):
        """always return True for exist node"""
        return True

    @property
    def path(self):
        """return path string (from root to current node)"""
        if self.parent:
            return '%s %s' % (self.parent.path.strip(), self.name)
        else:
            return self.name

    def get_child(self, name, defval=None):
        """get a child node of current node"""
        return self.child.get(name, defval)

    def add_child(self, name, obj=None):
        """add a child node to current node"""
        if obj and not isinstance(obj, TreeNode):
            raise ValueError('TreeNode only add another TreeNode obj as child')
        if obj is None:
            obj = TreeNode(name)
        obj.parent = self
        self.child[name] = obj
        return obj

    def del_child(self, name):
        """remove a child node from current node"""
        if name in self.child:
            del self.child[name]

    def find_child(self, path, create=False):
        """find child node by path/name, return None if not found"""
        # convert path to a list if input is a string
        path = path if isinstance(path, list) else path.split()
        cur = self
        for sub in path:
            # search
            obj = cur.get_child(sub)
            if obj is None and create:
                # create new node if need
                obj = cur.add_child(sub)
            # check if search done
            if obj is None:
                break
            cur = obj
        return obj

    def items(self):
        return self.child.items()

    def dump(self, indent=0):
        """dump tree to string"""
        tab = '    '*(indent-1)   ' |- ' if indent > 0 else ''
        print('%s%s' % (tab, self.name))
        for name, obj in self.items():
            obj.dump(indent 1)


if __name__ == '__main__':
    # local test
    print('test add_child()')
    root = TreeNode('') # root name is ''
    a1 = root.add_child('a1')
    a1.add_child('b1')
    a1.add_child('b2')
    a2 = root.add_child('a2')
    b3 = a2.add_child('b3')
    b3.add_child('c1')
    root.dump()
    # (root)
    #  |- a1
    #      |- b1
    #      |- b2
    #  |- a2
    #      |- b3
    #          |- c1


    print('test items()')
    for name, obj in a1.items():
        print(name, obj)
    # b1 TreeNode(b1)
    # b2 TreeNode(b2)


    print('test operator "in"')
    print("b2 is a1's child = %s" % ('b2' in a1))
    # b2 is a1's child = True


    print('test del_child()')
    a1.del_child('b2')
    root.dump()
    print("b2 is a1's child = %s" % ('b2' in a1))
    # (root)
    #  |- a1
    #      |- b1
    #  |- a2
    #      |- b3
    #          |- c1
    # b2 is a1's child = False


    print('test find_child()')
    obj = root.find_child('a2 b3 c1')
    print(obj)
    # TreeNode(c1)

    print('test find_child() with create')
    obj = root.find_child('a1 b1 c2 b1 e1 f1', create=True)
    print(obj)
    root.dump()
    # TreeNode(f1)
    # (root)
    # |- a1
    #     |- b1
    #         |- c2
    #             |- b1
    #                 |- e1
    #                     |- f1
    # |- a2
    #     |- b3
    #         |- c1

    print('test attr path')
    print(obj.path)
    # a1 b1 c2 b1 e1 f1

0 人点赞