线段树模板

2023-07-25 19:50:39 浏览数 (2)

概述:

线段树是算法竞赛中常用的数据结构(虽然考场中很少用,毕竟调起来麻烦,区间求和用树状树组还是更加方便代码也短)。

线段树可以在O(logN)的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。简略的描述一下算法思路,线段树是一个二叉树,树的每一个节点存储的都是一个区间内的值(根据具体的题目而定),每个父结点的值由两个子结点的值决定。

但是普通的二分思想并不能体现线段树的精髓所在,线段树的精髓就在于它的懒标记,具体往下看。

算法的实现:

//建议初学者先看无懒标记版,在最下面。

这里以洛谷P3372的区间求和为例

个人习惯:
代码语言:javascript复制
#define pl tr<<1 //左儿子
#define pr tr<<1|1 //右儿子
建树(build)
代码语言:javascript复制
struct segmentTree{
	int l,r; //查询的区间范围
	long long sum ,lz; //区间和,懒标记
}t[N<<2];//要开4*N的大小

void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[l];return;} //如果区间内只有一个树,则赋值,返回
	int mid=(l r)>>1;
	build(l,mid,pl); //建左区间
	build(mid 1,r,pr); //建右区间
	pushup(tr); //关键操作,通过最下层来更新到上层
}
上放(pushup)
代码语言:javascript复制
void pushup(int tr){
	t[tr].sum=t[pl].sum t[pr].sum; //由两个子结点的值更新父结点的值
}
下放(pushdown)

懒标记解释:带有懒标记的值是已经处理完成的确认的值。

代码语言:javascript复制
void pushdown(int tr){
	if(t[tr].lz){
		t[pl].sum =t[tr].lz*(t[pl].r-t[pl].l 1);//左儿子的值加上懒标记的值*区间内数的个数
		t[pr].sum =t[tr].lz*(t[pr].r-t[pr].l 1);//右儿子的值加上懒标记的值*区间内树的个数
		t[pl].lz =t[tr].lz;//懒标记下放
		t[pr].lz =t[tr].lz;//懒标记下放
		t[tr].lz=0;//将父结点的懒标记清零
	}
}
更新(update)

update中的pushup()是我当时学习该算法时的没理解的一个地方,并不是直接更新每个结点的值,而是最后通过pushup()来更新父结点

代码语言:javascript复制
void update(int l,int r,int tr,int num){
	if(l<=t[tr].l&&t[tr].r<=r) {t[tr].sum =num*(t[tr].r-t[tr].l 1);t[tr].lz =num;return;}
	pushdown(tr);//上一行是指如果该区间在查询区间内,则更新该区间值即懒标记,并且返回。(因为有懒标记),如果不包含则懒标记下放
	int mid=(t[tr].l t[tr].r)>>1;//二分
	if(l<=mid) update(l,r,pl,num); //如果左儿子一部分在查询区间内,更新左儿子
	if(mid<r) update(l,r,pr,num); //如果右儿子一部分在查询区间内,更新右儿子
	pushup(tr);//关键的一步
}
查询(query)
代码语言:javascript复制
long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum; 
	pushdown(tr);
	int mid=(t[tr].l t[tr].r)>>1;
	if(l<=mid) ans =query(l,r,pl);
	if(mid<r) ans =query(l,r,pr);
	return ans;
}

例题与示例程序:

1.区间求和

洛谷P3372

代码语言:javascript复制
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#define pl tr<<1
#define pr tr<<1|1
using namespace std;
const int N=1e5 10;
int n,m,a[100010],x,y,k,q;
struct segmentTree{
	int l,r,lz;
	long long sum;
}t[N<<2];
void pushup(int tr){
	t[tr].sum=t[pl].sum t[pr].sum;
}
void pushdown(int tr){
	if(t[tr].lz){
		t[pl].sum =t[tr].lz*(t[pl].r-t[pl].l 1);
		t[pr].sum =t[tr].lz*(t[pr].r-t[pr].l 1);
		t[pl].lz =t[tr].lz;
		t[pr].lz =t[tr].lz;
		t[tr].lz=0;
	}
}
void build(int l,int r,int tr){
	t[tr].l=l,t[tr].r=r;
	if(l==r){t[tr].sum=a[r];return;}
	int mid=(l r)>>1;
	build(l,mid,pl);
	build(mid 1,r,pr);
	pushup(tr);
}
void update(int l,int r,int tr,int num){
	if(l<=t[tr].l&&t[tr].r<=r) {t[tr].sum =num*(t[tr].r-t[tr].l 1);t[tr].lz =num;return;}
	pushdown(tr);
	int mid=(t[tr].r t[tr].l)>>1;
	if(l<=mid)update(l,r,pl,num);
	if(mid<r)update(l,r,pr,num);
	pushup(tr);
}
long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
	pushdown(tr);
	int mid=(t[tr].r t[tr].l)>>1;
	if(l<=mid) ans =query(l,r,pl);
	if(mid<r) ans =query(l,r,pr);
	return ans;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i  )scanf("%d",&a[i]);
		build(1,n,1);
	for(int i=1;i<=m;i  ){
		scanf("%d%d%d",&q,&x,&y);
		if(q==1){
			scanf("%d",&k);
			update(x,y,1,k);
		}
		else{
			printf("%lldn",query(x,y,1));
		}
	}
    return 0;
}
2.区间求乘积

洛谷P3373

代码语言:javascript复制
#include <iostream>
#include <stdio.h>
#include <algorithm>
#define pl tr<<1
#define pr tr<<1|1

using namespace std;
const int N=1e5 10;

int n,m,p,x,y,k,q;
int a[N];
struct segmentTree{
	int l,r;
	long long sum,add=0,mul=1;//add=加,mul=乘
}t[N<<2];
void pushup(int tr){
	t[tr].sum=(t[pl].sum t[pr].sum)%p;
}
void pushdown(int tr){
	t[pl].sum=(t[tr].add*(t[pl].r-t[pl].l 1)%p (t[pl].sum*t[tr].mul)%p)%p;
	t[pr].sum=(t[tr].add*(t[pr].r-t[pr].l 1)%p (t[pr].sum*t[tr].mul)%p)%p;
	t[pl].add=(t[tr].mul*t[pl].add%p t[tr].add)%p;
	t[pl].mul=t[tr].mul*t[pl].mul%p;
	t[pr].add=(t[tr].mul*t[pr].add%p t[tr].add)%p;
	t[pr].mul=t[tr].mul*t[pr].mul%p;
	t[tr].add=0;t[tr].mul=1;
}
void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[l];return;}
	else{
		int mid=(l r)>>1;
		build(l,mid,pl);
		build(mid 1,r,pr);
		pushup(tr);
	}
}
void update1(int l,int r,int tr,int k){//add
	if(l<=t[tr].l&&t[tr].r<=r){
		t[tr].sum=(t[tr].sum k*(t[tr].r-t[tr].l 1)%p)%p;
		t[tr].add=(t[tr].add k%p)%p;
		return;
	}
	pushdown(tr);
	int mid=(t[tr].l t[tr].r)>>1;
	if(l<=mid) update1(l,r,pl,k);
	if(mid<r) update1(l,r,pr,k);
	pushup(tr);
}
void update2(int l,int r,int tr,int k){//mul
	if(l<=t[tr].l&&t[tr].r<=r){
		t[tr].sum=(t[tr].sum*k)%p;
		t[tr].add=(t[tr].add*k)%p;
		t[tr].mul=(t[tr].mul*k)%p;
		return;
	}
	pushdown(tr);
	int mid=(t[tr].l t[tr].r)>>1;
	if(l<=mid) update2(l,r,pl,k);
	if(mid<r) update2(l,r,pr,k);
	pushup(tr);
}
long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
	int mid=(t[tr].l t[tr].r)>>1;
	pushdown(tr);
	if(l<=mid) ans =query(l,r,pl);
	if(mid<r) ans =query(l,r,pr);
	return ans%p;
}
int main(){
	scanf("%d%d%d",&n,&m,&p);
	for(int i=1;i<=n;i  ) scanf("%d",&a[i]);
	build(1,n,1);
	for(int i=1;i<=m;i  ){
		scanf("%d%d%d",&q,&x,&y);
		if(q==1){
			scanf("%d",&k);
			update2(x,y,1,k);
		}
		else if(q==2){
			scanf("%d",&k);
			update1(x,y,1,k);
		}
		else {
			printf("%lldn",query(x,y,1));
		}
	}
    return 0;
}

无懒标记版本:

代码语言:c 复制
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>

using namespace std;
const int N=1e5 10;
int n,m,q,x,y,k;
int a[N];
struct segmenttree{
	int l,r,sum;
}t[N<<2];
void pushup(int tr){
	t[tr].sum=t[tr<<1].sum t[tr<<1|1].sum;
}
void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[r];return;}
	int mid=(l r)>>1;
	build(l,mid,tr<<1);
	build(mid 1,r,tr<<1|1);
	pushup(tr);
}
void update(int l,int r,int tr,int num){
	int mid=(t[tr].l t[tr].r)>>1;
	if(t[tr].l==t[tr].r) {
		t[tr].sum =num;return;
	}
	if(l<=mid)update(l,r,tr<<1,num);
	if(mid<r)update(l,r,tr<<1|1,num);
	pushup(tr);
}
int query(int l,int r,int tr){
	int ans=0;
	if(t[tr].l>=l&&t[tr].r<=r) {return t[tr].sum;}
	int mid=(t[tr].l t[tr].r)>>1;
	if(l<=mid) ans =query(l,r,tr<<1);
	if(mid<r) ans =query(l,r,tr<<1|1);
	return ans;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i  ){
		scanf("%d",&a[i]);
	}
	build(1,n,1);
	for(int i=1;i<=m;i  ){
		scanf("%d",&q);
		if(q==1){
			scanf("%d%d%d",&x,&y,&k);
			update(x,y,1,k);
		}
		else {
			scanf("%d%d",&x,&y);
			cout<<query(x,y,1)<<endl;
		}
	}
}

0 人点赞