【算法】快速数论变换(NTT)初探

【简介】
  快速傅里叶变换(FFT)运用了单位复根的性质减少了运算,但是每个复数系数的实部和虚部是一个余弦和正弦函数,因此系数都是浮点数,而浮点数的运算速度较慢且可能产生误差等精度问题,因此提出了以数论为基础的具有循环卷积性质的快速数论变换(NTT)。
  在FFT中,通过n次单位复根即ωn=1ω来运算,而对于NTT来说,则是运用了素数的原根来运算。
【原根】
【定义】
  对于两个正整数a,m满足gcd(a,m)=1,由欧拉定理可知,存在正整数dm1,如d=φ(m),使得ad1(mod m)
  因此,在gcd(a,m)=1时,定义a对模m的指数δm(a)为使ad1(mod m)成立的最小正整数d。若δm(a)=φ(m),则称a是模m的原根。
【性质/定义2】
  若一个数g是对于P的原根,那么gi mod P,1i<P的结果互不相同。
【求原根方法】
  对质数P1分解质因数得到不同的质因子p1,p2,p3,...,pn,对于任何2aP1,判定a是否为P的原根,只需要检验aP1p1,aP1p2,...,aP1pn这n个数中,是否存在一个数mod P1,若存在,则a不是P的原根,否则aP的原根。
【正确性证明】
  假设存在一个t<φ(P)=P1使得at1(mod P)i[1,P),aP1pi mod P1
  由裴蜀定理得,一定存在一组k,x使得kt=(P1)x+gcd(t,P1)
  由欧拉定理/费马小定理得,aP11(mod P)
  于是1akta(P1)x+gcd(t,P1)agcd(t,P1)(mod P)
  t<P1gcd(t,P1)<P1
  又gcd(t,P1)|P1,于是gcd(t,P1)必整除aP1p1,aP1p2,...,aP1pn中至少一个,设gcd(t,P1)|aP1pi,则aP1piagcd(t,P1)1(mod P)
  故假设不成立。
【用途】
  我们可以发现原根g拥有所有FFT所需的ω的性质,于是如果我们用gP1N(mod P)来代替ωn=e2πiN,就能把复数对应成一个整数,在(mod P)意义下做快速变换了。
【NTT模数】
  显然在上述的用途中,P必须是素数且N必须是P1的因数,因为N2的幂,所以可以构造形如P=c2k+1的素数。
  常见的形如P=c2k+1的素数有998244353=119223+1,1004535809=479221+1,它们的原根都为3
  如果题目的模数P任意怎么办?我们取的模数必须超过n(P1)2
  那么我们可以取多个模数(乘积>n(P1)2)做完NTT之后用CRT合并...
【例题】
  原来的配方,熟悉的味道...(怎么比FFT慢了

#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <algorithm>
#define ll long long
#define MOD(x) ((x) >= mod ? (x)-mod : (x))
using namespace std;
const int maxn = 4000010, inf = 1e9, mod = 998244353;
int n, N;
int a[maxn], b[maxn], c[maxn];
char s[maxn];
inline int power(int a, int b)
{
    int ans = 1;
    for (; b; b >>= 1, a = 1ll * a * a % mod)
        if (b & 1)
            ans = 1ll * ans * a % mod;
    return ans;
}
inline void ntt(int *a, int f)
{
    for (int i = 0, j = 0; i < N; i++)
    {
        if (i < j)
            swap(a[i], a[j]);
        for (int k = N >> 1; (j ^= k) < k; k >>= 1)
            ;
    }
    for (int i = 2; i <= N; i <<= 1)
    {
        int nw = power(3, (mod - 1/ i);
        if (f == -1)
            nw = power(nw, mod - 2);
        for (int j = 0, m = i >> 1; j < N; j += i)
            for (int k = 0, w = 1; k < m; k++)
            {
                int t = 1ll * a[j + k + m] * w % mod;
                a[j + k + m] = MOD(a[j + k] - t + mod);
                a[j + k] = MOD(a[j + k] + t);
                w = 1ll * w * nw % mod;
            }
    }
    if (f == -1)
        for (int i = 0, inv = power(N, mod - 2); i < N; i++)
            a[i] = 1ll * a[i] * inv % mod;
}
int main()
{
    scanf("%d"&n);
    for (N = 1; N < (n << 1); N <<= 1)
        ;
    scanf("%s", s);
    for (int i = 0; i < n; i++)
        a[n - i - 1= s[i] - '0';
    scanf("%s", s);
    for (int i = 0; i < n; i++)
        b[n - i - 1= s[i] - '0';
    ntt(a, 1);
    ntt(b, 1);
    for (int i = 0; i < N; i++)
        c[i] = 1ll * a[i] * b[i] % mod;
    ntt(c, -1);
    for (int i = 0; i < N; i++)
        if (c[i] >= 10)
        {
            c[i + 1+= c[i] / 10;
            c[i] %= 10;
            if (i == N - 1)
                N++;
        }
    N--;
    while (!c[N] && N > 0)
        N--;
    for (int i = N; ~i; i--)
        printf("%d"c[i]);
}

  DP。
  f[i][j]为前i个数总和mod m后为j的方案数,B[i]为集合内数mod m后为i的数的个数。f[i][jk%m]=f[i1][j]B[k]
  这个式子显然是可以矩阵快速幂优化的,效率O(m2logn),还是会TLE。
  对所有数x,设gi(mod m)=x,则将x映射为i,即取x在模m意义下的离散对数,cf[L]af[L1]
  原式改为:c[(i+j)%(φ(m)=m1)]=k=0i+ja[i]b[i+jk]
  卷积形式,上NTT
  注意乘出来超过m1项,而次数超过m1项的贡献应加回次数减去m1的项中,即上方式子的取模操作。
  NTT次数较多预处理出g的所有次幂会快一些,而且因为不明原因直接取模会比快速取模快,CPU你怎么这么难伺候...

#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#define MOD(x) ((x) >= mod ? (x)-mod : (x))
#include <algorithm>
#define ll long long
using namespace std;
const int maxn = 20010, inf = 1e9, mod = 1004535809;
int n, m, x, s, N, g, X;
int a[maxn], b[maxn], A[maxn], B[maxn], ind[maxn], p[maxn], g1[maxn], g2[maxn];
inline void read(int &k)
{
    int f = 1;
    k = 0;
    char c = getchar();
    while (c < '0' || c > '9')
        c == '-' && (f = -1), c = getchar();
    while (c <= '9' && c >= '0')
        k = k * 10 + c - '0', c = getchar();
    k *= f;
}
inline int power(int a, int b, int p)
{
    int ans = 1;
    for (; b; b >>= 1, a = 1ll * a * a % p)
        if (b & 1)
            ans = 1ll * ans * a % p;
    return ans;
}
inline bool judge(int x)
{
    for (int i = 1; i <= p[0]; i++)
        if (power(x, (m - 1/ p[i], m) == 1)
            return 0;
    return 1;
}
inline int yg(int x)
{
    x--;
    for (int i = 2; i * i <= x; i++)
        if (x % i == 0)
        {
            p[++p[0]] = i;
            while (x % i == 0)
                x /= i;
        }
    if (x > 1)
        p[++p[0]] = x;
    for (int i = 2;; i++)
        if (judge(i))
            return i;
}
inline void init()
{
    int base1 = power(3, (mod - 1/ N, mod), base2 = power(base1, mod - 2, mod);
    g1[0= g2[0= 1;
    for (int i = 1; i <= N; i++)
        g1[i] = 1ll * g1[i - 1* base1 % mod, g2[i] = 1ll * g2[i - 1* base2 % mod;
}
inline void ntt(int *a, int *w, int f)
{
    for (int i = 0, j = 0; i < N; i++)
    {
        if (i < j)
            swap(a[i], a[j]);
        for (int k = N >> 1; (j ^= k) < k; k >>= 1)
            ;
    }
    for (int i = 2; i <= N; i <<= 1)
        for (int j = 0, m = i >> 1; j < N; j += i)
            for (int k = 0; k < m; k++)
            {
                int t = 1ll * a[j + k + m] * w[N / i * k] % mod;
                a[j + k + m] = (a[j + k] - t + mod) % mod;
                a[j + k] = (a[j + k] + t) % mod;
            }
    if (f == -1)
        for (int i = 0, inv = power(N, mod - 2, mod); i < N; i++)
            a[i] = 1ll * a[i] * inv % mod;
}
inline void mul(int *a, int *b, int *c)
{
    memcpy(A, a, N << 2);
    memcpy(B, b, N << 2);
    ntt(A, g1, 1);
    ntt(B, g1, 1);
    for (int i = 0; i < N; i++)
        c[i] = 1ll * A[i] * B[i] % mod;
    ntt(c, g2, -1);
    for (int i = 0; i < m - 1; i++)
        c[i] = MOD(c[i] + c[i + m - 1]), c[i + m - 1= 0;
}
int main()
{
    read(n);
    read(m);
    read(X);
    read(s);
    g = yg(m);
    for (int i = 0, j = 1; i < m - 1; i++, j = 1ll * j * g % m)
        ind[j] = i;
    X = ind[X];
    for (N = 1; N < m; N <<= 1)
        ;
    N <<= 1;
    init();
    for (int i = 1; i <= s; i++)
        read(x), x && (a[ind[x]]++);
    b[0= 1;
    for (; n; n >>= 1mul(a, a, a))
        if (n & 1)
            mul(a, b, b);
    printf("%d\n"b[X]);
}

评论

此博客中的热门博文

将博客部署到星际文件系统(IPFS)

高中地理必修一知识点总结

一场CF的台前幕后(下)