【HDOJ】5657 CA Loves Math

时间:2023-09-09 11:56:20

1. 题目描述
对于给定的$a, n, mod, a \in [2,11], n \in [0, 10^9], mod \in [1, 10^9]$求出在$[1, a^n]$内的所有$a$进制下的数并且不含重复数字。

2. 基本思路
这题比赛的时候,没人做出来,但是基本思路大家都有。显然可以直接将$n$改写为$\min(n,a)$。
我比赛的代码TLE,思路是这样的:首先$mod$很小时,可以数位DP解;当$mod$很大时,可以先找到所有的排列然后,然后令$delta = fact(a)/fact(a-n)$,然后以这个作为循环间隔找到满足不重复的数字,然后再判断是否是$mod$的倍数。
hack的时候,其实可以直接以$mod$作为阈值解。题解也提到了这个思路。
这样原问题可以分两种情况:
(1) 大于阈值,枚举$mod$的倍数,然后判断是否包含重复数字;
(2) 小于等于阈值,数位DP。
然后,赛后交还是wa了几次。这里有几个特殊情况需要单独考虑:
(1) n = 0时,只能取1,直接判断是否为$mod$倍数。
(2) n = 1时,可以取[1, a],同样需要判断是否为$mod$倍数。
并且,数位DP是累加DP的。即长度为$[1,n]$的满足条件的总和。

3. 代码

 /*  */
#include <iostream>
#include <sstream>
#include <string>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
#include <deque>
#include <bitset>
#include <algorithm>
#include <cstdio>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <cctype>
#include <cassert>
#include <functional>
#include <iterator>
#include <iomanip>
using namespace std;
//#pragma comment(linker,"/STACK:102400000,1024000") #define sti set<int>
#define stpii set<pair<int, int> >
#define mpii map<int,int>
#define vi vector<int>
#define pii pair<int,int>
#define vpii vector<pair<int,int> >
#define rep(i, a, n) for (int i=a;i<n;++i)
#define per(i, a, n) for (int i=n-1;i>=a;--i)
#define clr clear
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define all(x) (x).begin(),(x).end()
#define SZ(x) ((int)(x).size())
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1
#define INF 0x3f3f3f3f
#define mset(a, val) memset(a, (val), sizeof(a)) #define LL __int64 const int bound = ;
const int maxn = ;
int dp[<<][bound];
int Bits[<<];
vector<int> St[maxn];
int Sz[maxn];
int a, n, mod; void solve();
void _solve(); inline int lowest(int x) {
return -x & x;
} inline int getBits(int x) {
int ret = ; while (x) {
++ret;
x -= lowest(x);
}
return ret;
} void init() {
int mst = << ; rep(i, , mst) {
Bits[i] = getBits(i);
St[Bits[i]].pb(i);
} rep(i, , maxn)
Sz[i] = SZ(St[i]);
} bool vis[];
inline bool judge(LL x) {
if (x == ) return false; memset(vis, false, sizeof(vis));
while (x) {
int tmp = x % a;
if (vis[tmp])
return false;
x /= a;
vis[tmp] = true;
}
return true;
} void solve() {
if (n == ) {
printf("%d\n", %mod== ? :);
return ;
}
if (n == ) {
int ans = ;
rep(i, , a+)
ans += i%mod == ? :;
printf("%d\n", ans);
return ;
} n = min(n, a);
if (mod > bound) {
_solve();
return ;
} int mst = <<a;
memset(dp, , sizeof(dp)); rep(j, , a)
++dp[<<j][j%mod]; rep(l, , n) {
rep(j, , Sz[l]) {
const int st = St[l][j];
if (st >= mst)
continue;
rep(k, , mod) {
const int& cnt = dp[st][k];
if (cnt == )
continue; rep(i, , a) {
if (st & (<<i))
continue; int nst = st | (<<i);
int nk = (k * a + i) % mod;
dp[nst][nk] += cnt;
}
}
}
} int ans = ; rep(l, , n+) {
rep(j, , Sz[l]) {
const int& st = St[l][j];
ans += dp[st][];
}
} printf("%d\n", ans);
} LL Pow(LL base, int n) {
LL ret = ; while (n) {
if (n & )
ret = ret * base;
base = base * base;
n >>= ;
} return ret;
} void _solve() {
LL tmp = mod, ubound = Pow(a, n);
int ans = ; while (tmp <= ubound) {
if (judge(tmp))
++ans;
tmp += mod;
} printf("%d\n", ans);
} int main() {
ios::sync_with_stdio(false);
#ifndef ONLINE_JUDGE
freopen("data.in", "r", stdin);
freopen("data.out", "w", stdout);
#endif int t; init();
scanf("%d", &t);
while (t--) {
scanf("%d%d%d",&a,&n,&mod);
solve();
} #ifndef ONLINE_JUDGE
printf("time = %ldms.\n", clock());
#endif return ;
}