线段树模板

概述:

线段树是算法竞赛中常用的数据结构(虽然考场中很少用,毕竟调起来麻烦,区间求和用树状树组还是更加方便代码也短)。

线段树可以在O(logN)的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。简略的描述一下算法思路,线段树是一个二叉树,树的每一个节点存储的都是一个区间内的值(根据具体的题目而定),每个父结点的值由两个子结点的值决定。

但是普通的二分思想并不能体现线段树的精髓所在,线段树的精髓就在于它的懒标记,具体往下看。

算法的实现:

//建议初学者先看无懒标记版,在最下面。

这里以洛谷P3372的区间求和为例

个人习惯:

#define pl tr<<1 //左儿子
#define pr tr<<1|1 //右儿子

建树(build)

struct segmentTree{
	int l,r; //查询的区间范围
	long long sum ,lz; //区间和,懒标记
}t[N<<2];//要开4*N的大小

void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[l];return;} //如果区间内只有一个树,则赋值,返回
	int mid=(l+r)>>1;
	build(l,mid,pl); //建左区间
	build(mid+1,r,pr); //建右区间
	pushup(tr); //关键操作,通过最下层来更新到上层
}

上放(pushup)

void pushup(int tr){
	t[tr].sum=t[pl].sum+t[pr].sum; //由两个子结点的值更新父结点的值
}

下放(pushdown)

懒标记解释:带有懒标记的值是已经处理完成的确认的值。


void pushdown(int tr){
	if(t[tr].lz){
		t[pl].sum+=t[tr].lz*(t[pl].r-t[pl].l+1);//左儿子的值加上懒标记的值*区间内数的个数
		t[pr].sum+=t[tr].lz*(t[pr].r-t[pr].l+1);//右儿子的值加上懒标记的值*区间内树的个数
		t[pl].lz+=t[tr].lz;//懒标记下放
		t[pr].lz+=t[tr].lz;//懒标记下放
		t[tr].lz=0;//将父结点的懒标记清零
	}
}

更新(update)

update中的pushup()是我当时学习该算法时的没理解的一个地方,并不是直接更新每个结点的值,而是最后通过pushup()来更新父结点


void update(int l,int r,int tr,int num){
	if(l<=t[tr].l&&t[tr].r<=r) {t[tr].sum+=num*(t[tr].r-t[tr].l+1);t[tr].lz+=num;return;}
	pushdown(tr);//上一行是指如果该区间在查询区间内,则更新该区间值即懒标记,并且返回。(因为有懒标记),如果不包含则懒标记下放
	int mid=(t[tr].l+t[tr].r)>>1;//二分
	if(l<=mid) update(l,r,pl,num); //如果左儿子一部分在查询区间内,更新左儿子
	if(mid<r) update(l,r,pr,num); //如果右儿子一部分在查询区间内,更新右儿子
	pushup(tr);//关键的一步
}

查询(query)


long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum; 
	pushdown(tr);
	int mid=(t[tr].l+t[tr].r)>>1;
	if(l<=mid) ans+=query(l,r,pl);
	if(mid<r) ans+=query(l,r,pr);
	return ans;
}

例题与示例程序:

1.区间求和

洛谷P3372

#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#define pl tr<<1
#define pr tr<<1|1
using namespace std;
const int N=1e5+10;
int n,m,a[100010],x,y,k,q;
struct segmentTree{
	int l,r,lz;
	long long sum;
}t[N<<2];
void pushup(int tr){
	t[tr].sum=t[pl].sum+t[pr].sum;
}
void pushdown(int tr){
	if(t[tr].lz){
		t[pl].sum+=t[tr].lz*(t[pl].r-t[pl].l+1);
		t[pr].sum+=t[tr].lz*(t[pr].r-t[pr].l+1);
		t[pl].lz+=t[tr].lz;
		t[pr].lz+=t[tr].lz;
		t[tr].lz=0;
	}
}
void build(int l,int r,int tr){
	t[tr].l=l,t[tr].r=r;
	if(l==r){t[tr].sum=a[r];return;}
	int mid=(l+r)>>1;
	build(l,mid,pl);
	build(mid+1,r,pr);
	pushup(tr);
}
void update(int l,int r,int tr,int num){
	if(l<=t[tr].l&&t[tr].r<=r) {t[tr].sum+=num*(t[tr].r-t[tr].l+1);t[tr].lz+=num;return;}
	pushdown(tr);
	int mid=(t[tr].r+t[tr].l)>>1;
	if(l<=mid)update(l,r,pl,num);
	if(mid<r)update(l,r,pr,num);
	pushup(tr);
}
long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
	pushdown(tr);
	int mid=(t[tr].r+t[tr].l)>>1;
	if(l<=mid) ans+=query(l,r,pl);
	if(mid<r) ans+=query(l,r,pr);
	return ans;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
		build(1,n,1);
	for(int i=1;i<=m;i++){
		scanf("%d%d%d",&q,&x,&y);
		if(q==1){
			scanf("%d",&k);
			update(x,y,1,k);
		}
		else{
			printf("%lld\n",query(x,y,1));
		}
	}
    return 0;
}

2.区间求乘积

洛谷P3373

#include <iostream>
#include <stdio.h>
#include <algorithm>
#define pl tr<<1
#define pr tr<<1|1

using namespace std;
const int N=1e5+10;

int n,m,p,x,y,k,q;
int a[N];
struct segmentTree{
	int l,r;
	long long sum,add=0,mul=1;//add=加,mul=乘
}t[N<<2];
void pushup(int tr){
	t[tr].sum=(t[pl].sum+t[pr].sum)%p;
}
void pushdown(int tr){
	t[pl].sum=(t[tr].add*(t[pl].r-t[pl].l+1)%p+(t[pl].sum*t[tr].mul)%p)%p;
	t[pr].sum=(t[tr].add*(t[pr].r-t[pr].l+1)%p+(t[pr].sum*t[tr].mul)%p)%p;
	t[pl].add=(t[tr].mul*t[pl].add%p+t[tr].add)%p;
	t[pl].mul=t[tr].mul*t[pl].mul%p;
	t[pr].add=(t[tr].mul*t[pr].add%p+t[tr].add)%p;
	t[pr].mul=t[tr].mul*t[pr].mul%p;
	t[tr].add=0;t[tr].mul=1;
}
void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[l];return;}
	else{
		int mid=(l+r)>>1;
		build(l,mid,pl);
		build(mid+1,r,pr);
		pushup(tr);
	}
}
void update1(int l,int r,int tr,int k){//add
	if(l<=t[tr].l&&t[tr].r<=r){
		t[tr].sum=(t[tr].sum+k*(t[tr].r-t[tr].l+1)%p)%p;
		t[tr].add=(t[tr].add+k%p)%p;
		return;
	}
	pushdown(tr);
	int mid=(t[tr].l+t[tr].r)>>1;
	if(l<=mid) update1(l,r,pl,k);
	if(mid<r) update1(l,r,pr,k);
	pushup(tr);
}
void update2(int l,int r,int tr,int k){//mul
	if(l<=t[tr].l&&t[tr].r<=r){
		t[tr].sum=(t[tr].sum*k)%p;
		t[tr].add=(t[tr].add*k)%p;
		t[tr].mul=(t[tr].mul*k)%p;
		return;
	}
	pushdown(tr);
	int mid=(t[tr].l+t[tr].r)>>1;
	if(l<=mid) update2(l,r,pl,k);
	if(mid<r) update2(l,r,pr,k);
	pushup(tr);
}
long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
	int mid=(t[tr].l+t[tr].r)>>1;
	pushdown(tr);
	if(l<=mid) ans+=query(l,r,pl);
	if(mid<r) ans+=query(l,r,pr);
	return ans%p;
}
int main(){
	scanf("%d%d%d",&n,&m,&p);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	build(1,n,1);
	for(int i=1;i<=m;i++){
		scanf("%d%d%d",&q,&x,&y);
		if(q==1){
			scanf("%d",&k);
			update2(x,y,1,k);
		}
		else if(q==2){
			scanf("%d",&k);
			update1(x,y,1,k);
		}
		else {
			printf("%lld\n",query(x,y,1));
		}
	}
    return 0;
}

无懒标记版本:

#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>

using namespace std;
const int N=1e5+10;
int n,m,q,x,y,k;
int a[N];
struct segmenttree{
	int l,r,sum;
}t[N<<2];
void pushup(int tr){
	t[tr].sum=t[tr<<1].sum+t[tr<<1|1].sum;
}
void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[r];return;}
	int mid=(l+r)>>1;
	build(l,mid,tr<<1);
	build(mid+1,r,tr<<1|1);
	pushup(tr);
}
void update(int l,int r,int tr,int num){
	int mid=(t[tr].l+t[tr].r)>>1;
	if(t[tr].l==t[tr].r) {
		t[tr].sum+=num;return;
	}
	if(l<=mid)update(l,r,tr<<1,num);
	if(mid<r)update(l,r,tr<<1|1,num);
	pushup(tr);
}
int query(int l,int r,int tr){
	int ans=0;
	if(t[tr].l>=l&&t[tr].r<=r) {return t[tr].sum;}
	int mid=(t[tr].l+t[tr].r)>>1;
	if(l<=mid) ans+=query(l,r,tr<<1);
	if(mid<r) ans+=query(l,r,tr<<1|1);
	return ans;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
	}
	build(1,n,1);
	for(int i=1;i<=m;i++){
		scanf("%d",&q);
		if(q==1){
			scanf("%d%d%d",&x,&y,&k);
			update(x,y,1,k);
		}
		else {
			scanf("%d%d",&x,&y);
			cout<<query(x,y,1)<<endl;
		}
	}
}
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇