题解:[BJOI2018] 链上二次求和

洛谷P4458

Posted by TH911 on February 3, 2026

题目传送门

题意分析

原来线段树可以维护单点加多项式。

原链即序列,设为 $a_1,a_2,a_3,\cdots,a_n$。

记 $\displaystyle s_n=\sum_{i=1}^na_i,t_n=\sum_{i=1}^ns_i$。

那么容易发现,答案为: \(\begin{aligned} \sum_{d=l}^r\sum_{i=1}^{n-d+1}\sum_{j=i}^{i+d-1}a_j&=\sum_{d=l}^r\sum_{i=1}^{n-d+1}(s_{i+d-1}-s_{i-1})\\ &=\sum_{d=l}^r\left(\sum_{i=d}^ns_i-\sum_{i=1}^{n-d}s_i\right)\\ &=\sum_{d=l}^r(t_n-t_{d-1}-t_{n-d})\\ &=(r-l+1)t_n-\sum_{i=l-1}^{r-1}t_i-\sum_{i=n-r}^{n-l}t_i \end{aligned}\)

考虑线段树维护 $t_i$ 区间和。

那么在区间 $[l,r]$ 加上 $d$ 时,可以发现:

$i$ $l$ $l+1$ $l+2$ $\cdots$ $r$ $r+1$ $r+2$ $\cdots$
$\Delta a_i$ $+d$ $+d$ $+d$ $\cdots$ $+d$ $+d$ $+d$ $\cdots$
$\Delta s_i$ $+d$ $+2d$ $+3d$ $\cdots$ $+(r-l+1)d$ $+(r-l+1)d$ $+(r-l+1)d$ $\cdots$
$\Delta t_i$ $+d$ $+3d$ $+6d$ $\cdots$ $\small+\dfrac{(r-l+1)(r-l+2)}{2}d$ $\small+\dfrac{(r-l+1)(r-l+2)}{2}d+(r-l+1)d$ $\small+\dfrac{(r-l+1)(r-l+2)}{2}d+2(r-l+1)d$ $\cdots$
  • 若 $l\leq i\leq r$,则 $s_i$ 加上了 $(i-l+1)d$。

    那么:

    \[\begin{aligned} \Delta t_i&=\sum_{j=l}^i\Delta s_i\\\\ &=\sum_{j=l}^i(j-l+1)d\\ &=\dfrac{(i-l+1)(i-l+2)}{2}d\\ &=\dfrac12d\cdot i^2+\dfrac12d(-2l+3)\cdot i+\dfrac12d(l-1)(l-2) \end{aligned}\]
  • 若 $r<i$,则 $s_i$ 加上了 $(r-l+1)d$。

    那么:

    \[\begin{aligned} \Delta t_i&=\sum_{j=l}^i\Delta s_i\\ &=\Delta t_r+\sum_{j=r+1}^i\Delta s_i\\ &=\dfrac{(r-l+1)(r-l+2)}2d+(i-r)(r-l+1)d\\ &=(r-l+1)d\cdot i+\dfrac{(r-l+1)(r-l+2)}2d-r(r-l+1)d \end{aligned}\]

线段树维护每个点加关于 $i$ 的多项式即可,打三个懒标记 $a,b,c$,表示子区间 $[l,r]$ 还需要加上 $\displaystyle\sum_{i=l}^r\left(ai^2+bi+c\right)$。

同时需要计算区间平方和,可以预先计算前缀和后差分,也可以套公式 $\displaystyle\sum_{i=1}^ni^2=\dfrac{n(n+1)(2n+1)}6$。

AC 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
constexpr const int N=200000,P=1000000007,inv2=500000004,inv6=166666668;
int n,a[N+1],pre[N+1];
struct segTree{
	struct node{
		int l,r;
		int value,a,b,c;
		
		int quadratic(){
			return (pre[r]-pre[l-1])%P;
		}
		int linear(){
			return (l+r)*(r-l+1ll)%P*inv2%P;
		}
		int size(){
			return r-l+1;
		}
	}t[N<<2|1];
	
	void up(int p){
		t[p].value=(t[p<<1].value+t[p<<1|1].value)%P;
	}
	void build(int p,int l,int r){
		t[p]={l,r};
		if(l==r){
			t[p].value=a[l];
			return;
		}
		int mid=l+r>>1;
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		up(p);
	}
	void down(int p){
		if(t[p].a){
			t[p<<1].value=(t[p<<1].value+1ll*t[p<<1].quadratic()*t[p].a)%P;
			t[p<<1].a=(t[p<<1].a+t[p].a)%P;
			t[p<<1|1].value=(t[p<<1|1].value+1ll*t[p<<1|1].quadratic()*t[p].a)%P;
			t[p<<1|1].a=(t[p<<1|1].a+t[p].a)%P;
			t[p].a=0;
		}
		if(t[p].b){
			t[p<<1].value=(t[p<<1].value+1ll*t[p<<1].linear()*t[p].b)%P;
			t[p<<1].b=(t[p<<1].b+t[p].b)%P;
			t[p<<1|1].value=(t[p<<1|1].value+1ll*t[p<<1|1].linear()*t[p].b)%P;
			t[p<<1|1].b=(t[p<<1|1].b+t[p].b)%P;
			t[p].b=0;
		}
		if(t[p].c){
			t[p<<1].value=(t[p<<1].value+1ll*t[p<<1].size()*t[p].c)%P;
			t[p<<1].c=(t[p<<1].c+t[p].c)%P;
			t[p<<1|1].value=(t[p<<1|1].value+1ll*t[p<<1|1].size()*t[p].c)%P;
			t[p<<1|1].c=(t[p<<1|1].c+t[p].c)%P;
			t[p].c=0;
		}
	}
	void add(int p,int l,int r,int a,int b,int c){
		if(l>r){
			return;
		}
		if(l<=t[p].l&&t[p].r<=r){
			t[p].value=(t[p].value+1ll*t[p].quadratic()*a)%P;
			t[p].a=(t[p].a+a)%P;
			t[p].value=(t[p].value+1ll*t[p].linear()*b)%P;
			t[p].b=(t[p].b+b)%P;
			t[p].value=(t[p].value+1ll*t[p].size()*c)%P;
			t[p].c=(t[p].c+c)%P;
			return;
		}
		down(p);
		if(l<=t[p<<1].r){
			add(p<<1,l,r,a,b,c);
		}
		if(t[p<<1|1].l<=r){
			add(p<<1|1,l,r,a,b,c);
		}
		up(p);
	}
	int query(int p,int l,int r){
		if(l<=t[p].l&&t[p].r<=r){
			return t[p].value;
		}
		down(p);
		int ans=0;
		if(l<=t[p<<1].r){
			ans=query(p<<1,l,r);
		}
		if(t[p<<1|1].l<=r){
			ans=(ans+query(p<<1|1,l,r))%P;
		}
		return ans;
	}
}t;
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	
	int m;
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>a[i];
		pre[i]=(pre[i-1]+1ll*i*i)%P;
	}
	for(int i=1;i<=n;i++){
		a[i]=(a[i-1]+a[i])%P;
	}
	for(int i=1;i<=n;i++){
		a[i]=(a[i-1]+a[i])%P;
	}
	t.build(1,1,n);
	while(m--){
		int op,l,r,d,pl,pl2;
		cin>>op>>l>>r;
		if(l>r){
			swap(l,r);
		}
		switch(op){
			case 1:
				int d;
				cin>>d;
				pl=1ll*d*inv2%P;
				t.add(1,l,r,pl,pl*(-2ll*l+3)%P,pl*(l-1ll)%P*(l-2)%P);
				pl2=(r-l+1ll)*d%P;
				t.add(1,r+1,n,0,pl2,(-1ll*r*pl2%P+(r-l+1ll)*(r-l+2)%P*pl%P)%P);
				break;
			case 2:
				int ans=(((r-l+1ll)*t.query(1,n,n)%P-t.query(1,l-1,r-1))%P-t.query(1,n-r,n-l))%P;
				if(ans<0){
					ans+=P;
				}
				cout<<ans<<'\n';
				break; 
		}
	}
	
	cout.flush();
	
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}