「QOJ 9904」最小生成树

Description

Link:QOJ 9904

给出一个包含 nn 个点的完全图,节点编号为 1n1 \sim n,其中 i,j(ij)i, j(i \neq j) 之间的边权为 ai+ja_{i + j},求这张图的最小生成树。

数据范围:2n2×1052 \leq n \leq 2 \times 10^51ai1091 \leq a_i \leq 10^9

时空限制:33s / 256256MiB。

Solution

算法一:二分 + 线段树维护哈希

考虑 kruskal 的想法,按照边权从小到大考虑每条边,能连就连。假设现在考虑的是 apa_p

  • pnp \leq n 时,需要连接的边为 (1,p1),(2,p2),(1, p - 1), (2, p - 2), \cdots
  • p>np > n 时,需要连接的边为 (pn,n),(pn+1,n1),(p - n, n), (p - n + 1, n - 1), \cdots

发现这形似一个回文的区间 [l,r][l, r],需要连接的边为

(l,r),(l+1,r1),,(l+i,ri),(l, r), (l + 1, r - 1), \cdots, (l + i, r - i), \cdots

但其实有效连边只有 n1n - 1 次,这么多边很多都是无用的,我们希望快速地寻找有效边 (l+i,ri)(l + i, r - i)

启发式合并维护连通块,设 fax\mathrm{fa}_x 表示 xx 所在的连通块编号。

考虑每次找出第一个有效边,相当于是要找出一个最小的 ii 使得 fal+ifari\mathrm{fa}_{l + i} \neq \mathrm{fa}_{r - i}。可以使用线段树维护 fa\mathrm{fa} 的区间哈希值,二分寻找最小的 ii 即可。

时间复杂度 O(nlog2n)\mathcal{O}(n \log^2 n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#include <bits/stdc++.h>

#define debug(a) std::cout << #a << "=" << (a) << ' '

using s64 = long long;
using u64 = unsigned long long;

/* 取 min */
template <class T>
inline void chmin(T &x, const T &y) {
if (x > y) {
x = y;
}
}
/* 取 max */
template <class T>
inline void chmax(T &x, const T &y) {
if (x < y) {
x = y;
}
}

/* ----- ----- ----- 正文 ----- ----- ----- */

const int N = 200100;

int n;

int fa[N];
std::vector<int> hold[N];

const int P = 13331;
const int mod1 = 1e9 + 7, mod2 = 1e9 + 9;

int power1[N], power2[N];

struct hash1 {
int l, h;
hash1() {}
hash1(int _l, int _h) : l(_l), h(_h) {}
hash1 operator + (const hash1 &rhs) const {
return hash1(l + rhs.l, (1ll * h * power1[rhs.l] + rhs.h) % mod1);
}
};
struct hash2 {
int l, h;
hash2() {}
hash2(int _l, int _h) : l(_l), h(_h) {}
hash2 operator + (const hash2 &rhs) const {
return hash2(l + rhs.l, (1ll * h * power2[rhs.l] + rhs.h) % mod2);
}
};

namespace SGT {
struct node {
hash1 pre1, suf1;
hash2 pre2, suf2;
} t[N * 4];

void upd(int p) {
t[p].pre1 = t[p * 2].pre1 + t[p * 2 + 1].pre1;
t[p].suf1 = t[p * 2 + 1].suf1 + t[p * 2].suf1;
t[p].pre2 = t[p * 2].pre2 + t[p * 2 + 1].pre2;
t[p].suf2 = t[p * 2 + 1].suf2 + t[p * 2].suf2;
}

void build(int p, int l, int r) {
if (l == r) {
t[p].pre1 = t[p].suf1 = hash1(1, l);
t[p].pre2 = t[p].suf2 = hash2(1, l);
return;
}
int mid = (l + r) >> 1;
build(p * 2, l, mid), build(p * 2 + 1, mid + 1, r);
upd(p);
}

void change(int p, int l, int r, int x, int y) {
if (l == r) {
t[p].pre1 = t[p].suf1 = hash1(1, y);
t[p].pre2 = t[p].suf2 = hash2(1, y);
return;
}
int mid = (l + r) >> 1;
if (x <= mid) {
change(p * 2, l, mid, x, y);
} else {
change(p * 2 + 1, mid + 1, r, x, y);
}
upd(p);
}

hash1 ask1(int p, int l, int r, int s, int e, int type) { // 0 pre, 1 suf
if (s <= l && r <= e) {
return type == 0 ? t[p].pre1 : t[p].suf1;
}
int mid = (l + r) >> 1;
if (s <= mid && mid < e) {
if (type == 0) {
return ask1(p * 2, l, mid, s, e, type) + ask1(p * 2 + 1, mid + 1, r, s, e, type);
} else {
return ask1(p * 2 + 1, mid + 1, r, s, e, type) + ask1(p * 2, l, mid, s, e, type);
}
}
if (s <= mid) {
return ask1(p * 2, l, mid, s, e, type);
} else {
return ask1(p * 2 + 1, mid + 1, r, s, e, type);
}
}
hash2 ask2(int p, int l, int r, int s, int e, int type) { // 0 pre, 1 suf
if (s <= l && r <= e) {
return type == 0 ? t[p].pre2 : t[p].suf2;
}
int mid = (l + r) >> 1;
if (s <= mid && mid < e) {
if (type == 0) {
return ask2(p * 2, l, mid, s, e, type) + ask2(p * 2 + 1, mid + 1, r, s, e, type);
} else {
return ask2(p * 2 + 1, mid + 1, r, s, e, type) + ask2(p * 2, l, mid, s, e, type);
}
}
if (s <= mid) {
return ask2(p * 2, l, mid, s, e, type);
} else {
return ask2(p * 2 + 1, mid + 1, r, s, e, type);
}
}
};

void link(int p, int q) {
p = fa[p], q = fa[q];
if (hold[p].size() > hold[q].size()) {
std::swap(p, q);
}
for (int x : hold[p]) {
fa[x] = q, hold[q].push_back(x);
SGT::change(1, 1, n, x, q);
}
}

int main() {
std::ios::sync_with_stdio(0);
std::cin.tie(0);

std::cin >> n;

std::vector< std::pair<int, int> > seq(2 * n - 3);
for (int i = 3; i <= 2 * n - 1; i ++) {
int v;
std::cin >> v;
seq[i - 3] = {v, i};
}

std::sort(seq.begin(), seq.end());

for (int i = 1; i <= n; i ++) {
fa[i] = i;
hold[i].push_back(i);
}

power1[0] = 1;
for (int i = 1; i <= n; i ++) power1[i] = 1ll * power1[i - 1] * P % mod1;

power2[0] = 1;
for (int i = 1; i <= n; i ++) power2[i] = 1ll * power2[i - 1] * P % mod2;

SGT::build(1, 1, n);

s64 ans = 0;
for (auto [v, p] : seq) {
int le, rg;
if (p <= n) {
le = 1, rg = p - 1;
} else {
le = p - n, rg = n;
}

int ok =
(SGT::ask1(1, 1, n, le, rg, 0).h == SGT::ask1(1, 1, n, le, rg, 1).h) &&
(SGT::ask2(1, 1, n, le, rg, 0).h == SGT::ask2(1, 1, n, le, rg, 1).h);

if (ok) {
continue;
}

while (1) {
int l = 0, r = (rg - le + 1) / 2;
while (l < r) {
int mid = (l + r + 1) >> 1;
int ok =
(SGT::ask1(1, 1, n, le, le + mid - 1, 0).h == SGT::ask1(1, 1, n, rg - mid + 1, rg, 1).h) &&
(SGT::ask2(1, 1, n, le, le + mid - 1, 0).h == SGT::ask2(1, 1, n, rg - mid + 1, rg, 1).h);
if (ok) {
l = mid;
} else {
r = mid - 1;
}
}

if (l == (rg - le + 1) / 2) {
break;
} else {
link(le + l, rg - l);
ans += v;
}
}
}

std::cout << ans << '\n';

return 0;
}

/**
* 心中无女人
* 比赛自然神
* 模板第一页
* 忘掉心上人
**/

算法二:倍增并查集

基本的讨论同算法一。现在考虑如何快速连边。

回顾一下 「SCOI2016」萌萌哒 的做法

Link:「SCOI2016」萌萌哒

简单的想法是,直接将两个区间的对应位置连边,数出最后有几个连通块。

对于普通的并查集,若 x,yx, y 在同一个集合,则代表 x,yx, y 相同。但普通并查集只能处理单点之间的连通信息。

考虑“倍增并查集”,若 x,yx, y 在第 k(k=0,1,,log2n)k(k = 0, 1, \cdots, \log_2 n) 个并查集中在同一个集合,则代表区间 [x,x+2k1],[y,y+2k1][x, x + 2^k - 1], [y, y + 2^k - 1] 相同。相当于是区间与区间合并

每次将两个区间二进制分解成不超过 log2n\log_2 n 个长度为二的幂次的区间,然后将这些对应区间合并。

但最后要处理的是单点之间的连通信息。我们就按照 k=log2n,,1,0k = \log_2 n, \cdots, 1, 0 的顺序,依次将区间的连通信息下放(将一个长度为 2k2^k 的区间,拆成前后两个长度为 2k12^{k - 1} 的区间)。

一共 log2n\log_2 n 层,每层 O(n)O(n) 个节点。时间复杂度 O(nlogn)\mathcal{O}(n \log n)

沿用该做法。只不过我们每次要合并的区间是呈回文状的,即一个区间从前往后,与另一个区间从后往前合并。

考虑启用扩展域并查集,若 x,y+nx, y + n 在第 kk 个并查集中在同一个集合,则代表区间 [x,x+2k1][x, x + 2^k - 1] 从前往后,与区间 [y2k+1,y][y - 2^k + 1, y] 从后往前相匹配。

每次合并的时候,还需要得到连了几条边。这可以递归处理:若 x,yx, y 在第 kk 个并查集中不连通,先将当前区间合并,然后将当前区间拆成前后两个长度为 2k12^{k - 1} 的区间,继续递归合并。若递归到 k=0k = 0,则说明需要连边。

时间复杂度 O(nlogn)\mathcal{O}(n \log n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <bits/stdc++.h>

#define debug(a) std::cout << #a << "=" << (a) << ' '

using s64 = long long;
using u64 = unsigned long long;

/* 取 min */
template <class T>
inline void chmin(T &x, const T &y) {
if (x > y) {
x = y;
}
}
/* 取 max */
template <class T>
inline void chmax(T &x, const T &y) {
if (x < y) {
x = y;
}
}

/* ----- ----- ----- 正文 ----- ----- ----- */

const int N = 200100;

int n, t;

struct DSU {
std::vector<int> fa;

DSU() {}
DSU(int n) {
fa.resize(n + 1);
std::iota(fa.begin(), fa.end(), 0);
}

int get(int x) {
return fa[x] == x ? x : fa[x] = get(fa[x]);
}

void merge(int x, int y) {
fa[get(x)] = get(y);
}
};

std::vector<DSU> dsu;

int merge(int k, int x, int y) {
if (dsu[k].get(x) == dsu[k].get(y + n)) return 0;
dsu[k].merge(x, y + n);
if (k == 0) {
return 1;
} else {
return merge(k - 1, x, y) + merge(k - 1, x + (1 << (k - 1)), y - (1 << (k - 1)));
}
}

int main() {
std::ios::sync_with_stdio(0);
std::cin.tie(0);

std::cin >> n, t = std::__lg(n);

dsu.assign(t + 1, DSU(2 * n));
for (int i = 1; i <= n; i ++) {
dsu[0].fa[i + n] = i;
}

std::vector< std::pair<int, int> > seq(2 * n - 3);
for (int i = 3; i <= 2 * n - 1; i ++) {
int v;
std::cin >> v;
seq[i - 3] = {v, i};
}

std::sort(seq.begin(), seq.end());

s64 ans = 0;
for (auto [v, p] : seq) {
int l, r;
if (p <= n) {
l = 1, r = p - 1;
} else {
l = p - n, r = n;
}

int k = std::__lg((r - l + 1) / 2);
ans += 1ll * v * (merge(k, l, r) + merge(k, l + (1 << k), r - (1 << k)));
}

std::cout << ans << '\n';

return 0;
}

/**
* 心中无女人
* 比赛自然神
* 模板第一页
* 忘掉心上人
**/