avatar

segment-tree-beats

课件

这是个学习笔记,主要是给自己看的啦。

下面所述的不特别说明皆为O(nlogn)O(nlogn)的,大大拓宽了线段树的应用范围。

 

区间取maxmax,区间加,区间求和

维护最小值,次小值,和。

当介于最小值和次小值之间修改,复杂度是对的。

uoj515

http://uoj.ac/problem/515

单点修改,询问 ax,,ana_{x},\cdots,a_{n} 的不同的后缀最小值个数。

倒着按坐标扫描线,按询问时间建立线段树。

要支持区间取minmin,单点询问这个点被取minmin了多少次。

这个问题其实比区间求和简单,因为标记很容易下传,使用相同的递归方式就行了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <bits/stdc++.h>
#define ld double
#define ull unsigned long long
#define ll long long
#define pii pair <int, int>
#define iiii pair <int, pii >
#define mp make_pair
#define INF 1000000000
#define rep(i, x) for(int (i) = 0; (i) < (x); (i)++)
#define getchar() (*input_pos++)
const int TT = 5e7;
char input_buffer[TT], output_buffer[TT];
char *input_pos = input_buffer, *output_pos = output_buffer;
inline int getint() {
int x = 0, p = 1; char c = getchar();
while (c <= 32) c = getchar();
if (c == 45) p = -p, c = getchar();
while (c > 32) x = x * 10 + c - 48, c = getchar();
return x * p;
}
void write(int x) {
if (!x) return;
write(x / 10); *output_pos++ = '0' + x % 10;
}
void writeln(int x) {
write(x);
*output_pos++ = '\n';
}
using namespace std;
const int mod = 1e9 + 7;
inline void reduce(int &x) { x += x >> 31 & mod; }
inline int mul(int x, int y) { return 1ll * x * y % mod; }
//ruogu_alter
const int N = 1e6 + 5;
int n, q, res[N], mx[N << 2], mx2[N << 2], tag[N << 2];
vector<iiii> a[N];
vector<int> vq[N];
bool fg[N];
pii lst[N];
//
void pd(int k) {
if (tag[k]) {
if (mx[2 * k + 1] > mx[k]) tag[2 * k + 1] += tag[k], mx[2 * k + 1] = mx[k];
if (mx[2 * k + 2] > mx[k]) tag[2 * k + 2] += tag[k], mx[2 * k + 2] = mx[k];
tag[k] = 0;
}
}
void up(int k) {
mx[k] = max(mx[2 * k + 1], mx[2 * k + 2]);
if (mx[2 * k + 1] == mx[2 * k + 2]) {
mx2[k] = max(mx2[2 * k + 1], mx2[2 * k + 2]);
}
else if (mx[2 * k + 1] > mx[2 * k + 2]) {
mx2[k] = max(mx2[2 * k + 1], mx[2 * k + 2]);
}
else {
mx2[k] = max(mx[2 * k + 1], mx2[2 * k + 2]);
}
}
void modify(int l, int r, int x, int y, int k, int v) {
if (l >= y || x >= r || v >= mx[k]) return;
if (x <= l && r <= y && v > mx2[k]) {
++tag[k]; mx[k] = v;
return;
}
int mid = (l + r) >> 1; pd(k);
modify(l, mid, x, y, 2 * k + 1, v);
modify(mid, r, x, y, 2 * k + 2, v);
up(k);
}
int qry(int l, int r, int p, int k) {
if (r - l == 1) return tag[k];
int mid = (l + r) >> 1; pd(k);
if (p < mid) return qry(l, mid, p, 2 * k + 1);
else return qry(mid, r, p, 2 * k + 2);
}
int main() {
fread(input_buffer, 1, TT, stdin);
memset(mx2, -1, sizeof(mx2));
rep(i, N << 2) mx[i] = INF;
n = getint(); q = getint();
for (int i = 0; i < n; i++) lst[i] = mp(getint(), 0);
rep(i, q) {
int op = getint(), x = getint() - 1;
if (op == 1) {
int y = getint();
a[x].emplace_back(lst[x].first, mp(lst[x].second, i));
lst[x] = mp(y, i);
}
else {
vq[x].emplace_back(i);
fg[i] = true;
}
}
rep(i, n) a[i].emplace_back(lst[i].first, mp(lst[i].second, q));
for (int i = n - 1; i >= 0; i--) {
for (auto &u : a[i]) {
modify(0, q, u.second.first, u.second.second, 0, u.first);
}
for (auto &u : vq[i]) res[u] = qry(0, q, u, 0);
}
rep(i, q) if (fg[i]) writeln(res[i]);
fwrite(output_buffer, 1, output_pos - output_buffer, stdout);
return 0;
}

区间取maxmax,区间加,区间求和,求区间最大值/区间历史最大值

维护一个形如(x,y)(x,y)的标记,表示ai=max(ai+x,y)a_i=max(a_i+x,y)。这个标记是很好合并的,如果(x,y)(x,y)附加上一个(z,w)(z,w)的标记,那么就变成(x+z,max(y+z,w))(x+z, max(y+z,w))了。

uoj164

http://uoj.ac/problem/164

这题只要查询单点值/单点历史最大值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <bits/stdc++.h>
#define ld double
#define ull unsigned long long
#define ll long long
#define pii pair <int, int>
#define iiii pair <int, pii >
#define mp make_pair
#define INF 1000000000
#define rep(i, x) for(int (i) = 0; (i) < (x); (i)++)
inline int getint() {
int x = 0, p = 1; char c = getchar();
while (c <= 32) c = getchar();
if (c == 45) p = -p, c = getchar();
while (c > 32) x = x * 10 + c - 48, c = getchar();
return x * p;
}
using namespace std;
const int mod = 1e9 + 7;
inline void reduce(int &x) { x += x >> 31 & mod; }
inline int mul(int x, int y) { return 1ll * x * y % mod; }
//ruogu_alter
const int N = 5e5 + 5;
int n, q, a[N];
const ll inf = 2e18;
struct node {
ll x, y, tx, ty;
node() { x = y = tx = ty = 0; }
inline void append(const node b) {
node ans;
ans.x = max(-inf, x + b.x);
ans.y = max(y + b.x, b.y);
ans.tx = max(tx, x + b.tx);
ans.ty = max(ty, max(y + b.tx, b.ty));
(*this) = ans;
}
inline bool c() { return x || y || tx || ty; }
} dat[N << 2], I;
//
inline void pd(int k) {
if (dat[k].c()) {
dat[2 * k + 1].append(dat[k]);
dat[2 * k + 2].append(dat[k]);
dat[k] = I;
}
}
void modify(int l, int r, int x, int y, int k, node v) {
if (l >= y || x >= r) return;
if (x <= l && r <= y) {
dat[k].append(v);
return;
}
int mid = (l + r) >> 1; pd(k);
modify(l, mid, x, y, 2 * k + 1, v);
modify(mid, r, x, y, 2 * k + 2, v);
}
pair<ll, ll> qry(int l, int r, int p, int k) {
if (r - l == 1) return mp(max(dat[k].x + a[l], dat[k].y), max(dat[k].tx + a[l], dat[k].ty));
int mid = (l + r) >> 1; pd(k);
if (p < mid) return qry(l, mid, p, 2 * k + 1);
else return qry(mid, r, p, 2 * k + 2);
}
int main() {
n = getint(); q = getint();
rep(i, n) a[i] = getint();
rep(qqq, q) {
int op = getint();
if (op <= 3) {
int l = getint() - 1, r = getint() - 1, x = getint();
node u;
if (op == 1) u.x = u.tx = x, u.y = u.ty = -inf;
if (op == 2) u.x = u.tx = -x, u.y = u.ty = 0;
if (op == 3) u.x = u.tx = -inf, u.y = u.ty = x;
modify(0, n, l, r + 1, 0, u);
}
else {
pair<ll, ll> res = qry(0, n, getint() - 1, 0);
if (op == 4) printf("%lld\n", res.first);
if (op == 5) printf("%lld\n", res.second);
}
}
return 0;
}

区间取maxmax,区间加,区间求和,求区间最小值/区间历史最小值

这个时候不能像上述一样直接合并了。

我们用初始的方法维护该区间取最小值的tagtag和不取最小值的tagtag。根据初始的方法,区间取maxmax的时候只有这个值介于(mn,se)(mn,se)的时候才会修改,就相当于给该区间所有取mnmn的区间加。下传的时候可以这样:

1
2
3
4
5
6
7
8
9
inline void pd(int k) {
if (dat[k].c()) {
bool fx = (dat[2 * k + 1].mn <= dat[2 * k + 2].mn);
bool fy = (dat[2 * k + 1].mn >= dat[2 * k + 2].mn);
pd(dat[2 * k + 1], dat[k], fx);
pd(dat[2 * k + 2], dat[k], fy);
dat[k].cl();
}
}

uoj169

http://uoj.ac/problem/169

区间加/区间取maxmax/区间求最小值/区间历史最小值。

要维护七个东西:

最小值,次小值,历史最小值,最小值加多少标记,非最小值加多少标记,最小值历史最小加多少标记,非最小值历史最小加多少标记。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <bits/stdc++.h>
#define ld double
#define ull unsigned long long
#define ll long long
#define pii pair <int, int>
#define iiii pair <int, pii >
#define mp make_pair
#define INF 1000000000
#define rep(i, x) for(int (i) = 0; (i) < (x); (i)++)
inline int getint() {
int x = 0, p = 1; char c = getchar();
while (c <= 32) c = getchar();
if (c == 45) p = -p, c = getchar();
while (c > 32) x = x * 10 + c - 48, c = getchar();
return x * p;
}
using namespace std;
const int mod = 1e9 + 7;
inline void reduce(int &x) { x += x >> 31 & mod; }
inline int mul(int x, int y) { return 1ll * x * y % mod; }
//ruogu_alter
const int N = 5e5 + 5;
int n, q, a[N];
const int inf = 0x7f7f7f7f;
struct node {
int mn, se, mn2;
int tgmn, tgmn2;
int tgse, tgse2;
inline void cl() { tgmn = tgmn2 = tgse = tgse2 = 0; }
inline bool c() { return tgmn || tgmn2 || tgse || tgse2; }
node() {
cl(); se = inf;
}
} dat[N << 2];
//
inline void up(int k) {
dat[k].mn = min(dat[2 * k + 1].mn, dat[2 * k + 2].mn);
dat[k].mn2 = min(dat[2 * k + 1].mn2, dat[2 * k + 2].mn2);
if (dat[2 * k + 1].mn == dat[2 * k + 2].mn) {
dat[k].se = min(dat[2 * k + 1].se, dat[2 * k + 2].se);
}
else if (dat[2 * k + 1].mn < dat[2 * k + 2].mn) {
dat[k].se = min(dat[2 * k + 1].se, dat[2 * k + 2].mn);
}
else {
dat[k].se = min(dat[2 * k + 1].mn, dat[2 * k + 2].se);
}
}
inline void pd(node &x, node y, bool fg) {
if (!fg) {
y.tgmn = y.tgse;
y.tgmn2 = y.tgse2;
}
x.mn += y.tgmn;
x.mn2 = min(x.mn2, x.mn - y.tgmn + y.tgmn2);
x.se = min(inf, x.se + y.tgse);
x.tgmn += y.tgmn;
x.tgmn2 = min(x.tgmn2, x.tgmn - y.tgmn + y.tgmn2);
x.tgse += y.tgse;
x.tgse2 = min(x.tgse2, x.tgse - y.tgse + y.tgse2);
}
inline void pd(int k) {
if (dat[k].c()) {
bool fx = (dat[2 * k + 1].mn <= dat[2 * k + 2].mn);
bool fy = (dat[2 * k + 1].mn >= dat[2 * k + 2].mn);
pd(dat[2 * k + 1], dat[k], fx);
pd(dat[2 * k + 2], dat[k], fy);
dat[k].cl();
}
}
void build(int l, int r, int k) {
if (r - l == 1) {
dat[k].mn = dat[k].mn2 = a[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, 2 * k + 1);
build(mid, r, 2 * k + 2);
up(k);
}
void modify(int l, int r, int x, int y, int k, int v) {
if (l >= y || x >= r) return;
if (x <= l && r <= y) {
dat[k].mn += v;
dat[k].mn2 = min(dat[k].mn, dat[k].mn2);
dat[k].se = min(dat[k].se + v, inf);
dat[k].tgmn += v;
dat[k].tgmn2 = min(dat[k].tgmn, dat[k].tgmn2);
dat[k].tgse += v;
dat[k].tgse2 = min(dat[k].tgse, dat[k].tgse2);
return;
}
int mid = (l + r) >> 1; pd(k);
modify(l, mid, x, y, 2 * k + 1, v);
modify(mid, r, x, y, 2 * k + 2, v);
up(k);
}
void modify2(int l, int r, int x, int y, int k, int v) {
if (l >= y || x >= r || v <= dat[k].mn) return;
if (x <= l && r <= y && v < dat[k].se) {
dat[k].tgmn += v - dat[k].mn;
dat[k].mn = v;
return;
}
int mid = (l + r) >> 1; pd(k);
modify2(l, mid, x, y, 2 * k + 1, v);
modify2(mid, r, x, y, 2 * k + 2, v);
up(k);
}
int qry(int l, int r, int x, int y, int k) {
if (l >= y || x >= r) return inf;
if (x <= l && r <= y) return dat[k].mn;
int mid = (l + r) >> 1; pd(k);
return min(qry(l, mid, x, y, 2 * k + 1), qry(mid, r, x, y, 2 * k + 2));
}
int qry2(int l, int r, int x, int y, int k) {
if (l >= y || x >= r) return inf;
if (x <= l && r <= y) return dat[k].mn2;
int mid = (l + r) >> 1; pd(k);
return min(qry2(l, mid, x, y, 2 * k + 1), qry2(mid, r, x, y, 2 * k + 2));
}
int main() {
n = getint(); q = getint();
rep(i, n) a[i] = getint();
build(0, n, 0);
rep(qq, q) {
int op = getint(), l = getint() - 1, r = getint() - 1;
if (op == 1) modify(0, n, l, r + 1, 0, getint());
if (op == 2) modify2(0, n, l, r + 1, 0, getint());
if (op == 3) printf("%d\n", qry(0, n, l, r + 1, 0));
if (op == 4) printf("%d\n", qry2(0, n, l, r + 1, 0));
}
return 0;
}

有两个数组,aabib_imaxmaxaabb区间加,求aibia_i-b_i最小值。

维护222^2aibia_i-b_i最小值:aia_i取不取最小值,bib_i取不取最小值。

可以拓展到kk个数组,复杂度O(nlogn2k)O(nlogn2^k)

文章作者: ruogu
文章链接: http://ruogu-alter.github.io/2020/04/07/segment-tree-beats/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 ruogu's blog

评论