「2017 山东一轮集训 Day1」Sum

题目大意

求有多少 n 位十进制数是 p 的倍数且每位之和小于等于 m (对0..mm都要求),允许前导 0 ,答案对 998244353 取模。

输入格式

三个数n, mm, p

输出格式

mm + 1个数 分别表示m = 0 .. mm的时候的答案

样例输入

2 3 3

样例输出

1 1 1 5

数据范围

对于测试点 1,1≤n≤1000,1≤p≤50.1≤m≤5;
对于测试点 2、3,1≤n≤10​^9,1≤p≤50.1≤m≤5;
对于测试点 4、5、6,1≤n≤10​^9,1≤p≤50.1≤m≤50;
对于测试点 7、8、9、10,1≤n≤10​^9,1≤p≤16.1≤m≤1000。

解题报告

暴力dp:dp[i][j][k] -> dp[i+1][(j*10+a)%p][k+a]
然后发现可以优化转移。

以前我一般会采用矩阵快速幂来优化转移,这道题不是,是用的倍增思想。我们考虑dp[i] -> dp[2i]的过程,相当于是把序列分为两部分的dp[i],那么我们发现只需合并两部分信息就可以了:
dp[a][b] * dp[d][e][f] -> dp[a+d][(b+10^e)%p]

倍增+暴力转移,常数不要写的太大的话可以通过前60分。

对于后面的四组数据,c和f非常大,我们可以发现如果我们枚举b和e,这相当于对两个序列做卷积。

对每个状态,我们需要维护这p个序列,然后做NTT即可。注意一点常数优化,如果我们要求AB+CD(这里为卷积)的结果,我们可以不用把AB的结果idft回去再相加,可以直接对点值相加。最后一起idft即可。

我的代码

#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const LL P = 998244353;
const int G = 3;
int n, m, p;
LL ttt[55][3005], ret[55][3005], tmp[55][3005], ret_dft[55][3005], tmp_dft[55][3005], R[100005], L;
int _wn[100005], __wn[100005];
LL ksm(LL a, LL b, LL c) {
    LL r = 1, t = a;
    while (b) {
        if (b & 1ll) r = (r * t) % c;
        t = (t * t) % c;
        b >>= 1ll;
    }
    return r;
}
LL NP;
int N;
void ntt(LL *a, bool inv) {
    for (int i = 1; i < N; ++i) if (R[i] > i) swap(a[i], a[R[i]]);
    for (int i = 1; i < N; i <<= 1) {
        LL wn = _wn[i];
        if (inv) wn = __wn[i];
        for (int j = 0; j < N; j += (i << 1)) {
            LL w = 1;
            for (int k = 0; k < i; ++k) {
                LL x = a[j + k], y = LL(a[j + k + i]) * w % P;
                a[j + k] = (x + y) % P;
                a[j + k + i] = (x - y + P) % P;
                w = LL(w) * wn % P;
            }
        }
    }
    if (inv) {
        for (int i = 0; i < N; ++i)
            a[i] = (a[i] * NP) % P;
    }
}
void ntt_ksm(int b) {
    int L = 1;
    while (b) {
        int po = ksm(10, L, p);
        if (b & 1) {
            for (int i = 0; i < p; ++i)
                for (int j = 0; j < N; ++j)
                    ttt[i][j] = 0;
            for (int i = 0; i < p; ++i) {
                for (int j = 0; j < N; ++j) {
                    ret_dft[i][j] = ret[i][j];
                    tmp_dft[i][j] = tmp[i][j];
                }
                ntt(ret_dft[i], false);
                ntt(tmp_dft[i], false);
            }
            for (int i = 0; i < p; ++i)
                for (int k = 0; k < p; ++k)
                    for (int j = 0; j < N; ++j)
                        (ttt[(i+k*po)%p][j] += tmp_dft[i][j] * ret_dft[k][j]) %= P;
            for (int i = 0; i < p; ++i) {
                ntt(ttt[i], true);
                for (int j = 0; j <= m; ++j)
                    ret[i][j] = ttt[i][j];
            }
        }
            for (int i = 0; i < p; ++i)
                for (int j = 0; j < N; ++j)
                    ttt[i][j] = 0;
            for (int i = 0; i < p; ++i) {
                for (int j = 0; j < N; ++j)
                    tmp_dft[i][j] = tmp[i][j];
                ntt(tmp_dft[i], false);
            }
            for (int i = 0; i < p; ++i)
                for (int k = 0; k < p; ++k)
                    for (int j = 0; j < N; ++j)
                        (ttt[(i+k*po)%p][j] += tmp_dft[i][j] * tmp_dft[k][j]) %= P;
            for (int i = 0; i < p; ++i) {
                ntt(ttt[i], true);
                for (int j = 0; j <= m; ++j)
                    tmp[i][j] = ttt[i][j];
            }
        b >>= 1;
        L <<= 1;
    }
}
int main() {
    freopen("sum.in", "r", stdin);
    freopen("sum.out", "w", stdout);
    scanf("%d %d %d", &n, &p, &m);
    L = -1;
    for (L = -1, N = 1; N <= m*2; N <<= 1, ++L);
    NP = ksm(N, P-2, P);
    for (int i = 1; i < N; i <<= 1) {
        _wn[i] = ksm(G, (P-1)/(i<<1), P);
        __wn[i] = ksm(_wn[i], P-2, P);
    }
    for (int i = 1; i < N; ++i)
        R[i] = (R[i >> 1] >> 1) | ((i & 1) << L);
    for (int i = 0; i < 10 && i <= m; ++i) {
        ++ret[i%p][i];
        ++tmp[i%p][i];
    }
    --n;
    ntt_ksm(n);
    printf("%lld", ret[0][0]);
    for (int i = 1; i <= m; ++i) {
        (ret[0][i] += ret[0][i-1]) %= P;
        printf(" %lld", ret[0][i]);
    }
    fclose(stdin);
    fclose(stdout);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据