概述:
线段树是算法竞赛中常用的数据结构(虽然考场中很少用,毕竟调起来麻烦,区间求和用树状树组还是更加方便代码也短)。
线段树可以在O(logN)的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。简略的描述一下算法思路,线段树是一个二叉树,树的每一个节点存储的都是一个区间内的值(根据具体的题目而定),每个父结点的值由两个子结点的值决定。
但是普通的二分思想并不能体现线段树的精髓所在,线段树的精髓就在于它的懒标记,具体往下看。
算法的实现:
//建议初学者先看无懒标记版,在最下面。
这里以洛谷P3372的区间求和为例
个人习惯:
代码语言:javascript复制#define pl tr<<1 //左儿子
#define pr tr<<1|1 //右儿子
建树(build)
代码语言:javascript复制struct segmentTree{
int l,r; //查询的区间范围
long long sum ,lz; //区间和,懒标记
}t[N<<2];//要开4*N的大小
void build(int l,int r,int tr){
t[tr].l=l;t[tr].r=r;
if(l==r) {t[tr].sum=a[l];return;} //如果区间内只有一个树,则赋值,返回
int mid=(l r)>>1;
build(l,mid,pl); //建左区间
build(mid 1,r,pr); //建右区间
pushup(tr); //关键操作,通过最下层来更新到上层
}
上放(pushup)
代码语言:javascript复制void pushup(int tr){
t[tr].sum=t[pl].sum t[pr].sum; //由两个子结点的值更新父结点的值
}
下放(pushdown)
懒标记解释:带有懒标记的值是已经处理完成的确认的值。
代码语言:javascript复制void pushdown(int tr){
if(t[tr].lz){
t[pl].sum =t[tr].lz*(t[pl].r-t[pl].l 1);//左儿子的值加上懒标记的值*区间内数的个数
t[pr].sum =t[tr].lz*(t[pr].r-t[pr].l 1);//右儿子的值加上懒标记的值*区间内树的个数
t[pl].lz =t[tr].lz;//懒标记下放
t[pr].lz =t[tr].lz;//懒标记下放
t[tr].lz=0;//将父结点的懒标记清零
}
}
更新(update)
update中的pushup()是我当时学习该算法时的没理解的一个地方,并不是直接更新每个结点的值,而是最后通过pushup()来更新父结点
代码语言:javascript复制void update(int l,int r,int tr,int num){
if(l<=t[tr].l&&t[tr].r<=r) {t[tr].sum =num*(t[tr].r-t[tr].l 1);t[tr].lz =num;return;}
pushdown(tr);//上一行是指如果该区间在查询区间内,则更新该区间值即懒标记,并且返回。(因为有懒标记),如果不包含则懒标记下放
int mid=(t[tr].l t[tr].r)>>1;//二分
if(l<=mid) update(l,r,pl,num); //如果左儿子一部分在查询区间内,更新左儿子
if(mid<r) update(l,r,pr,num); //如果右儿子一部分在查询区间内,更新右儿子
pushup(tr);//关键的一步
}
查询(query)
代码语言:javascript复制long long query(int l,int r,int tr){
long long ans=0;
if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
pushdown(tr);
int mid=(t[tr].l t[tr].r)>>1;
if(l<=mid) ans =query(l,r,pl);
if(mid<r) ans =query(l,r,pr);
return ans;
}
例题与示例程序:
1.区间求和
洛谷P3372
代码语言:javascript复制#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#define pl tr<<1
#define pr tr<<1|1
using namespace std;
const int N=1e5 10;
int n,m,a[100010],x,y,k,q;
struct segmentTree{
int l,r,lz;
long long sum;
}t[N<<2];
void pushup(int tr){
t[tr].sum=t[pl].sum t[pr].sum;
}
void pushdown(int tr){
if(t[tr].lz){
t[pl].sum =t[tr].lz*(t[pl].r-t[pl].l 1);
t[pr].sum =t[tr].lz*(t[pr].r-t[pr].l 1);
t[pl].lz =t[tr].lz;
t[pr].lz =t[tr].lz;
t[tr].lz=0;
}
}
void build(int l,int r,int tr){
t[tr].l=l,t[tr].r=r;
if(l==r){t[tr].sum=a[r];return;}
int mid=(l r)>>1;
build(l,mid,pl);
build(mid 1,r,pr);
pushup(tr);
}
void update(int l,int r,int tr,int num){
if(l<=t[tr].l&&t[tr].r<=r) {t[tr].sum =num*(t[tr].r-t[tr].l 1);t[tr].lz =num;return;}
pushdown(tr);
int mid=(t[tr].r t[tr].l)>>1;
if(l<=mid)update(l,r,pl,num);
if(mid<r)update(l,r,pr,num);
pushup(tr);
}
long long query(int l,int r,int tr){
long long ans=0;
if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
pushdown(tr);
int mid=(t[tr].r t[tr].l)>>1;
if(l<=mid) ans =query(l,r,pl);
if(mid<r) ans =query(l,r,pr);
return ans;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i )scanf("%d",&a[i]);
build(1,n,1);
for(int i=1;i<=m;i ){
scanf("%d%d%d",&q,&x,&y);
if(q==1){
scanf("%d",&k);
update(x,y,1,k);
}
else{
printf("%lldn",query(x,y,1));
}
}
return 0;
}
2.区间求乘积
洛谷P3373
代码语言:javascript复制#include <iostream>
#include <stdio.h>
#include <algorithm>
#define pl tr<<1
#define pr tr<<1|1
using namespace std;
const int N=1e5 10;
int n,m,p,x,y,k,q;
int a[N];
struct segmentTree{
int l,r;
long long sum,add=0,mul=1;//add=加,mul=乘
}t[N<<2];
void pushup(int tr){
t[tr].sum=(t[pl].sum t[pr].sum)%p;
}
void pushdown(int tr){
t[pl].sum=(t[tr].add*(t[pl].r-t[pl].l 1)%p (t[pl].sum*t[tr].mul)%p)%p;
t[pr].sum=(t[tr].add*(t[pr].r-t[pr].l 1)%p (t[pr].sum*t[tr].mul)%p)%p;
t[pl].add=(t[tr].mul*t[pl].add%p t[tr].add)%p;
t[pl].mul=t[tr].mul*t[pl].mul%p;
t[pr].add=(t[tr].mul*t[pr].add%p t[tr].add)%p;
t[pr].mul=t[tr].mul*t[pr].mul%p;
t[tr].add=0;t[tr].mul=1;
}
void build(int l,int r,int tr){
t[tr].l=l;t[tr].r=r;
if(l==r) {t[tr].sum=a[l];return;}
else{
int mid=(l r)>>1;
build(l,mid,pl);
build(mid 1,r,pr);
pushup(tr);
}
}
void update1(int l,int r,int tr,int k){//add
if(l<=t[tr].l&&t[tr].r<=r){
t[tr].sum=(t[tr].sum k*(t[tr].r-t[tr].l 1)%p)%p;
t[tr].add=(t[tr].add k%p)%p;
return;
}
pushdown(tr);
int mid=(t[tr].l t[tr].r)>>1;
if(l<=mid) update1(l,r,pl,k);
if(mid<r) update1(l,r,pr,k);
pushup(tr);
}
void update2(int l,int r,int tr,int k){//mul
if(l<=t[tr].l&&t[tr].r<=r){
t[tr].sum=(t[tr].sum*k)%p;
t[tr].add=(t[tr].add*k)%p;
t[tr].mul=(t[tr].mul*k)%p;
return;
}
pushdown(tr);
int mid=(t[tr].l t[tr].r)>>1;
if(l<=mid) update2(l,r,pl,k);
if(mid<r) update2(l,r,pr,k);
pushup(tr);
}
long long query(int l,int r,int tr){
long long ans=0;
if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
int mid=(t[tr].l t[tr].r)>>1;
pushdown(tr);
if(l<=mid) ans =query(l,r,pl);
if(mid<r) ans =query(l,r,pr);
return ans%p;
}
int main(){
scanf("%d%d%d",&n,&m,&p);
for(int i=1;i<=n;i ) scanf("%d",&a[i]);
build(1,n,1);
for(int i=1;i<=m;i ){
scanf("%d%d%d",&q,&x,&y);
if(q==1){
scanf("%d",&k);
update2(x,y,1,k);
}
else if(q==2){
scanf("%d",&k);
update1(x,y,1,k);
}
else {
printf("%lldn",query(x,y,1));
}
}
return 0;
}
无懒标记版本:
代码语言:c 复制#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
using namespace std;
const int N=1e5 10;
int n,m,q,x,y,k;
int a[N];
struct segmenttree{
int l,r,sum;
}t[N<<2];
void pushup(int tr){
t[tr].sum=t[tr<<1].sum t[tr<<1|1].sum;
}
void build(int l,int r,int tr){
t[tr].l=l;t[tr].r=r;
if(l==r) {t[tr].sum=a[r];return;}
int mid=(l r)>>1;
build(l,mid,tr<<1);
build(mid 1,r,tr<<1|1);
pushup(tr);
}
void update(int l,int r,int tr,int num){
int mid=(t[tr].l t[tr].r)>>1;
if(t[tr].l==t[tr].r) {
t[tr].sum =num;return;
}
if(l<=mid)update(l,r,tr<<1,num);
if(mid<r)update(l,r,tr<<1|1,num);
pushup(tr);
}
int query(int l,int r,int tr){
int ans=0;
if(t[tr].l>=l&&t[tr].r<=r) {return t[tr].sum;}
int mid=(t[tr].l t[tr].r)>>1;
if(l<=mid) ans =query(l,r,tr<<1);
if(mid<r) ans =query(l,r,tr<<1|1);
return ans;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i ){
scanf("%d",&a[i]);
}
build(1,n,1);
for(int i=1;i<=m;i ){
scanf("%d",&q);
if(q==1){
scanf("%d%d%d",&x,&y,&k);
update(x,y,1,k);
}
else {
scanf("%d%d",&x,&y);
cout<<query(x,y,1)<<endl;
}
}
}