「bzoj5210」最大连通子块和,动态DP

Description

给出一棵n个点、以1为根的有根树,点有点权。要求支持如下两种操作:

M x y:将点x的点权改为y;

Q x:求以x为根的子树的最大连通子块和。

其中,一棵子树的最大连通子块和指的是:该子树所有子连通块的点权和中的最大值

(本题中子连通块包括空连通块,点权和为0)。

Input

第一行两个整数n、m,表示树的点数以及操作的数目。

第二行n个整数,第i个整数w_i表示第i个点的点权。

接下来的n-1行,每行两个整数x、y,表示x和y之间有一条边相连。

接下来的m行,每行输入一个操作,含义如题目所述。保证操作为M x y或Q x之一。

1≤n,m≤200000 ,任意时刻 |w_i|≤10^9 。

Output

对于每个Q操作输出一行一个整数,表示询问子树的最大连通子块和。

Sample Input

5 4
3 -2 0 3 -1
1 2
1 3
4 2
2 5
Q 1
M 4 1
Q 1
Q 2

Sample Output

4
3
1

Solution

对于没有修改的版本,一个DP就可以解决

$f[x] = v[x] + sum max{0, f[u]}$

对于带修改的版本,使用动态DP

定义以下状态:

$F[x]$表示以$x$为根的联通块的答案

$H[x]$表示$x$子树中答案$max$

$LH[x]$表示$x$轻儿子的答案$max$

$LF[x]$表示$sum F[u]$

动态DP上,重链的转移(将重链表示成序列$v_1, v_2…v_n$, 按深度从小到大排序,$v_n$一定没有子节点)

对于链上的DP,构建矩阵转移:(定义加法为取$max$,乘法为加法)

化简一下转移矩阵的相乘:

从转移矩阵中提取答案,事实上就是用向量$[0,0,0]$去乘这一块的转移矩阵,也就是$f=max(a,d),h=max(b,c)$

于是对于一个链,用线段树记录一下一个区间的转移矩阵,再用链底的来乘一下就好了

对于子树的DP,直接线段树加和/取max即可

写的时候注意一下链上线段树的乘法要用右边乘左边(也就是树上下面乘上面)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 200005;
const ll INF = 4e17;

char buf[1 << 20], *p1, *p2;
#define GC (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? 0 : *p1 ++)
inline int _R() {
int d = 0;char t;bool ty = 1;
while (t = GC, (t < '0' || t > '9') && t != '-');
t == '-' ? (ty = 0) : (d = t - '0');
while (t = GC, t >= '0' && t <= '9') d = (d << 3) + (d << 1) + t - '0';
return ty ? d : -d;
}
inline void _S(char *c) {
char *t = c, ch;
while (ch = GC, ch == ' ' || ch == 'n' || ch == 'r') ;
*t ++ = ch;
while (ch = GC, ch != ' ' && ch == 'n' && ch == 'r') *t ++ = ch;
*t = 0;
}

int n, m, w[N];
int Tote, Last[N], Next[N << 1], End[N << 1];
void (int x, int y) {
End[++ Tote] = y, Next[Tote] = Last[x], Last[x] = Tote;
End[++ Tote] = x, Next[Tote] = Last[y], Last[y] = Tote;
}

struct sb_data;
struct tr_matrix;

struct sb_data {
ll f, h;
sb_data operator + (const sb_data& rhs) const {
return (sb_data) {f + rhs.f, max(h, rhs.h)};
}
tr_matrix to_tr_matrix();
} ;

struct tr_matrix {
ll a, b, c, d;
tr_matrix operator + (const tr_matrix& rhs) const {
return (tr_matrix) {a + rhs.a, max(a + rhs.b, b),
max(d + rhs.b, max(c, rhs.c)), max(rhs.d, d + rhs.a)};
}
sb_data to_sb_data();
} ;

tr_matrix sb_data :: to_tr_matrix() { return (tr_matrix) {f, f, h, 0}; }
sb_data tr_matrix :: to_sb_data() { return (sb_data) {max(a, d), max(b, c)}; }

namespace Point {
struct Node {
Node *Son[2];
sb_data val;
} pool[N * 6], *null, *tl, *root[N];

void _init() {
null = tl = pool;
null -> Son[0] = null -> Son[1] = null;
null -> val = (sb_data) {0, 0};
for (int i = 0; i <= n; i ++) root[i] = null;
}

void _pushup(Node *p) {
p -> val = p -> Son[0] -> val + p -> Son[1] -> val;
}
void Modify(Node *&p, int l, int r, int k, const sb_data& d) {
if (p == null) p = ++ tl, p -> Son[0] = p -> Son[1] = null;
if (l == r) { p -> val = d; return; }
int mid = l + r >> 1;
if (k <= mid) Modify(p -> Son[0], l, mid, k, d);
else Modify(p -> Son[1], mid + 1, r, k, d);
_pushup(p);
}
}

namespace Chain {
struct Node {
Node *Son[2];
tr_matrix val;
} pool[N * 6], *null, *tl, *root[N];

void _init() {
null = tl = pool;
null -> Son[0] = null -> Son[1] = null;
null -> val = (tr_matrix) {0, -INF, -INF, -INF};
for (int i = 0; i <= n; i ++) root[i] = null;
}

void _pushup(Node *p) {
p -> val = p -> Son[1] -> val + p -> Son[0] -> val;
}
void Modify(Node *&p, int l, int r, int k, const tr_matrix& d) {
if (p == null) p = ++ tl, p -> Son[0] = p -> Son[1] = null;
if (l == r) { p -> val = d; return; }
int mid = l + r >> 1;
if (k <= mid) Modify(p -> Son[0], l, mid, k, d);
else Modify(p -> Son[1], mid + 1, r, k, d);
_pushup(p);
}
tr_matrix Query(Node *p, int l, int r, int x, int y) {
if (x <= l && r <= y) return p -> val;
int mid = l + r >> 1;
tr_matrix lm = (tr_matrix) {0, -INF, -INF, -INF}, rm = (tr_matrix) {0, -INF, -INF, -INF};
if (x <= mid) lm = Query(p -> Son[0], l, mid, x, y);
if (mid < y) rm = Query(p -> Son[1], mid + 1, r, x, y);
return rm + lm;
}
}


int hson[N], dep[N], Cnt[N];
int totln, Bln[N], Pos_ch[N], Pos_ln[N], Pos[N], Len[N];
int dfs_init(int x, int fa) {
int sz = 1, maxx = 0, tmp, i, u;
dep[x] = dep[fa] + 1;
Cnt[x] = 2;
for (i = Last[x]; i; i = Next[i])
if (u = End[i], Cnt[x] ++, u != fa)
if (tmp = dfs_init(u, x), sz += tmp, tmp > maxx)
maxx = tmp, hson[x] = u;
return sz;
}

int dfs_make(int x, int tp, int fa) {
Bln[x] = totln;
Pos[x] = dep[x] - dep[tp] + 1;
Len[totln] ++;

if (hson[x]) dfs_make(hson[x], tp, x);

Point :: Modify(Point :: root[x], 1, Cnt[x], 1, (sb_data) {w[x], 0});
for (int u, j = 2, i = Last[x]; i; i = Next[i], j ++)
if (u = End[i], u != fa && u != hson[x]) {
++ totln;
Pos_ch[totln] = j;
Pos_ln[totln] = x;
dfs_make(u, u, x);
Point :: Modify(Point :: root[x], 1, Cnt[x], j, Chain :: root[Bln[u]] -> val.to_sb_data());
}
Chain :: Modify(Chain :: root[Bln[x]], 1, Len[Bln[x]], Pos[x], Point :: root[x] -> val.to_tr_matrix());
}

int main() {
int i, j, k, x, y, u, v;
n = _R(), m = _R();
for (i = 1; i <= n; i ++) w[i] = _R();
for (i = 1; i < n; i ++) {
j = _R(), k = _R();
ADDE(j, k);
}

Point :: _init();
Chain :: _init();
dfs_init(1, 0);
dfs_make(1, 1, 0);


char opt[6];
sb_data ans;
for (i = 1; i <= m; i ++) {
_S(opt);
if (opt[0] == 'Q') {
x = _R();
k = Bln[x];
ans = (Chain :: Query(Chain :: root[k], 1, Len[k], Pos[x], Len[k])).to_sb_data();
printf("%lldn", ans.h);
}
else {
x = _R(), w[x] = _R();
Point :: Modify(Point :: root[x], 1, Cnt[x], 1, (sb_data) {w[x], 0});
Chain :: Modify(Chain :: root[Bln[x]], 1, Len[Bln[x]], Pos[x], Point :: root[x] -> val.to_tr_matrix());
for (k = Bln[x]; k; k = v) {
u = Pos_ln[k];
v = Bln[u];
Point :: Modify(Point :: root[u], 1, Cnt[u], Pos_ch[k], Chain :: root[k] -> val.to_sb_data());
Chain :: Modify(Chain :: root[v], 1, Len[v], Pos[u], Point :: root[u] -> val.to_tr_matrix());
}
}
}
}