文章目录
- 1. 简介
- 2. 单点更新,区间查询
- 3. 区间更新,单点查询
- 4. 区间更新,区间查询
- 5. 区间最值模板
- 6. 参考
有这样一类问题,给定一个数列,让你求某段区间内和。如果对某个值或某段区间内的值进行修改后,如何快速的求和。如果线性执行更新操作或求和操作,无疑时间复杂度太大了。 那么借助分治的思想,在执行更新区间的操作时,把它转化为几段区间的更新,同样求和操作时,也通过维护分段区间的和来达到快速求区间和的问题。线段树就是利用二叉树这种数据结构,来维护区间信息的一种数据结构。
简介
二叉树的每个结点,都代表一段区间。考虑到二叉树的结构,他的根结点就维护从1~n这段区间的信息,根结点的左子树维护1~mid这段区间,右子树维护mid 1~n这段区间,以此递归向下。
一般每个结点需要维护区间修改的信息,以及区间和的信息。
二叉树的叶子结点(从左到右)储存数列的1~n。 修改操作分为两类,一种是在区间的原数值基础上进行修改:加或减去val、乘以val、开根号、、、等;一种是将该区间的值改为val;不同的操作在维护区间和时,相应的有些变化。下面以区间和问题为例,对线段树的实现进行讲解。 如果实现线段树一般需要以下几种操作:
代码语言:javascript复制build(start,end,vals) //o(n)
update(index,value) //o(logn)
rangeQuery(start,end) //o(logn k)
另外线段树可以用结构体指针来索引左右孩子,也可以用数组来存储(申请的长度至少要4n),本文选用前者。
单点更新,区间查询
307.Range Sum Query - Mutable 如果做过一些二叉树递归类的题,这个应该就挺好理解了。 几年前我尝试学习线段树的时候,感觉好难。后来刷了一些二叉树类的题,现在再来学习线段树,发现还是挺好理解的。所以如果有些算法学起来困难,可能是前置知识的掌握还不到位。 二叉树的每个结点需要用start、end存储线段起止号,sum存储该段区间的和,另外left、right索引左右子树。 建树过程用buildTree()递归创建就好了,从根节点开始创建,终止条件是线段的start==end(到达叶子节点了,从左到右看就是原数列)。 单点更新:由于是单点更新,所以一定会从根节点往下找,直到相应的叶子节点。然后更新叶子节点。最后还要在回溯的过程中更新每一个包涵该点的线段。 区间查询:对于要查询的区间,如果都被包涵在左子树,就去左子树查询;如果被包涵在右子树,就去右子树查询;如果要查询的区间在左右子树标示的线段中都有一部分,那就分别将左右子树查询的结果加起来。
代码语言:javascript复制//线段树是利用二分思想解决区间问题
class SegmentTreeNode{
public:
SegmentTreeNode(int start,int end,int sum,
SegmentTreeNode *left=nullptr,SegmentTreeNode *right=nullptr):
start(start),end(end),sum(sum),left(left),right(right) {}
//禁用赋值构造和拷贝构造函数
SegmentTreeNode(const SegmentTreeNode&)=delete;
SegmentTreeNode& operator=(const SegmentTreeNode&)=delete;
~SegmentTreeNode(){
delete left;
delete right;
left=right=nullptr;
}
public:
int start;
int end;
int sum; //可以是max,min
SegmentTreeNode *left;
SegmentTreeNode *right;
}; //end class SegmentTreeNode
class NumArray {
public:
NumArray(vector<int>& nums) {
nums_.swap(nums);
if(!nums_.empty()){
root_.reset(buildTree(0,nums_.size()-1));
}
}
void update(int i, int val) {
updateTree(root_.get(),i,val-nums_[i]);
}
int sumRange(int i, int j) {
return sumRange(root_.get(),i,j);
}
private:
//创建线段树
SegmentTreeNode *buildTree(int start,int end){
if(start==end){
return new SegmentTreeNode(start,end,nums_[start]);
}
int mid=start ((end-start)>>1);
SegmentTreeNode *left=buildTree(start,mid);
SegmentTreeNode *right=buildTree(mid 1,end);
return new SegmentTreeNode(start,end,left->sum right->sum,left,right);
}
//更新线段树,将i处的值增加addval
void updateTree(SegmentTreeNode *root,int i,int addval){
if(root->start==i && root->end==i){
root->sum =addval;
nums_[i] =addval;
return ;
}
int mid=root->start ((root->end-root->start)>>1);
if(i<=mid){
updateTree(root->left,i,addval);
}else{
updateTree(root->right,i,addval);
}
root->sum =addval;
}
//计算区间i到j的和
int sumRange(SegmentTreeNode *root,int i,int j){
if(root->start==i && root->end==j){
return root->sum;
}
int mid=root->start ((root->end-root->start)>>1);
if(i>mid){
return sumRange(root->right,i,j);
}else if(j<=mid){
return sumRange(root->left,i,j);
}else{
return sumRange(root->left,i,mid) sumRange(root->right,mid 1,j);
}
}
/* 打印叶子节点,用于调试
void printTree(SegmentTreeNode *root){
if(root->left==nullptr && root->right==nullptr){
cout<<root->sum<<" ";
return ;
}
printTree(root->left);
printTree(root->right);
}
*/
private:
vector<int> nums_;
std::unique_ptr<SegmentTreeNode> root_;
}; //end class NumArray
区间更新,单点查询
hdu 1556 Color the ball 对于这类问题,算法的思想是在区间更新的时候不用全部实施到该区间的每个点上,只将该区间分为几部分,然后实施到分开的几个区间上就好。等到单点查询的时候将单点的值加上所有对该点的更新就好。 由于对区间进行更新,所以二叉树每个节点上需要多一个updateval来维护对区间的更新。 区间更新函数,跟上一类问题中的区间查询有点相似。 单点更新:从根节点向下找到目标点,然后在回溯的时候直接加上每个每个包涵该点的区间维护的updateval。
代码语言:javascript复制#include <bits/stdc .h>
using namespace std;
class SegmentTreeNode{
public:
SegmentTreeNode(int start,int end,int sum,int val=0,SegmentTreeNode *left=nullptr,SegmentTreeNode *right=nullptr):
start(start),end(end),sum(sum),updateval(val),left(left),right(right) {}
//禁用赋值构造和拷贝构造函数
SegmentTreeNode(const SegmentTreeNode&)=delete;
SegmentTreeNode& operator=(const SegmentTreeNode&)=delete;
~SegmentTreeNode(){
delete left;
delete right;
left=right=nullptr;
}
public:
int start;
int end;
int sum; //可以是max,min
int updateval; //用来记录当前区间上update过的数值
SegmentTreeNode *left;
SegmentTreeNode *right;
}; //end class SegmentTreeNode
class NumArray {
public:
NumArray(vector<int>& nums) {
nums_.swap(nums);
if(!nums_.empty()){
root_.reset(buildTree(0,nums_.size()-1));
}
}
void update(int s, int e, int val) {
updateTree(root_.get(),s,e,val);
}
int query(int i) {
return queryTree(root_.get(),i);
}
private:
//创建线段树
SegmentTreeNode *buildTree(int start,int end){
if(start==end){
return new SegmentTreeNode(start,end,nums_[start]);
}
int mid=start ((end-start)>>1);
SegmentTreeNode *left=buildTree(start,mid);
SegmentTreeNode *right=buildTree(mid 1,end);
return new SegmentTreeNode(start,end,left->sum right->sum,0,left,right);
}
//区间更新线段树,将区间s~e处的值增加addval
void updateTree(SegmentTreeNode *root,int s,int e,int val){
if(root->start==s && root->end==e){
root->updateval =val;
return ;
}
int mid=root->start ((root->end-root->start)>>1);
if(s>mid){
updateTree(root->right,s,e,val);
}else if(e<=mid){
updateTree(root->left,s,e,val);
}else{
updateTree(root->left,s,mid,val);
updateTree(root->right,mid 1,e,val);
}
}
//单点查询
int queryTree(SegmentTreeNode *root,int i){
if(root->start==i && root->end==i){
return root->sum root->updateval;
}
int mid=root->start ((root->end-root->start)>>1);
if(i<=mid){
return queryTree(root->left,i) root->updateval;
}else{
return queryTree(root->right,i) root->updateval;
}
}
private:
vector<int> nums_;
std::unique_ptr<SegmentTreeNode> root_;
}; //end class NumArray
int main()
{
std::ios::sync_with_stdio(0);
int N;
int a,b;
while(cin>>N){
if(N==0) break;
vector<int> tmp(N 1,0);
NumArray numarry(tmp);
for(int i=0;i<N;i ){
cin>>a>>b;
numarry.update(a,b,1);
}
if(N==1){
cout<<numarry.query(1);
return 0;
}
for(int i=0;i<N;i ){
cout<<numarry.query(i 1);
if(i!=N-1){
cout<<" ";
}else{
cout<<endl;
}
}
}
return 0;
}
区间更新,区间查询
洛谷oj:P3372【模板】线段树1
以下有两个版本,第一个是pushdown版本。 添加pushdown()后,如果一个数列1~8, 第一次更新1~4,就先将该操作实施到根节点的左孩子上就可以了(有的实现专门用个lazyflag标记,其实不用,如果updateval不为0,则说明lazyflag为1),然后更新根结点的sum。 如果第二次再更新3~4,在向下寻找线段3~4的过程中,要将之前的更新操作往下落实。于是就将1~4上的updateval清零,然后将该更新操作往下分别实施到1~2和3~4上。将寻找3~4的路径上的更新操作都落实到3~4上之后,再执行3~4的更新操作。然后回溯的过程中更新每个结点上的sum。 在查询的时候,如果查询3~3区间,也是需要依次pushdown(),将之前的区间更新落实到3~3区间上,然后返回区间3~3那个结点的sum就可以了。
代码语言:javascript复制#include <bits/stdc .h>
using namespace std;
class SegmentTreeNode{
public:
SegmentTreeNode(int start,int end,long long sum,long long val=0,SegmentTreeNode *left=nullptr,SegmentTreeNode *right=nullptr):
start(start),end(end),sum(sum),updateval(val),left(left),right(right) {}
//禁用赋值构造和拷贝构造函数
SegmentTreeNode(const SegmentTreeNode&)=delete;
SegmentTreeNode& operator=(const SegmentTreeNode&)=delete;
~SegmentTreeNode(){
delete left;
delete right;
left=right=nullptr;
}
public:
int start;
int end;
long long sum; //可以是max,min
long long updateval; //用来记录当前区间上update过的数值
SegmentTreeNode *left;
SegmentTreeNode *right;
}; //end class SegmentTreeNode
class NumArray {
public:
NumArray(vector<long long>& nums) {
nums_.swap(nums);
if(!nums_.empty()){
root_.reset(buildTree(0,nums_.size()-1));
}
}
void update(int s, int e, int val) {
updateTree(root_.get(),s,e,val);
}
long long query(int s,int e) {
return queryTree(root_.get(),s,e);
}
private:
//创建线段树
SegmentTreeNode *buildTree(int start,int end){
if(start==end){
return new SegmentTreeNode(start,end,nums_[start]);
}
int mid=start ((end-start)>>1);
SegmentTreeNode *left=buildTree(start,mid);
SegmentTreeNode *right=buildTree(mid 1,end);
return new SegmentTreeNode(start,end,left->sum right->sum,0,left,right);
}
//区间更新线段树,将区间s~e处的值增加addval
void updateTree(SegmentTreeNode *root,int s,int e,int val){
if(root->start==s && root->end==e){
root->sum =val*(e-s 1);
root->updateval =val;
return ;
}
pushdown(root);
int mid=root->start ((root->end-root->start)>>1);
if(s>mid){
updateTree(root->right,s,e,val);
}else if(e<=mid){
updateTree(root->left,s,e,val);
}else{
updateTree(root->left,s,mid,val);
updateTree(root->right,mid 1,e,val);
}
root->sum=root->left->sum root->right->sum;
}
//区间查询
long long queryTree(SegmentTreeNode *root,int s,int e){
if(root->start==s && root->end==e){
return root->sum;
}
pushdown(root);
int mid=root->start ((root->end-root->start)>>1);
if(e<=mid){
return queryTree(root->left,s,e);
}else if(s>mid){
return queryTree(root->right,s,e);
}else{
return queryTree(root->left,s,mid) queryTree(root->right,mid 1,e);
}
}
void pushdown(SegmentTreeNode *root){
if(root->updateval){
root->left->updateval =root->updateval;
root->right->updateval =root->updateval;
int mid=root->start ((root->end-root->start)>>1);
root->left->sum =root->updateval*(mid-root->start 1);
root->right->sum =root->updateval*(root->end-mid);
root->updateval=0;
}
}
private:
vector<long long> nums_;
std::unique_ptr<SegmentTreeNode> root_;
}; //end class NumArray
int main()
{
std::ios::sync_with_stdio(0);
long long n,m;
long long tmp,oper,x,y,k;
vector<long long> vi;
cin>>n>>m;
vi.resize(n 1);
for(int i=1;i<=n;i ){
cin>>vi[i];
}
NumArray numarry(vi);
for(int i=0;i<m;i ){
cin>>oper;
if(oper==1){
cin>>x>>y>>k;
numarry.update(x,y,k);
}else{
cin>>x>>y;
cout<<numarry.query(x,y)<<endl;
}
}
return 0;
}
标记永久化版本,去掉了pushdown函数,比上一版本有一常数优化。 pushdown版本的是每一次更新区间时,都顺带着将之前的更新向下落实。但是我们其实可以采取”区间更新,单点查询”时的做法,每次更新时实施到相应区间上,不用落实到最下面。然后在每次查询完,回溯的时候,把每个区间上的更新都加上。
代码语言:javascript复制#include <bits/stdc .h>
using namespace std;
class SegmentTreeNode{
public:
SegmentTreeNode(int start,int end,long long sum,long long val=0,SegmentTreeNode *left=nullptr,SegmentTreeNode *right=nullptr):
start(start),end(end),sum(sum),updateval(val),left(left),right(right) {}
//禁用赋值构造和拷贝构造函数
SegmentTreeNode(const SegmentTreeNode&)=delete;
SegmentTreeNode& operator=(const SegmentTreeNode&)=delete;
~SegmentTreeNode(){
delete left;
delete right;
left=right=nullptr;
}
public:
int start;
int end;
long long sum; //可以是max,min
long long updateval; //用来记录当前区间上update过的数值
SegmentTreeNode *left;
SegmentTreeNode *right;
}; //end class SegmentTreeNode
class NumArray {
public:
NumArray(vector<long long>& nums) {
nums_.swap(nums);
if(!nums_.empty()){
root_.reset(buildTree(0,nums_.size()-1));
}
}
void update(int s, int e, int val) {
updateTree(root_.get(),s,e,val);
}
long long query(int s,int e) {
return queryTree(root_.get(),s,e);
}
private:
//创建线段树
SegmentTreeNode *buildTree(int start,int end){
if(start==end){
return new SegmentTreeNode(start,end,nums_[start]);
}
int mid=start ((end-start)>>1);
SegmentTreeNode *left=buildTree(start,mid);
SegmentTreeNode *right=buildTree(mid 1,end);
return new SegmentTreeNode(start,end,left->sum right->sum,0,left,right);
}
//区间更新线段树,将区间s~e处的值增加addval
void updateTree(SegmentTreeNode *root,int s,int e,int val){
root->sum =val*(e-s 1); //每次调用该函数,只有整棵线段树的根节点到目标结点的sum值会被更新
if(root->start==s && root->end==e){
root->updateval =val;
return ;
}
int mid=root->start ((root->end-root->start)>>1);
if(s>mid){
updateTree(root->right,s,e,val);
}else if(e<=mid){
updateTree(root->left,s,e,val);
}else{
updateTree(root->left,s,mid,val);
updateTree(root->right,mid 1,e,val);
}
}
//区间查询
long long queryTree(SegmentTreeNode *root,int s,int e){
if(root->start==s && root->end==e){
return root->sum;
}
int mid=root->start ((root->end-root->start)>>1);
if(e<=mid){
return queryTree(root->left,s,e) root->updateval*(e-s 1);
}else if(s>mid){
return queryTree(root->right,s,e) root->updateval*(e-s 1);
}else{
return queryTree(root->left,s,mid) queryTree(root->right,mid 1,e) root->updateval*(e-s 1);
}
}
private:
vector<long long> nums_;
std::unique_ptr<SegmentTreeNode> root_;
}; //end class NumArray
int main(){
std::ios::sync_with_stdio(0);
long long n,m;
long long tmp,oper,x,y,k;
vector<long long> vi;
cin>>n>>m;
vi.resize(n 1);
for(int i=1;i<=n;i ){
cin>>vi[i];
}
NumArray numarry(vi);
for(int i=0;i<m;i ){
cin>>oper;
if(oper==1){
cin>>x>>y>>k;
numarry.update(x,y,k);
}else{
cin>>x>>y;
cout<<numarry.query(x,y)<<endl;
}
}
return 0;
}
区间最值模板
代码语言:javascript复制class SegmentTreeNode2{
public:
SegmentTreeNode2(int start,int end,int max,int min,
SegmentTreeNode2 *left=nullptr,SegmentTreeNode2 *right=nullptr):
start(start),end(end),maxx(max),minn(min),left(left),right(right) {}
//禁用赋值构造和拷贝构造函数
SegmentTreeNode2(const SegmentTreeNode2&)=delete;
SegmentTreeNode2& operator=(const SegmentTreeNode2&)=delete;
~SegmentTreeNode2(){
delete left;
delete right;
left=right=nullptr;
}
public:
int start;
int end;
int maxx;
int minn;
SegmentTreeNode2 *left;
SegmentTreeNode2 *right;
}; //end class SegmentTreeNode2
class NumArray {
public:
NumArray(vector<int>& nums) {
nums_.swap(nums);
if(!nums_.empty()){
root_.reset(buildTree(0,nums_.size()-1));
}
}
int getMax(int i, int j) {
return getMax(root_.get(),i,j);
}
int getMin(int i,int j){
return getMin(root_.get(),i,j);
}
private:
//创建线段树
SegmentTreeNode2 *buildTree(int start,int end){
if(start==end){
return new SegmentTreeNode2(start,end,nums_[start],nums_[start]);
}
int mid=start ((end-start)>>1);
SegmentTreeNode2 *left=buildTree(start,mid);
SegmentTreeNode2 *right=buildTree(mid 1,end);
return new SegmentTreeNode2(start,end,max(left->maxx,right->maxx),min(left->minn,right->minn),left,right);
}
int getMax(SegmentTreeNode2 *root,int i,int j){
if(root->start==i && root->end==j){
return root->maxx;
}
int mid=root->start ((root->end-root->start)>>1);
if(i>mid){
return getMax(root->right,i,j);
}else if(j<=mid){
return getMax(root->left,i,j);
}else{
return max(getMax(root->left,i,mid),getMax(root->right,mid 1,j));
}
}
int getMin(SegmentTreeNode2 *root,int i,int j){
if(root->start==i && root->end==j){
return root->minn;
}
int mid=root->start ((root->end-root->start)>>1);
if(i>mid){
return getMin(root->right,i,j);
}else if(j<=mid){
return getMin(root->left,i,j);
}else{
return min(getMin(root->left,i,mid),getMin(root->right,mid 1,j));
}
}
private:
vector<int> nums_;
std::unique_ptr<SegmentTreeNode2> root_;
}; //end class NumArray
class Solution {
public:
/**
* @param num: array of num
* @param ask: Interval pairs
* @return: return the sum of xor
*/
int Intervalxor(vector<int> &num, vector<vector<int>> &ask) {
// write your code here
NumArray na(num);
int res=na.getMax(ask[0][0]-1,ask[0][1]-1) na.getMin(ask[0][2]-1,ask[0][3]-1);
for(int i=1;i<ask.size();i ){
res^=(na.getMax(ask[i][0]-1,ask[i][1]-1) na.getMin(ask[i][2]-1,ask[i][3]-1));
}
return res;
}
};
参考
- 花花酱 LeetCode Segment Tree - 刷题找工作 SP14
- 线段树 从入门到进阶
- 线段树标记永久化 学习笔记【线段树】
- 使用线段树实现简单的内存管理
- 线段树详解