题解:[SCOI2010] 序列操作

洛谷P2572

Posted by TH911 on December 21, 2024

题目传送门

题意分析

我们需要维护一种数据结构来维护 $0,1$ 区间的信息,使其能够高效的获取区间内 $1$ 的数量和最多的连续的 $1$ 的数量,并能够高效的将指定区间赋为指定值和实现区间翻转(注意是值取反,而不是类似于字符串的翻转)。

维护区间信息

线段树节点维护信息

维护区间信息,自然而然地想到线段树

那么需要维护哪些信息呢?

一般地,区间边界 $l,r$。

对于操作 $3$,显然我们需要维护区间内 $1$ 的数量。

考虑到需要进行“区间翻转”,因此我们也要维护 $0$ 的信息(见下文)。

维护 $cnt_0,cnt_1$ 表示区间内 $0,1$ 分别出现的次数,当然,你可以使用 $(r-l+1)-cnt_1$ 来获取 $cnt_0$,但是为了方便和美观,我们不这么做。

维护两个懒标记 $set$ 和 $reverse$,$set$ 表示区间是否被赋值,初始值为 $-1$;$reverse$ 是一个 bool 变量,用于标记是否翻转。

维护 $len_1$ 表示区间内最多的连续的 $1$ 的数量,考虑到需要翻转,因此也维护 $len_0$。

同时还需要维护 $pre_0,pre_1,suf_0,suf_1$,具体见下文。

参考代码

1
2
3
4
5
6
struct node{
    int l,r;
    int cnt[2],len[2],pre[2],suf[2];
    int set;
    bool reverse;
}t[4*N+1];

区间合并与上传更新

这是一定需要考虑的,在 $up(p)$ 中合并区间(上传更新),$cnt_0,cnt_1$ 都好合并,直接相加即可;然而 $len_0,len_1$ 却不好合并,因为我们需要考虑跨越中线的情况。

比如说,合并 $\texttt{1101}$ 和 $\texttt{1100}$。如果直接取最大值,就会发现合并后新的 $len_1=2$。然而合并后为 $\texttt{11011100}$,明显答案为 $3$;这就是所谓“跨越中线”。

我们可以效仿平衡树维护区间信息,维护 $pre_0,pre_1,suf_0,suf_1$ 分别表示最长前缀 $0$ 的长度、最长前缀 $1$ 的长度、最长后缀 $0$ 的长度、最长后缀 $1$ 的长度。

那么,对于 $pre_0$ 的转移明显就是 $t[p].pre_0\leftarrow t[p\times 2].pre_0$。

一个特殊的情况就是,$t[p\times 2]$ 所维护的区间全是 $0$。这样,还要加上 $t[p\times 2+1].pre_0$。

对于剩下三个,同理可得。

此时我们再来维护 $len_0,len_1$。

以 $len_0$ 为例,仅仅需要用原来取的最大值再与 $t[p\times 2].suf_0+t[p\times 2+1].pre_0$ 取最大值即可。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void up(int p){
    t[p].cnt[0]=t[p<<1].cnt[0]+t[p<<1|1].cnt[0];
    t[p].cnt[1]=t[p<<1].cnt[1]+t[p<<1|1].cnt[1];
    t[p].len[0]=max(t[p<<1].len[0],t[p<<1|1].len[0]);
    t[p].len[0]=max(t[p].len[0],t[p<<1].suf[0]+t[p<<1|1].pre[0]);
    t[p].len[1]=max(t[p<<1].len[1],t[p<<1|1].len[1]);
    t[p].len[1]=max(t[p].len[1],t[p<<1].suf[1]+t[p<<1|1].pre[1]);

    t[p].suf[0]=t[p<<1|1].suf[0];
    if(t[p<<1|1].suf[0]==size(p<<1|1))t[p].suf[0]+=t[p<<1].suf[0];
    t[p].suf[1]=t[p<<1|1].suf[1];
    if(t[p<<1|1].suf[1]==size(p<<1|1))t[p].suf[1]+=t[p<<1].suf[1];

    t[p].pre[0]=t[p<<1].pre[0];
    if(t[p<<1].pre[0]==size(p<<1))t[p].pre[0]+=t[p<<1|1].pre[0];
    t[p].pre[1]=t[p<<1].pre[1];
    if(t[p<<1].pre[1]==size(p<<1))t[p].pre[1]+=t[p<<1|1].pre[1];
}

懒标记与下传更新

上文已经提到,设计了一个 $set$ 和一个 $reverse$,现在展开详细解释。

我们定义 $down(p)$ 来下传更新(懒标记)。

那么需要传递两种操作:区间推平和区间翻转。

为了方便,我们先考虑推平再考虑翻转

注意:在区间推平后应当取消区间翻转标记。

因为当 $reverse=1$ 的时候,实际上是先翻转、再赋为指定值,因此实际上翻转已经完成然后被推平了,因此不应该翻转。

但这又是个问题了,为什么不能是 $set$ 先有值然后再有 $reverse\leftarrow 1$ 呢?

其实是有的,但是这也不需要翻转啊……大家都一样,翻转完和没翻没有区别,因此不做考虑

其实还有一些问题,见下文的区间推平

关于 $set$ 和 $reverse$ 标记传递的顺序问题

update at $2025/1/25$:转载自我的回答


先推平再翻转

你可以这么理解为什么先传递推平标记要取消翻转标记:

假设有一个区间是 $\texttt{11001}$,那么将其打上翻转标记应当变为:$\texttt{00110}$。但是因为懒标记的原因,可能会变为:$\color{red}\texttt{11}\color{black}\texttt{110}$。红色部分表示因为懒标记而实际上没有被翻转的部分

此时我们再推平整个区间为 $\texttt{00000}$,那么前面两个 $\color{red}\texttt{11}$ 的翻转懒标记仍然存在

于是就会导致整个区间变为 $\color{red}\texttt{00}\color{black}\texttt{000}$,此时你再传递翻转懒标记就会把区间变为 $\texttt{11000}$,就不对了。

所以在区间推平的时候要取消翻转标记。


先翻转再推平

你当然也可以先传递翻转标记再传递推平标记。但是这样的话就可以不需要取消翻转标记了。(这种需要注意标记顺序问题 ,见下文)

因为你翻转之后推平了,无所谓。

比如说区间 $\texttt{10101}$,打上翻转标记变成 $\color{red}\texttt{10}\color{black}\texttt{010}$。此时你再推平会先传递翻转标记,整个区间会变成 $\texttt{01010}$,是正确的。然后你再传递推平标记即可。


标记顺序问题

其实还涉及到顺序问题

  • 先推平再翻转

    你在推平时直接取消翻转标记即可。传递时因为是先传递推平标记,再传递翻转标记,也符合标记的顺序,能够保证答案的正确性。

  • 先翻转再推平

    其实这种我没写过。你在翻转的时候需要先传递推平标记,即将推平标记的传递同时写入 pushdown() 和你的翻转函数。

    不然就会出现类似于标记先后顺序不对进而导致出错的情况。

    比如说有一个区间是 $\texttt{10110}$,区间推平为 $\texttt{00000}$,实际变为 $\color{red}\texttt{10}\color{black}\texttt{000}$。

    此时翻转函数将此区间翻转,就会变成 $\color{red}\texttt{01}\color{black}\texttt{111}$,然后你再传递之前的标记,于是区间就变成了:$\texttt{00111}$,显然不对。(应当变为 $\texttt{11111}$。)

    上文中之所以是正确的,是因为没有考虑在翻转之前有没有推平标记。

总而言之,就是在进行某一操作时,需要先传递另一操作的标记来保证正确性。

而区间推平的好处就是推平之后无论之前是否翻转,推平之后都一样,因此可以直接取消区间翻转的标记。


所以说题解里大多数都是建议你先传递推平标记再传递翻转标记,因为你推平时取消翻转标记仅仅需要一行:

1
t[p].reverse = false;

但是你先翻转再推平的时候在翻转函数内传递推平标记就会要写一堆(比如说下文我写的 set 函数)。


update at $2025/1/25\ 23:59$:先翻转再推平的话,可以考虑翻转推平标记,这样会简单一些。

对于赋值,没什么好说的,见代码。

对于翻转,注意不要去交换左右子树即可。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
void down(int p){
    if(t[p].set!=-1){
        t[p<<1].cnt[t[p].set]=size(p<<1);
        t[p<<1].cnt[t[p].set^1]=0;
        t[p<<1].len[t[p].set]=size(p<<1);
        t[p<<1].len[t[p].set^1]=0;
        t[p<<1].pre[t[p].set]=size(p<<1);
        t[p<<1].pre[t[p].set^1]=0;
        t[p<<1].suf[t[p].set]=size(p<<1);
        t[p<<1].suf[t[p].set^1]=0;
        t[p<<1].set=t[p].set;
        t[p<<1].reverse=false;

        t[p<<1|1].cnt[t[p].set]=size(p<<1|1);
        t[p<<1|1].cnt[t[p].set^1]=0;
        t[p<<1|1].len[t[p].set]=size(p<<1|1);
        t[p<<1|1].len[t[p].set^1]=0;
        t[p<<1|1].pre[t[p].set]=size(p<<1|1);
        t[p<<1|1].pre[t[p].set^1]=0;
        t[p<<1|1].suf[t[p].set]=size(p<<1|1);
        t[p<<1|1].suf[t[p].set^1]=0;
        t[p<<1|1].set=t[p].set;
        t[p<<1|1].reverse=false;

        t[p].set=-1;
    }

    if(t[p].reverse){
        swap(t[p<<1].cnt[0],t[p<<1].cnt[1]);
        swap(t[p<<1].len[0],t[p<<1].len[1]);
        swap(t[p<<1].pre[0],t[p<<1].pre[1]);
        swap(t[p<<1].suf[0],t[p<<1].suf[1]);
        t[p<<1].reverse=!t[p<<1].reverse;

        swap(t[p<<1|1].cnt[0],t[p<<1|1].cnt[1]);
        swap(t[p<<1|1].len[0],t[p<<1|1].len[1]);
        swap(t[p<<1|1].pre[0],t[p<<1|1].pre[1]);
        swap(t[p<<1|1].suf[0],t[p<<1|1].suf[1]);
        t[p<<1|1].reverse=!t[p<<1|1].reverse;

        t[p].reverse=false;
    }
}

区间推平

注意清除翻转标记。

然后也没什么好说的,见代码。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void set(int p,int l,int r,int k){
    if(l<=t[p].l&&t[p].r<=r){
        t[p].cnt[k]=size(p);
        t[p].cnt[k^1]=0;
        t[p].len[k]=size(p);
        t[p].len[k^1]=0;
        t[p].pre[k]=size(p);
        t[p].pre[k^1]=0;
        t[p].suf[k]=size(p);
        t[p].suf[k^1]=0;

        t[p].set=k;
        t[p].reverse=false;//注意清除!
        return;
    }
    down(p);
    if(l<=t[p<<1].r)set(p<<1,l,r,k);
    if(t[p<<1|1].l<=r)set(p<<1|1,l,r,k);
    up(p);
}

区间翻转

唯一需要注意的就是,$reverse$ 是 bool 变量,逻辑运算符 ! 取反即可。

然后就是直接交换四个信息。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void reverse(int p,int l,int r){
    if(l<=t[p].l&&t[p].r<=r){
        swap(t[p].cnt[0],t[p].cnt[1]);
        swap(t[p].len[0],t[p].len[1]);
        swap(t[p].pre[0],t[p].pre[1]);
        swap(t[p].suf[0],t[p].suf[1]);
        t[p].reverse=!t[p].reverse;
        return;
    }
    down(p);
    if(l<=t[p<<1].r)reverse(p<<1,l,r);
    if(t[p<<1|1].l<=r)reverse(p<<1|1,l,r);
    up(p);
}

查询区间内 $1$ 的个数

就是求 $cnt_1$ 意义上的区间和。

参考代码

1
2
3
4
5
6
7
8
int query(int p,int l,int r,int k){
    if(l<=t[p].l&&t[p].r<=r)return t[p].cnt[k];
    down(p);
    int ans=0;
    if(l<=t[p<<1].r)ans=query(p<<1,l,r,k);
    if(t[p<<1|1].l<=r)ans+=query(p<<1|1,l,r,k);
    return ans;
}

查询区间内最多连续 $1$ 的个数

首先在不考虑“跨越中线”的情况下,即求 $len$ 意义下的区间最大值。

然后我们考虑跨越中线,类似的也就是把左子树的后缀与右子树的前缀长度加起来就行,注意记得取 $\min$,防止超出给定区间

参考代码

1
2
3
4
5
6
7
8
9
int query_len(int p,int l,int r,int k){
    if(l<=t[p].l&&t[p].r<=r)return t[p].len[k];
    down(p);
    int ans=0;
    if(l<=t[p<<1].r)ans=query_len(p<<1,l,r,k);
    if(t[p<<1|1].l<=r)ans=max(ans,query_len(p<<1|1,l,r,k));
    ans=max(ans,min(t[p<<1].suf[k],t[p<<1].r-l+1) + min(t[p<<1|1].pre[k],r-t[p<<1|1].l+1));
    return ans;
}

AC 代码

整整 $5\text{KB}$……

但是实话实说,虽然看起来长而且实际上也长,但是很多都是“复制粘贴 + 微调”可以解决的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
const int N=1e5;
int n,m,a[N+1];
struct seg{
	struct node{
		int l,r;
		int cnt[2],len[2],pre[2],suf[2];
		int set;
		bool reverse;
	}t[4*N+1];
	
	void up(int p){
		t[p].cnt[0]=t[p<<1].cnt[0]+t[p<<1|1].cnt[0];
		t[p].cnt[1]=t[p<<1].cnt[1]+t[p<<1|1].cnt[1];
		t[p].len[0]=max(t[p<<1].len[0],t[p<<1|1].len[0]);
		t[p].len[0]=max(t[p].len[0],t[p<<1].suf[0]+t[p<<1|1].pre[0]);
		t[p].len[1]=max(t[p<<1].len[1],t[p<<1|1].len[1]);
		t[p].len[1]=max(t[p].len[1],t[p<<1].suf[1]+t[p<<1|1].pre[1]);
		
		t[p].suf[0]=t[p<<1|1].suf[0];
		if(t[p<<1|1].suf[0]==size(p<<1|1))t[p].suf[0]+=t[p<<1].suf[0];
		t[p].suf[1]=t[p<<1|1].suf[1];
		if(t[p<<1|1].suf[1]==size(p<<1|1))t[p].suf[1]+=t[p<<1].suf[1];
		
		t[p].pre[0]=t[p<<1].pre[0];
		if(t[p<<1].pre[0]==size(p<<1))t[p].pre[0]+=t[p<<1|1].pre[0];
		t[p].pre[1]=t[p<<1].pre[1];
		if(t[p<<1].pre[1]==size(p<<1))t[p].pre[1]+=t[p<<1|1].pre[1];
	}
	void build(int p,int l,int r){
		t[p].l=l,t[p].r=r;
		t[p].set=-1;
		if(l==r){
			t[p].cnt[a[l]]=1;
			t[p].len[a[l]]=1;
			t[p].pre[a[l]]=1;
			t[p].suf[a[l]]=1;
			return;
		}
		int mid=l+r>>1;
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		up(p);
	}
	int size(int p){
		return t[p].r-t[p].l+1;
	}
	void down(int p){
		if(t[p].set!=-1){
			t[p<<1].cnt[t[p].set]=size(p<<1);
			t[p<<1].cnt[t[p].set^1]=0;
			t[p<<1].len[t[p].set]=size(p<<1);
			t[p<<1].len[t[p].set^1]=0;
			t[p<<1].pre[t[p].set]=size(p<<1);
			t[p<<1].pre[t[p].set^1]=0;
			t[p<<1].suf[t[p].set]=size(p<<1);
			t[p<<1].suf[t[p].set^1]=0;
			t[p<<1].set=t[p].set;
			t[p<<1].reverse=false;
			
			t[p<<1|1].cnt[t[p].set]=size(p<<1|1);
			t[p<<1|1].cnt[t[p].set^1]=0;
			t[p<<1|1].len[t[p].set]=size(p<<1|1);
			t[p<<1|1].len[t[p].set^1]=0;
			t[p<<1|1].pre[t[p].set]=size(p<<1|1);
			t[p<<1|1].pre[t[p].set^1]=0;
			t[p<<1|1].suf[t[p].set]=size(p<<1|1);
			t[p<<1|1].suf[t[p].set^1]=0;
			t[p<<1|1].set=t[p].set;
			t[p<<1|1].reverse=false;
			
			t[p].set=-1;
		}
		
		if(t[p].reverse){
			swap(t[p<<1].cnt[0],t[p<<1].cnt[1]);
			swap(t[p<<1].len[0],t[p<<1].len[1]);
			swap(t[p<<1].pre[0],t[p<<1].pre[1]);
			swap(t[p<<1].suf[0],t[p<<1].suf[1]);
			t[p<<1].reverse=!t[p<<1].reverse;
			
			swap(t[p<<1|1].cnt[0],t[p<<1|1].cnt[1]);
			swap(t[p<<1|1].len[0],t[p<<1|1].len[1]);
			swap(t[p<<1|1].pre[0],t[p<<1|1].pre[1]);
			swap(t[p<<1|1].suf[0],t[p<<1|1].suf[1]);
			t[p<<1|1].reverse=!t[p<<1|1].reverse;
			
			t[p].reverse=false;
		}
	}
	void set(int p,int l,int r,int k){
		if(l<=t[p].l&&t[p].r<=r){
			t[p].cnt[k]=size(p);
			t[p].cnt[k^1]=0;
			t[p].len[k]=size(p);
			t[p].len[k^1]=0;
			t[p].pre[k]=size(p);
			t[p].pre[k^1]=0;
			t[p].suf[k]=size(p);
			t[p].suf[k^1]=0;
			
			t[p].set=k;
			t[p].reverse=false;
			return;
		}
		down(p);
		if(l<=t[p<<1].r)set(p<<1,l,r,k);
		if(t[p<<1|1].l<=r)set(p<<1|1,l,r,k);
		up(p);
	}
	void reverse(int p,int l,int r){
		if(l<=t[p].l&&t[p].r<=r){
			swap(t[p].cnt[0],t[p].cnt[1]);
			swap(t[p].len[0],t[p].len[1]);
			swap(t[p].pre[0],t[p].pre[1]);
			swap(t[p].suf[0],t[p].suf[1]);
			t[p].reverse=!t[p].reverse;
			return;
		}
		down(p);
		if(l<=t[p<<1].r)reverse(p<<1,l,r);
		if(t[p<<1|1].l<=r)reverse(p<<1|1,l,r);
		up(p);
	}
	int query(int p,int l,int r,int k){
		if(l<=t[p].l&&t[p].r<=r)return t[p].cnt[k];
		down(p);
		int ans=0;
		if(l<=t[p<<1].r)ans=query(p<<1,l,r,k);
		if(t[p<<1|1].l<=r)ans+=query(p<<1|1,l,r,k);
		return ans;
	}
	int query_len(int p,int l,int r,int k){
		if(l<=t[p].l&&t[p].r<=r){
			return t[p].len[k];
		}
		down(p);
		int ans=0;
		if(l<=t[p<<1].r){
			ans=query_len(p<<1,l,r,k);
		}
		if(t[p<<1|1].l<=r){
			ans=max(ans,query_len(p<<1|1,l,r,k));
		}
		ans=max(ans,min(t[p<<1].suf[k],t[p<<1].r-l+1) + min(t[p<<1|1].pre[k],r-t[p<<1|1].l+1));
		return ans;
	}
}t;
void solve0(int l,int r){
	t.set(1,l,r,0);
}
void solve1(int l,int r){
	t.set(1,l,r,1);
}
void solve2(int l,int r){
	t.reverse(1,l,r);
}
void solve3(int l,int r){
	printf("%d\n",t.query(1,l,r,1));
}
void solve4(int l,int r){
	printf("%d\n",t.query_len(1,l,r,1));
}
typedef void (*fp)(int,int);
fp solve[5]={solve0,solve1,solve2,solve3,solve4};//函数指针
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	scanf("%d %d",&n,&m);
	for(int i=1;i<=n;i++)scanf("%d",a+i);
	t.build(1,1,n);
	while(m--){
		int op,l,r;
		scanf("%d %d %d",&op,&l,&r);
		l++,r++;
		solve[op](l,r);
	}
	
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}