[MtOI2019] T6 Solution

upd:之前式子有一点锅,现已修复。
这其实是一个很水的套路题。。
首先容易看出来 f(x,0) 是个线性递推的形式,要求的是其 k 阶前缀积。
要求乘积不太好搞,可以对 2 取一下对数,化乘为加。
于是问题转化为:
一个数列 a :\large a_n=n\space(n\le42)
\large a_n=\sum\limits_{i=1}^{42}ia_{n-i}\space(n\ge 43)求它 k 阶前缀和的第 n 项。

关于线性递推式的高阶前缀和有一个优美的性质。
设数列 a 的递推系数为 f ,那么在 f 前面加个 -1 ,然后做 k 阶差分得到的序列即 a 的 k 阶前缀和的递推式。( 当然要在后面扩展 k 项,同时最后去掉 -1 )
在此简短证明一下,设:
\large a_n=\sum\limits_{i=1}^kf_ia_{n-i}\large b_n=\sum\limits_{i=1}^na_i\large b_n=b_{n-1}+a_n=b_{n-1}+\sum\limits_{i=1}^kf_ia_{n-i}\large = b_{n-1}+\sum\limits_{i=1}^kf_i(b_{n-i}-b_{n-i-1})后面的那个求和展开,可以得到很多形如 (f_i-f_{i-1})b_{n-i} 的式子,这就是一个很明显的差分形式,接下来的证明就很容易了。
得到 k 阶前缀和的递推式后,把 a 的 k 阶前缀和的前几项也求出来,然后直接上线性递推板子即可。
不过要注意的是我们刚才取了个 \log ,所以以上运算都要对 \color{red} 998244352 取模,这需要用到任意模数。 7 次 FFT 的做法常数过大,不能通过;需要使用 4 次 FFT 的做法。
还有就是模数不是素数时,算组合数很麻烦,所以直接用倍增多项式快速幂计算高阶差分或前缀和即可。
时间复杂度 \Theta(k\log^2 k+k\log k\log n) 。
std:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 65539
#define LIM 65536
#define ll long long
#define reg register
#define p 998244352
#define pi 3.141592653589793
using namespace std;

struct complex{
    double x,y;
    inline complex(double x=0,double y=0):x(x),y(y){}
    inline complex operator + (const complex& b) const{
        return complex(x+b.x,y+b.y);
    }
    inline complex operator - (const complex& b) const{
        return complex(x-b.x,y-b.y);
    }
    inline complex operator * (const complex& b) const{
        return complex(x*b.x-y*b.y,x*b.y+y*b.x);
    }
    inline complex operator / (const int& b) const{
        return complex(x/b,y/b);
    }
    inline complex operator ~ () const{
        return complex(x,-y);
    }
};

namespace polynomial{ 
    int lg2[N],rev[N];
    complex rt[N];
    struct poly{
        int a[N];
        int t;
    };

    poly ig;

    inline int add(int a,int b){
        return a+b>=p?a+b-p:a+b;
    }

    inline int dec(int a,int b){
        return a<b?a-b+p:a-b;
    }

    inline void init(){
        for(reg int i=2;i<=LIM;++i) lg2[i] = lg2[i>>1]+1;
        rt[0] = complex(1,0);
        for(reg int i=1;i<=LIM;++i) rt[i] = complex(cos(pi/LIM*i),sin(pi/LIM*i));
    }

    inline void FFT(complex *a,int type,int lim){
        for(reg int i=1;i<=lim;++i){
            if(i>=rev[i]) continue;
            swap(a[i],a[rev[i]]);
        }
        reg complex w,y;
        reg int l = LIM;
        for(reg int mid=1;mid!=lim;mid<<=1){
            for(reg int j=0;j!=lim;j+=(mid<<1)){
                for(reg int k=0;k!=mid;++k){
                    w = type==1?rt[l*k]:(~rt[l*k]);
                    y = w*a[j|k|mid];
                    a[j|k|mid] = a[j|k]-y;
                    a[j|k] = a[j|k]+y;
                }
            }
            l >>= 1;
        }
        if(type==1) return;
        for(reg int i=0;i!=lim;++i) a[i] = a[i]/lim;
    }

    poly multiply(poly A,poly B,int len){
        complex f[N],g[N],h[N],q[N];
        complex t,f0,f1,g0,g1;
        poly r;
        ll x,y,z;
        int lim = 1;
        while(lim<=A.t+B.t) lim <<= 1;
        int l = lg2[lim]-1;
        for(reg int i=1;i<=lim;++i)
            rev[i] = (rev[i>>1]>>1)|((i&1)<<l);
        for(reg int i=0;i<=lim;++i){
            f[i] = complex(A.a[i]>>15,A.a[i]&32767);
            g[i] = complex(B.a[i]>>15,B.a[i]&32767);
        }
        FFT(f,1,lim);
        bool flag = A.t==B.t;
        if(flag){
            for(reg int i=0;i<=A.t;++i){
                if(A.a[i]==B.a[i]) continue;
                flag = false;
                break;
            }
        }
        if(flag) for(reg int i=0;i!=lim;++i) g[i] = f[i];
        else FFT(g,1,lim); 
        for(reg int i=0;i!=lim;++i){
            t = ~f[i?lim-i:0];
            f0 = (f[i]-t)*complex(0,-0.5),f1 = (f[i]+t)*0.5;
            t = ~g[i?lim-i:0];
            g0 = (g[i]-t)*complex(0,-0.5),g1 = (g[i]+t)*0.5;
            h[i] = f1*g1;
            q[i] = f1*g0 + f0*g1 + f0*g0*complex(0,1);
        }
        FFT(h,-1,lim),FFT(q,-1,lim);
        for(reg int i=0;i<=len;++i){
            x = (ll)(h[i].x+0.5)%p<<30;
            y = (ll)(q[i].x+0.5)<<15;
            z = q[i].y+0.5;
            r.a[i] = (x+y+z)%p;
        }
        r.t = len;
        for(reg int i=len+1;i!=LIM;++i) r.a[i] = 0;
        return r;
    }

    poly inverse(poly f){ 
        poly g,h,q;
        memset(g.a,0,sizeof(g.a));
        int top = 0,n = f.t;
        int s[30];
        while(n){
            s[++top] = n;
            n >>= 1;
        }
        g.a[0] = g.t = 1;
        while(top--){
            n = s[top+1];
            h = f,q = g;
            for(reg int i=n+1;i<=f.t;++i) h.a[i] = 0;
            g = multiply(multiply(g,g,n),h,n);
            for(reg int i=0;i<=n;++i)
                g.a[i] = dec(add(q.a[i],q.a[i]),g.a[i]);
        }
        return g;
    }

    inline void reverse(poly &f){
        int l = f.t>>1;
        for(reg int i=0;i<=l;++i)
            swap(f.a[i],f.a[f.t-i]);
    }

    poly divide(poly f,poly g){
        int n = f.t,m = g.t;
        reverse(f);
        g = ig;
        g.t = f.t = n-m+1;
        for(reg int i=f.t+1;i<=n;++i) f.a[i] = g.a[i] = 0;
        f = multiply(f,g,n-m);
        reverse(f);
        return f;
    }

    poly mod(poly f,poly g){
        if(f.t<g.t) return f;
        while(!f.a[f.t]) --f.t;
        poly q = divide(f,g);
        int l,m = g.t;
        for(reg int i=m+1;i<=f.t;++i) f.a[i] = 0;
        f.t = m;
        l = min(q.t,m);
        for(reg int i=l+1;i<=q.t;++i) q.a[i] = 0;
        q.t = l;
        g = multiply(g,q,m);
        for(reg int i=0;i<m;++i) f.a[i] = dec(f.a[i],g.a[i]);
        for(reg int i=m;i<=f.t;++i) f.a[i] = 0;
        f.t = m-1;
        return f;
    }

    poly mod_power(poly f,ll t,poly G){
        ig = G;
        reverse(ig);
        ig = inverse(ig);
        poly g = f;
        --t;
        while(t){
            if(t&1) g = mod(multiply(f,g,f.t+g.t),G);
            f = mod(multiply(f,f,f.t<<1),G);
            t >>= 1;
        }
        return g;
    }

    poly power(poly f,ll t,int lim){
        poly g = f;
        --t;
        while(t){
            if(t&1) g = multiply(f,g,min(f.t+g.t,lim));
            f = multiply(f,f,min(f.t<<1,lim));
            t >>= 1;
        }
        return g;
    }
};
using namespace polynomial;

inline int poww(int a,int t,int m){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%m;
        a = (ll)a*a%m;
        t >>= 1;
    }
    return res;
}

poly F,G,H;
int k,lim;
int a[N],f[N];
ll n;

int main(){
    init();
    int ans = 0;
    scanf("%lld%d",&n,&k);
    if(n==1){
        putchar('2');
        return 0;
    }
    for(reg int i=0;i<=41;++i) F.a[i] = i+1;
    lim = k+42;
    for(reg int i=42;i<lim;++i)
        for(reg int j=1;j<=42;++j)
            F.a[i] = (F.a[i]+(ll)j*F.a[i-j])%p;
    for(reg int i=1;i<=42;++i) G.a[i] = i;
    G.a[0] = p-1;
    F.t = lim-1,G.t = 42;
    if(k>0){
        H.a[0] = 1,H.a[1] = p-1;
        H.t = 1;
        H = power(H,k,lim);
        G = multiply(G,H,lim);
        H.t = lim-1;
        H.a[lim] = 0;
        F = multiply(F,inverse(H),lim-1);
    }
    for(reg int i=0;i<lim;++i){
        a[i] = F.a[i];
        f[i] = G.a[i];
    }
    f[lim] = G.a[lim];
    memset(F.a,0,sizeof(F.a));
    memset(G.a,0,sizeof(G.a));
    F.a[1] = F.t = 1;
    for(reg int i=0;i<=lim;++i) G.a[lim-i] = f[i]?p-f[i]:0;
    F = mod_power(F,n-1,G);
    for(reg int i=0;i<lim;++i)
        ans = (ans+(ll)a[i]*F.a[i])%p;
    ans = poww(2,ans,998244353);
    printf("%d",ans);
    return 0;
}

评论

此博客中的热门博文

高中地理必修一知识点总结

【CF961F】k-substrings

偷税与漏税的区别