hdu5287Fast wyh2000 Transform

题目

$C_i=sum_{j+k==i}(^i_j)A_jB_k(mod 3)$,已知$A_i,B_i$,求$C_i$

Solution

我一看,诶,$frac{C_i}{i!}=sum_{j+k==i}frac{A_j}{j!}frac{B_k}{k!}​$,这不是FFT裸题(模数自动忽略,看题目时还以为是减小难度的)?

写了以后发现,woc,$i!%3$全是$0$!!!

模拟赛上只好写了个$n^2$暴力,假设$f(i)$表示$i!$中有多少个$3$,

那么$frac{C_i}{i!}=sum_{j+k==i,f(j)+f(k)==f(i)}frac{A_j}{j!}frac{B_k}{k!}$(这里$i!$已经把质因子中所有$3​$都删掉了)

下面讲正解

我们先把$A,B$的长度凑成$3^k$,然后分治

每次把$A,B$分成三段,分别为$a_0,a_1,a_2$,$b_0,b_1,b_2​$

考虑两个约束条件(第一个约束条件其实可以用第二个解释,我是打完这篇博客才发现的。。。)

$1.f(j)+f(k)==f(i)$

我们本来要把$a_0,a_1,a_2$,$b_0,b_1,b_2$两两相乘的,但是$a_2b_1$,$a_1b_2$,$a_2b_2$可以不用乘

我以$a_2b_1$来举例子好了,其他两个同理

假设我们分成的三部分分别为$0…k-1$,$k…2k-1$,$2k…3k-1$

$a_2$对应$2k…3k-1$,$b_1$对应$k…2k-1$

那么$a_2b_1$中每个数都在$[3k,5k-2]$的范围内

由于$b_1$中一定有$k$,那么$a_2b_1$中的$3k$与之抵消(因为消掉的都是第一个数,所以不影响别的东西),平白无故多出一个$3$,所以$f(j)+f(k)$不可能等于$f(i)​$

$2.(^i_j)$

现在我们要算的是$a_0b_0​$,$a_0b_1+a_1b_0​$,$a_0b_2+a_1b_1+a_2b_0​$,至于为什么要分三部分,看看它们每部分下标的和就知道了

可是我们还没考虑$(^i_j)​$这个系数

想到组合数,就一定会想到卢卡斯定理:$(^i_j)$=$(^{i/3}{j/3})$ $(^{i%3}{j%3})$(不知道为什么公式打不出来,这句可以忽略)

我们惊奇地发现,这和我们分治的结构非常的相似,$imod3​$,$jmod3​$就对应了$a,b​$的下标

系数乘上,就变成了$a_0b_0​$,$a_0b_1+a_1b_0​$,$a_0b_2+2a_1b_1+a_2b_0​$

此时的时间复杂度为$O(n^{log_3^6})​$

事实上,我们只需要求$(a_0+a_1)(b_0+b_1)​$,$(a_0+a_2)(b_0+b_2)​$,$a_0b_0​$,$a_1b_1​$,$a_2b_2​$,就可以把上面六个表示出来

这样,时间复杂度就变成了$O(n^{log_3^5})$

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

using namespace std;
inline char (){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
int x=0,fl=1;char ch=gc();
for(;ch<48||ch>57;ch=gc())if(ch=='-')fl=-1;
for(;48<=ch&&ch<=57;ch=gc())x=(x<<3)+(x<<1)+(ch^48);
return x*fl;
}
char pbuf[100000],*pp=pbuf;
inline void pc(char ch){if(pp==pbuf+100000)fwrite(pbuf,1,100000,stdout),pp=pbuf;*pp++=ch;}
inline void flush(){fwrite(pbuf,1,pp-pbuf,stdout);}
const int N=200002;
int f[12][N],g[12][N],ans[12][N],L,n,i,T;
void solve(int d,int l,int L){
if (L==1){
ans[d][l]=f[d][l]*g[d][l];
return;
}
L/=3;
for (int i=l;i<l+L;i++){
f[d+1][i]=f[d][i]+f[d][i+L];
g[d+1][i]=g[d][i]+g[d][i+L];//b0+b1
f[d+1][i+L]=f[d][i]+f[d][i+L+L];//a0+a2
g[d+1][i+L]=g[d][i]+g[d][i+L+L];//b0+b2
}
solve(d+1,l,L);
solve(d+1,l+L,L);
for (int i=l;i<l+L;i++){
ans[d][i+L]=ans[d+1][i];
ans[d][i+L+L]=ans[d+1][i+L];
}
for (int i=l;i<l+L*3;i++){
f[d+1][i]=f[d][i];
g[d+1][i]=g[d][i];
}
solve(d+1,l,L);
solve(d+1,l+L,L);
solve(d+1,l+L+L,L);
for (int i=l;i<l+L;i++){
ans[d][i]=ans[d+1][i];//a0b0
ans[d][i+L]-=ans[d+1][i]+ans[d+1][i+L];//a0b1+a1b0=(a0+a1)(b0+b1)-a0b0-a1b1
ans[d][i+L+L]+=ans[d+1][i+L]*2-ans[d+1][i]-ans[d+1][i+L+L];//a0b2+2a1b1+a2b0=(a0+a2)(b0+b2)-a0b0-a2b2+2a1b1
}
}
int main(){
for (T=rd();T--;){
n=rd();
for (L=1;L<n;L*=3);
memset(f[0],0,L<<2);
memset(g[0],0,L<<2);
memset(ans[0],0,max(L,n<<1)<<2);
for (i=0;i<n;i++) f[0][i]=rd();
for (i=0;i<n;i++) g[0][i]=rd();
solve(0,0,L);
for (i=0;i<2*n-1;i++) pc(ans[0][i]%3|48),pc(' ');
pc(10);
}
flush();
}