HDU 5785 Interesting【Manacher】

题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=5785

题意:

给定长度为$n$的序列,求所有相邻回文串$[i,j]和[j + 1, k]$中$i times k$的和。
数据范围:$0 le n le 1000000$

分析:

我们设$l[i]$表示以$i$为右端点的所有回文串的左端点之和,$r[i]$表示以$i$为左端点的所有回文串的右端点之和,那么最终答案即为$sum_{i = 1}^n l[i] times r[i + 1]$。
我们先跑个马拉车求出数组$p$,然后枚举中心$i$,得到以$i$为中心的新串中的回文串左右端点为$L = i - p[i] + 1, R = i + p[i] + 1$
先考虑右端点,则有$r[L] += R, r[L + 1] += R - 1…r[i] += i$,我们发现每一项都比前一项少加一个$1$,
利用差分思想,设$dr[i]$表示后一项与前一项之差。初始情况$r[L] += R, dr[L + 1] += -1$,然后在区间结束处打个标记,即$dr[i + 1] += 1, r[i + 1] -= i$,在统计答案的时候不断累加并更新$dl$和$l$即可。
$r$数组同理可得。过程类似于树状数组。
最后从头到尾扫一遍,累加一下答案即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
> File Name: 5785.cpp
> Author: jiangyuzhu
> Created Time: 2016/10/4 16:25:08
************************************************************************/
#include<cstdio>
#include<cstring>
#include<queue>
#include<vector>
#include<set>
#include<map>
#include<algorithm>
using namespace std;
const int maxn = 1e6 + 5, mod = 1e9 + 7;
typedef long long ll;
char s[maxn], str[maxn << 1];
int p[maxn << 1];
int k, n;
void ()
{
str[0] = '$';
str[1] = '#';
for(int i = 0; s[i]; i++){
str[i * 2 + 2] = s[i];
str[i * 2 + 3] = '#';
}
n = k * 2 + 2;
str[n] = 0;
}
int manacher()
{
prepare();
int maxx = 0;
int id;
for(int i = 1; i < n; i++){
if(maxx > i){
p[i] = min(p[2 * id - i], maxx - i);
}else p[i] = 1;
while(str[i - p[i]] == str[i + p[i]]) p[i]++;
if(p[i] + i > maxx){
maxx = p[i] + i;
id = i;
}
}
int ans = 0;
for(int i = 0; i < n; i++){
p[i]--;
ans = max(ans , p[i]);
}
return ans;
}
ll dr[maxn << 1], dl[maxn << 1];
ll r[maxn << 1], l[maxn << 1];
inline void MOD(ll &x)
{
if(x < 0) x += mod;
if(x >= mod) x %= mod;
}
ll inv = 500000004;
int main (void)
{
while(~scanf("%s", s)){
k = strlen(s);
memset(p, 0, sizeof(p));
manacher();
memset(l, 0, sizeof(l));
memset(r, 0, sizeof(r));
memset(dr, 0, sizeof(dr));
memset(dl, 0, sizeof(dl));
for(int i = 1; i < n; ++i){
int L = i - p[i] + 1;
int R = i + p[i] - 1;
r[L] += R; MOD(r[L]);
dr[L + 1] -= 1; MOD(dr[L + 1]);
r[i + 1] -= i; MOD(r[i + 1]);
dr[i + 1] += 1; MOD(dr[i + 1]);
l[i] += i; MOD(l[i]);
dl[i + 1] -= 1; MOD(dl[i + 1]);
l[R + 1] -= L; MOD(l[R + 1]);
dl[R + 1] += 1; MOD(dl[R + 1]);
}
for(int i = 1; i < n; ++i){
dr[i] += dr[i - 1]; MOD(dr[i]);
r[i] += r[i - 1] + dr[i]; MOD(r[i]);
dl[i] += dl[i - 1]; MOD(dl[i]);
l[i] += l[i - 1] + dl[i]; MOD(l[i]);
}
ll ans = 0;
for(int i = 2; i < n - 2; i += 2){
(ans += l[i] * inv % mod * r[i + 2] % mod * inv % mod) %= mod;
}
printf("%lldn", ans);
}
return 0;
}