【算法】快速数论变换(NTT)初探
【简介】
快速傅里叶变换(FFT)运用了单位复根的性质减少了运算,但是每个复数系数的实部和虚部是一个余弦和正弦函数,因此系数都是浮点数,而浮点数的运算速度较慢且可能产生误差等精度问题,因此提出了以数论为基础的具有循环卷积性质的快速数论变换(NTT)。
在FFT中,通过次单位复根即的来运算,而对于NTT来说,则是运用了素数的原根来运算。
【原根】
【定义】
对于两个正整数满足,由欧拉定理可知,存在正整数,如,使得。
因此,在时,定义对模的指数为使成立的最小正整数。若,则称是模的原根。
【性质/定义2】
若一个数是对于的原根,那么的结果互不相同。
【求原根方法】
对质数分解质因数得到不同的质因子,对于任何,判定是否为的原根,只需要检验这n个数中,是否存在一个数为,若存在,则不是的原根,否则是的原根。
【正确性证明】
假设存在一个使得且。
由裴蜀定理得,一定存在一组使得
由欧拉定理/费马小定理得,
于是
故
又,于是必整除中至少一个,设,则。
故假设不成立。
【用途】
我们可以发现原根拥有所有FFT所需的的性质,于是如果我们用来代替,就能把复数对应成一个整数,在意义下做快速变换了。
【NTT模数】
显然在上述的用途中,必须是素数且必须是的因数,因为是的幂,所以可以构造形如的素数。
常见的形如的素数有,它们的原根都为
如果题目的模数任意怎么办?我们取的模数必须超过
那么我们可以取多个模数(乘积)做完NTT之后用CRT合并...
【例题】
原来的配方,熟悉的味道...(怎么比FFT慢了
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <algorithm>
#define ll long long
#define MOD(x) ((x) >= mod ? (x)-mod : (x))
using namespace std;
const int maxn = 4000010, inf = 1e9, mod = 998244353;
int n, N;
int a[maxn], b[maxn], c[maxn];
char s[maxn];
inline int power(int a, int b)
{
int ans = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1)
ans = 1ll * ans * a % mod;
return ans;
}
inline void ntt(int *a, int f)
{
for (int i = 0, j = 0; i < N; i++)
{
if (i < j)
swap(a[i], a[j]);
for (int k = N >> 1; (j ^= k) < k; k >>= 1)
;
}
for (int i = 2; i <= N; i <<= 1)
{
int nw = power(3, (mod - 1) / i);
if (f == -1)
nw = power(nw, mod - 2);
for (int j = 0, m = i >> 1; j < N; j += i)
for (int k = 0, w = 1; k < m; k++)
{
int t = 1ll * a[j + k + m] * w % mod;
a[j + k + m] = MOD(a[j + k] - t + mod);
a[j + k] = MOD(a[j + k] + t);
w = 1ll * w * nw % mod;
}
}
if (f == -1)
for (int i = 0, inv = power(N, mod - 2); i < N; i++)
a[i] = 1ll * a[i] * inv % mod;
}
int main()
{
scanf("%d", &n);
for (N = 1; N < (n << 1); N <<= 1)
;
scanf("%s", s);
for (int i = 0; i < n; i++)
a[n - i - 1] = s[i] - '0';
scanf("%s", s);
for (int i = 0; i < n; i++)
b[n - i - 1] = s[i] - '0';
ntt(a, 1);
ntt(b, 1);
for (int i = 0; i < N; i++)
c[i] = 1ll * a[i] * b[i] % mod;
ntt(c, -1);
for (int i = 0; i < N; i++)
if (c[i] >= 10)
{
c[i + 1] += c[i] / 10;
c[i] %= 10;
if (i == N - 1)
N++;
}
N--;
while (!c[N] && N > 0)
N--;
for (int i = N; ~i; i--)
printf("%d", c[i]);
}
DP。
为前个数总和后为的方案数,为集合内数后为的数的个数。
这个式子显然是可以矩阵快速幂优化的,效率,还是会TLE。
对所有数,设,则将映射为,即取在模意义下的离散对数,为,为。
原式改为:
卷积形式,上NTT
注意乘出来超过项,而次数超过项的贡献应加回次数减去的项中,即上方式子的取模操作。
NTT次数较多预处理出的所有次幂会快一些,而且因为不明原因直接取模会比快速取模快,CPU你怎么这么难伺候...
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#define MOD(x) ((x) >= mod ? (x)-mod : (x))
#include <algorithm>
#define ll long long
using namespace std;
const int maxn = 20010, inf = 1e9, mod = 1004535809;
int n, m, x, s, N, g, X;
int a[maxn], b[maxn], A[maxn], B[maxn], ind[maxn], p[maxn], g1[maxn], g2[maxn];
inline void read(int &k)
{
int f = 1;
k = 0;
char c = getchar();
while (c < '0' || c > '9')
c == '-' && (f = -1), c = getchar();
while (c <= '9' && c >= '0')
k = k * 10 + c - '0', c = getchar();
k *= f;
}
inline int power(int a, int b, int p)
{
int ans = 1;
for (; b; b >>= 1, a = 1ll * a * a % p)
if (b & 1)
ans = 1ll * ans * a % p;
return ans;
}
inline bool judge(int x)
{
for (int i = 1; i <= p[0]; i++)
if (power(x, (m - 1) / p[i], m) == 1)
return 0;
return 1;
}
inline int yg(int x)
{
x--;
for (int i = 2; i * i <= x; i++)
if (x % i == 0)
{
p[++p[0]] = i;
while (x % i == 0)
x /= i;
}
if (x > 1)
p[++p[0]] = x;
for (int i = 2;; i++)
if (judge(i))
return i;
}
inline void init()
{
int base1 = power(3, (mod - 1) / N, mod), base2 = power(base1, mod - 2, mod);
g1[0] = g2[0] = 1;
for (int i = 1; i <= N; i++)
g1[i] = 1ll * g1[i - 1] * base1 % mod, g2[i] = 1ll * g2[i - 1] * base2 % mod;
}
inline void ntt(int *a, int *w, int f)
{
for (int i = 0, j = 0; i < N; i++)
{
if (i < j)
swap(a[i], a[j]);
for (int k = N >> 1; (j ^= k) < k; k >>= 1)
;
}
for (int i = 2; i <= N; i <<= 1)
for (int j = 0, m = i >> 1; j < N; j += i)
for (int k = 0; k < m; k++)
{
int t = 1ll * a[j + k + m] * w[N / i * k] % mod;
a[j + k + m] = (a[j + k] - t + mod) % mod;
a[j + k] = (a[j + k] + t) % mod;
}
if (f == -1)
for (int i = 0, inv = power(N, mod - 2, mod); i < N; i++)
a[i] = 1ll * a[i] * inv % mod;
}
inline void mul(int *a, int *b, int *c)
{
memcpy(A, a, N << 2);
memcpy(B, b, N << 2);
ntt(A, g1, 1);
ntt(B, g1, 1);
for (int i = 0; i < N; i++)
c[i] = 1ll * A[i] * B[i] % mod;
ntt(c, g2, -1);
for (int i = 0; i < m - 1; i++)
c[i] = MOD(c[i] + c[i + m - 1]), c[i + m - 1] = 0;
}
int main()
{
read(n);
read(m);
read(X);
read(s);
g = yg(m);
for (int i = 0, j = 1; i < m - 1; i++, j = 1ll * j * g % m)
ind[j] = i;
X = ind[X];
for (N = 1; N < m; N <<= 1)
;
N <<= 1;
init();
for (int i = 1; i <= s; i++)
read(x), x && (a[ind[x]]++);
b[0] = 1;
for (; n; n >>= 1, mul(a, a, a))
if (n & 1)
mul(a, b, b);
printf("%d\n", b[X]);
}
评论
发表评论