BZOJ4650 [Noi2016]优秀的拆分

感觉这道题的方法在SA里还是很常见的
考虑枚举AA和BB的分界点,对答案的贡献是以当前为结尾的XX的个数和以当前为结束的XX的个数
因此实际上要计算以每个字符开头的XX(两个相同字符串)的个数,同理也要算以他结尾的
枚举X字符串的长度$len$,每隔长度$len$设一个关键点
那么每一个XX一定过两个相邻的关键点
枚举每一对相邻的关键点,考虑其可能对那些位置产生贡献(所谓对某个位置产生贡献就是所这个位置开头的XX字符串个数+1)
请看下图,我们求出两个位置的最长公共前缀和最长公共后缀

此时两者没有交点,也就是长度之和$le len$,不可能产生贡献(因为总会每截断)
但是下面这种情况

两者有了交点,就会产生贡献
两条橙色线都是可以产生的XX,在两者之间的每一条同样长度的字符串也是
因此紫色的那一段向后的XX会$+1$,棕色的那一段有向前的XX会$+1$
区间加一可以直接差分
至于LCP和LCS就是后缀数组的常见操作了,前后要构造两个

1.枚举 $len$ ,每隔 $len$ 设置关键点:这个的复杂度是调和级数 $O(n log n)$
2.求 后缀LCP,前缀LCS:使用后缀数组 + st 表 做到 $O(1)$ 查询
3.区间加上 1 : 差分维护就可以了。

记住这个方法,我叫他“调和级数法”

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

using namespace std;
const int maxn = 300010;
char s[maxn];
int sa[2][maxn], t1[maxn], t2[maxn], h[2][maxn], rnk[2][maxn], c[maxn], f[2][maxn][30], lg2[maxn], n;
int st[maxn], ed[maxn];
void (int m, int w) {
int *x = t1, *y = t2;
for(int i = 0; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[x[i] = s[i]]++;
for(int i = 1; i <= m; i++) c[i] += c[i-1];
for(int i = n; i; i--) sa[w][c[x[i]]--] = i;
for(int k = 1; k <= n; k <<= 1) {
int p = 0;
for(int i = 0; i <= m; i++) c[i] = 0;
for(int i = n - k + 1; i <= n; i++) y[++p] = i;
for(int i = 1; i <= n; i++) if(sa[w][i] > k) y[++p] = sa[w][i] - k;
for(int i = 1; i <= n; i++) c[x[y[i]]]++;
for(int i = 1; i <= m; i++) c[i] += c[i-1];
for(int i = n; i; i--) sa[w][c[x[y[i]]]--] = y[i];
swap(x, y);
p = 1; x[sa[w][1]] = 1;
for(int i = 2; i <= n; i++)
x[sa[w][i]] = y[sa[w][i]] == y[sa[w][i-1]] && y[sa[w][i] + k] == y[sa[w][i-1] + k] ? p : ++p;
if(p > n) break;
m = p;
}
for(int i = 1; i <= n; i++) rnk[w][sa[w][i]] = i;
for(int i = 1, k = 0; i <= n; i++) {
if(k) k--;
if(rnk[w][i] == 1) continue;
int j = sa[w][rnk[w][i] - 1];
while(s[i + k] == s[j + k]) k++;
h[w][rnk[w][i]] = k;
}
for(int i = 1; i <= n; i++) f[w][i][0] = h[w][i];
for(int k = 1; k <= 20; k++)
for(int i = 1; i + (1 << k) - 1 <= n; i++) f[w][i][k] = min(f[w][i][k-1], f[w][i + (1 << (k-1))][k-1]);
}
int calc(int w, int l, int r) {
l = rnk[w][l];
r = rnk[w][r];
if(l > r) swap(l, r);

l++;
int k = lg2[r - l + 1];
return min(f[w][l][k], f[w][r - (1 << k) + 1][k]);
}

int main() {

//freopen("testdata.ans", "w", stdout);
int T; scanf("%d", &T);
lg2[2] = 1; for(int i = 3; i <= 300000; i++) lg2[i] = lg2[i >> 1] + 1;
while(T--) {
memset(sa, 0, sizeof(sa));
memset(h, 0, sizeof(h));
memset(f, 0, sizeof(f));
memset(st, 0, sizeof(st));
memset(ed, 0, sizeof(ed));
memset(t1, 0, sizeof(t1));
memset(t2, 0, sizeof(t2));
memset(rnk, 0, sizeof(rnk));
scanf("%s", s + 1); n = strlen(s + 1);
build(128, 0);
reverse(s + 1, s + 1 + n);
build(128, 1);
//for(int i = 1; i <= n; i++) cout << sa[0][i] << ' '; cout << endl;
//for(int i = 1; i <= n; i++) cout << rnk[0][i] << ' '; cout << endl;
for(int len = 1; len + len <= n; len++) {
for(int i = 1; (i + 1) * len <= n; i++) {
int x = i * len, y = x + len;
int tail = calc(0, x, y), head = calc(1, n + 1 - x, n + 1 - y);
if(head + tail - 1 < len) continue;
tail = min(tail, len); head = min(head, len);
//printf("%d %d %d %dn", x, y, head, tail);
st[x - head + 1]++; st[x + tail - len + 1]--;
ed[y - head + len]++; ed[y + tail]--;
}
}
for(int i = 1; i <= n; i++) st[i] += st[i-1], ed[i] += ed[i-1];
//for(int i = 1; i <= n; i++) cout << st[i] << ' '; cout << endl;
//for(int i = 1; i <= n; i++) cout << ed[i] << ' '; cout << endl;
long long ans = 0;
for(int i = 3; i < n; i++) ans += st[i] * ed[i - 1];
printf("%lldn", ans);
}
return 0;
}