平衡二叉树(AVL)C++实现

2022-02-25 08:38:50 浏览数 (1)

代码语言:javascript复制
#include <bits/stdc  .h>
using namespace std;



class AVL {
private:
  struct node {
    int val;
    int height;
    node *left, *right; 

    node (int x) {
      this->val = x;
      this->height = 1;
      this->left = this->right = nullptr;
    }
  };


  node *root;
  int size;

  void updateHeight(node *root) {
    root->height = max(getHeight(root->left), getHeight(root->right))   1;
  }

  int getHeight(node *root) {
    return !root?0:root->height;
  }

  int getBalanceFactor(node* root) {
    return !root?0:(getHeight(root->left) - getHeight(root->right));
  }

  node* add(node *root, int x) {
    if (!root) {
      size   ;
      return new node(x);
    }

    if (x < root->val) {
      root->left = add(root->left, x);
    } else if (x > root->val) {
      root->right = add(root->right, x);
    } else {
      root->val = x;
    }


    // 更新高度
    updateHeight(root);

    int balanceFactor = getBalanceFactor(root);
    // 维护平衡
    // LL
    if (balanceFactor==2 && getBalanceFactor(root->left)>=0) {
      return rightRotate(root);
    }
    // LR
    //                 y
    //                / 
    //               x   t4
    //              /         leftRotate(x)   rightRotate(y)
    //             t1  z
    //                / 
    //               t2 t3
    if (balanceFactor==2 && getBalanceFactor(root->left)<0) {
      root->left = leftRotate(root->left);
      return rightRotate(root);
    }
    // RR
    if (balanceFactor==-2 && getBalanceFactor(root->right)<=0) {
      return leftRotate(root);
    }
    // RL
    //                y
    //               / 
    //              t1  x 
    //                 /        rightRotate(x)  leftRotate(y)
    //                z   t4 
    //               /    
    //              t2  t3
    if (balanceFactor==-2 && getBalanceFactor(root->right)>0) {
      root->right = rightRotate(root->right);
      return leftRotate(root);
    }
    return root;
  }

public:
  AVL() {
    this->root = nullptr;
    this->size = 0;
  }


  ~AVL() {
    destroy(root);
  }


  // 对节点y进行向左旋转操作,返回旋转后新的根节点x
  //    y                             x
  //  /                            /   
  // T1   x      向左旋转 (y)       y     z
  //     /    - - - - - - - ->   /    / 
  //   T2   z                    T1 T2 T3 T4
  //       / 
  //      T3 T4
  node* leftRotate(node *y) {
    node *x = y->right;
    node *t2 = x->left;
    x->left = y;
    y->right = t2;
    updateHeight(y);
    updateHeight(x);
    return x;
  }

  // 对节点y进行向右旋转操作,返回旋转后新的根节点x
  //        y                              x
  //       /                            /   
  //      x   T4     向右旋转 (y)        z     y
  //     /        - - - - - - - ->    /    / 
  //    z   T3                       T1  T2 T3 T4
  //   / 
  // T1   T2
  node* rightRotate(node *y) {
    node *x = y->left;
    node *t3 = x->right;
    x->right = y;
    y->left = t3;
    updateHeight(y);
    updateHeight(x);
    return x;
  }

  void add(int x) {
    root = add(root, x);
  }


  void destroy(node *root) {
    if (root) {
      destroy(root->left);
      destroy(root->right);
      delete root;
    }
  }

  void bfs() {
    if (!root) return;
    queue<node*> q;
    q.push(root);
    while (!q.empty()) {
      node *f = q.front();q.pop();
      cout << f->val << " ";
      if (f->left) q.push(f->left);
      if (f->right) q.push(f->right);
    }
    cout << endl;
  }

  bool isBST() {
    // 判断中序遍历序列
    vector<int> res;
    inOrder(root, res);
    for (int i=1; i<res.size();   i) {
      if (res[i-1] >= res[i]) return false;
    }
    return true;
  }

  bool isBalanced() {
    return isBalanced(root);
  }

  bool isBalanced(node *root) {
    if (!root) return 1;
    if (abs(getHeight(root->left) - getHeight(root->right)) > 1) return false;
    return isBalanced(root->left) && isBalanced(root->right);
  }


  void inOrder(node *root, vector<int> &res) {
    if (root) {
      inOrder(root->left, res);
      res.push_back(root->val);
      inOrder(root->right, res);
    }
  }
};


int main() {
  int a[] = {1,2,3,4,5};
  int n = sizeof(a)/sizeof(int);
  AVL avl;
  for (int i=0; i<n;   i) avl.add(a[i]);
  avl.bfs();
  cout << avl.isBalanced() << endl;
  return 0;
}

0 人点赞