「CF2222G」Statistics on Tree

Description

Link:CF2222G

给出一个大小为 nn 的树,定义点对 (u,v)(u, v) (1uvn1 \leq u \leq v \leq n) 的价值为:删除从 uuvv 的路径上的所有边之后,最大连通块的大小。

对于每个 ii (1in1 \leq i \leq n),求出价值等于 ii 的点对 (u,v)(u, v) 的数量。

数据范围:1n1051 \leq n \leq 10^5

时空限制:22s / 512512MiB。

Solution

在定根的情况下,枚举点对 (u,v)(u, v)lca(u,v)=z\mathrm{lca}(u, v) = z,定义:

  • aua_u:删除 zzuu 的所有边之后,不包含 zz 的最大连通块大小。
  • bub_uzzuu 路径上第二个点的子树大小。

那么点对 (u,v)(u, v) 的价值为 max(au,av,nbubv)\max(a_u, a_v, n - b_u - b_v)。发现直接做的话非常难统计,因为这都已经比卷积更加严格了。

考虑以树的重心为根,树的重心有一个很好的性质:所有子树的大小都不超过 n2\frac{n}{2}

这样的话,当点对的 lca(u,v)\mathrm{lca}(u, v) 不等于根时,外子树(包含 zz 的连通块)的大小必定超过 n2\frac{n}{2},所以此时一定是外子树更大。所以当 lca(u,v)\mathrm{lca}(u, v) 不等于根时,点对 (u,v)(u, v) 的价值为 nbubvn - b_u - b_v

这是一个卷积的形式。对于 zz 的每一个儿子 xx,相当于是给 szxsz_x 次项加上 szxsz_x,然后进行卷积。NTT 显然过不去,但是这个多项式比较稀疏,可以直接进行暴力多项式乘法(前提是要先去重)!可以证明时间复杂度是 O(nlogn)\mathcal{O}(n \log n) 的:

  • 对整棵树进行重链剖分,显然轻儿子的子树大小之和不超过 O(nlogn)\mathcal{O}(n \log n)
  • 轻儿子与轻儿子之间的复杂度:设 zz 的所有轻儿子的子树大小之和为 kk,则本质不同的子树大小不超过 O(k)\mathcal{O}(\sqrt{k}),于是暴力多项式乘法的复杂度为 O(k)\mathcal{O}(k)。于是总和不超过 O(nlogn)\mathcal{O}(n \log n)
  • 重儿子与轻儿子之间的复杂度:不超过点 zz 的度数。于是总和不超过 O(n)\mathcal{O}(n)

还要额外统计一下 lca(u,v)\mathrm{lca}(u, v) 等于根的贡献。

先考虑 max(au,av)\max(a_u, a_v) 决定了点对价值的情况,根据定义显然有 aubua_u \leq b_u,假设 aunbubva_u \geq n - b_u - b_v(不妨设 auava_u \geq a_v)那么有 2bu+bvn2b_u + b_v \geq n,所以必有 bun3b_u \geq \frac{n}{3}bvn3b_v \geq \frac{n}{3}

也就是说,当 max(au,av)\max(a_u, a_v) 决定了点对 (u,v)(u, v) 的价值时,u,vu, v 的其中一个必定位于大小超过 n3\frac{n}{3} 的子树内,然而这样的子树不超过两个。可以枚举这样的大子树去统计贡献,相当于是做一次二维数点。

再考虑 nbubvn - b_u - b_v 决定了点对价值的情况,对于每一种 bub_u,用一个 std::vector 存放 bub_u 相同的所有的 aa 值。依然去重之后暴力去枚举 bu,bvb_u, b_v 的值,那么此时要求 au,av<nbubva_u, a_v < n - b_u - b_v,可以在各自的 std::vector 里面二分数出有多少个符合条件的 aa

注意来自同一子树的贡献还要额外排除。

细节较多。

时间复杂度 O(nlogn)\mathcal{O}(n \log n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
#include <bits/stdc++.h>

using i64 = long long;

#define debug(a) std::cout << #a << '=' << (a) << ' '

template <class T>
inline void chmin(T &x, const T &y) {
if (x > y) {
x = y;
}
}
template <class T>
inline void chmax(T &x, const T &y) {
if (x < y) {
x = y;
}
}

const int N = 100100;

int n;
std::vector<std::vector<int>> G;

int sz[N], mp[N];
int rt;

void getSize(int u, int fu) {
sz[u] = 1;
for (int v : G[u]) {
if (v == fu) {
continue;
}
getSize(v, u);
sz[u] += sz[v];
}
}
void getRoot(int u, int fu) {
mp[u] = 0;
for (int v : G[u]) {
if (v == fu) {
continue;
}
getRoot(v, u);
chmax(mp[u], sz[v]);
}
chmax(mp[u], n - sz[u]);
if (!rt || mp[u] < mp[rt]) {
rt = u;
}
}

i64 ans[N];

void solve(int u, int fu) {
std::vector<std::pair<int, int>> tmp, seq;
for (int v : G[u]) {
if (v == fu) {
continue;
}
solve(v, u);

if (u == rt) {
continue;
}

ans[n - sz[v]] += sz[v]; // v 在 u 的子树内
ans[n - 2 * sz[v]] -= 1ll * sz[v] * (sz[v] - 1) / 2; // 容斥
tmp.push_back({sz[v], sz[v]});
}

if (u == rt) {
return;
}

std::sort(tmp.begin(), tmp.end());
for (auto [v, c] : tmp) {
if (!seq.size() || seq.back().first != v) {
seq.push_back({v, c});
} else {
seq.back().second += c;
}
}

for (int i = 0; i < seq.size(); i ++) {
for (int j = i; j < seq.size(); j ++) {
if (i == j) {
ans[n - 2 * seq[i].first] += 1ll * seq[i].second * (seq[i].second - 1) / 2;
} else {
ans[n - seq[i].first - seq[j].first] += 1ll * seq[i].second * seq[j].second;
}
}
}
}

int a[N], b[N];
void dfs_init(int u, int fu, int mp, int bel) {
a[u] = std::max(mp, sz[u]), b[u] = bel;
for (int v : G[u]) {
if (v == fu) {
continue;
}
dfs_init(v, u, std::max(mp, sz[u] - sz[v]), bel);
}
}

std::vector<int> sa[N], tmp;
void find(int u, int fu) {
tmp.push_back(a[u]);
for (int v : G[u]) {
if (v == fu) {
continue;
}
find(v, u);
}
}

struct BIT {
int c[N];

void init() {
for (int i = 1; i <= n; i ++) {
c[i] = 0;
}
}

void add(int x, int y) {
for (; x; x -= x & -x) {
c[x] += y;
}
}

int ask(int x) {
int ans = 0;
for (; x <= n; x += x & -x) {
ans += c[x];
}
return ans;
}
} bit[2];

std::vector<std::pair<int, int>> id;
void get(int u, int fu, int type) {
id.push_back({u, type});
for (int v : G[u]) {
if (v == fu) {
continue;
}
get(v, u, type);
}
}

void work() {
std::cin >> n;

G.assign(n + 1, {});
for (int i = 1; i < n; i ++) {
int x, y;
std::cin >> x >> y;
G[x].push_back(y);
G[y].push_back(x);
}

getSize(1, 0);
rt = 0;
getRoot(1, 0);
getSize(rt, 0);

for (int i = 1; i <= n; i ++) {
ans[i] = 0;
}

ans[n] += n; // u = v;

solve(rt, 0);
// debug(rt) << '\n';

for (int v : G[rt]) {
dfs_init(v, rt, 0, sz[v]);
}
for (int i = 1; i <= n; i ++) {
if (i != rt) {
ans[std::max(a[i], n - b[i])] ++; // u = rt, v 在 rt 的子树内
// debug(i), debug(a[i]), debug(b[i]) << '\n';
}
}

for (int i = 1; i <= n; i ++) {
sa[i].clear();
}
for (int i = 1; i <= n; i ++) {
if (i != rt) {
sa[b[i]].push_back(a[i]);
}
}
for (int i = 1; i <= n; i ++) {
std::sort(sa[i].begin(), sa[i].end());
}

std::vector<int> seq;
for (int i = 1; i <= n; i ++) {
if (sa[i].size()) {
seq.push_back(i);
}
}
for (int i = 0; i < seq.size(); i ++) {
for (int j = i; j < seq.size(); j ++) {
int v1 = seq[i], v2 = seq[j];
int x = n - v1 - v2;

int c1 = std::lower_bound(sa[v1].begin(), sa[v1].end(), x) - sa[v1].begin();
int c2 = std::lower_bound(sa[v2].begin(), sa[v2].end(), x) - sa[v2].begin();
if (v1 != v2) {
ans[x] += 1ll * c1 * c2;
} else {
ans[x] += 1ll * c1 * (c1 - 1) / 2;
}
}
}

for (int v : G[rt]) {
tmp.clear();
find(v, rt);

std::sort(tmp.begin(), tmp.end());
int x = n - 2 * sz[v];

int c = std::lower_bound(tmp.begin(), tmp.end(), x) - tmp.begin();
ans[x] -= 1ll * c * (c - 1) / 2; // 容斥
}

std::sort(G[rt].begin(), G[rt].end(), [&] (int x, int y) -> bool {
return sz[x] > sz[y];
});

if (G[rt].size() > 1) {
bit[0].init(), bit[1].init();
id.clear();

get(G[rt][0], rt, 0);
for (int i = 1; i < G[rt].size(); i ++) {
get(G[rt][i], rt, 1);
}

std::sort(id.begin(), id.end(), [&] (auto x, auto y) -> bool {
return a[x.first] < a[y.first];
});
for (auto [u, t] : id) {
ans[a[u]] += bit[t ^ 1].ask(std::max(1, n - a[u] - b[u]));
bit[t].add(b[u], 1);
}
}

if (G[rt].size() > 2) {
bit[0].init(), bit[1].init();
id.clear();

get(G[rt][1], rt, 0);
for (int i = 2; i < G[rt].size(); i ++) {
get(G[rt][i], rt, 1);
}

std::sort(id.begin(), id.end(), [&] (auto x, auto y) -> bool {
return a[x.first] < a[y.first];
});
for (auto [u, t] : id) {
ans[a[u]] += bit[t ^ 1].ask(std::max(1, n - a[u] - b[u]));
bit[t].add(b[u], 1);
}
}

for (int i = 1; i <= n; i ++) {
std::cout << ans[i] << " \n"[i == n];
}
}

int main() {
std::ios::sync_with_stdio(0);
std::cin.tie(0);

int T;
std::cin >> T;

while (T --) {
work();
}

return 0;
}

/*
1
5
1 2
2 3
2 4
2 5

1
7
3 4
1 5
2 3
2 6
5 2
7 3
*/