数位DP 求K进制下0~N的每个数每位上出现的数的总和

时间:2022-05-01 19:55:26

好久没写博客了,因为感觉时间比较紧,另一方面没有心思,做的题目比较浅也是另一方面。

热身赛第二场被血虐了好不好,于是决定看看数位DP吧。

进入正题:

如题是一道经(简)典(单)的数位dp。

第一步,对于数K^n-1这种形式的数,位数为n,它的各个位上,每个数0~K-1出现过的次数是一样的。

于是对于数B=K^n-1,有f(B)=(B+1)*n*(0+1+2+...+K-1)/K=(B+1)*n*(K-1)/2;

程序为:

 LL sum1(int pre,int n,int k)
{
LL ret=;
LL pw=;
for(int i=;i<n;i++) pw*=k;
ret=pre*pw+pw*n*(k-)/;
return ret;
}

其中pre在这种情况下为0,pre是什么?我们立刻进入下一步讨论。

第二步,由第一步的结论,我们可以引申一下。为了更形象一点,我们不妨在十进制的情况下讨论。

现在我提出一个问题:如何计算0~49999的数它们各个位上数字之和?(K=10的前提下)

我们根据第一步可以很容易求出[0,9999]=(9999+1)*4*(10-1)/2。

那么还剩下[10000,19999],[20000,29999],[30000,39999],[40000,49999]该怎么求?

仔细观察发现[10000,19999]不过是每个数都比[0,9999]多了一个为1的万位,[20000,29999]不过是每个数都比[0,9999]多了一个为2的万位,[30000,39999]不过是每个数都比[0,9999]多了一个为3的万位,依次类推...就发现了规律。

所以此时这个与后面的数位都无关的万位,我们用i表示,万位之前没有其他的位,所以pre=0(如果对pre有点不理解,看完第三步就知道了),于是对于[i0000,i9999]这样的解就是((pre+i)*10000)+(9999+1)*4*(10-1)/2。

那么,不难得知,求解通式即为((pre+i)*K^n)+(K^n)*n*(K-1)/2。

第三步,基于第一步和第二步的结论,已经可以求出类似于999(K=10),39999(K=10),49999(K=10)的解。

现在又提出一个问题,对于[0,54321]我们怎么解?

当然,先延续之前“区间划分”+“前缀”的思路,先划分为[0,9999],[10000,19999],[20000,29999],[30000,39999],[40000,49999],[50000,54321]。

对于[0,9999],[10000,19999],[20000,29999],[30000,39999],[40000,49999]已经讨论过了,接下来讨论如何求[50000,54321]。

这时把万位的5看作一个前缀,区间就变为了[0,4321],于是只要求前缀pre=5的[0,4321]的解,也就是递归调用第二步的方法,这样就可以求到[0,321],[0,21],[0,1]这样把所有的解相加,就是需要的答案了。

 LL sum2(int pre,LL n,int k)
{
if(n<k){
LL ret=;
for(int i=;i<=n;i++) ret+=pre+i;
return ret;
}
LL tn=n,pw=,ret=;
int mi=;
while(tn>=k){
pw*=k;
mi++;
tn/=k;
}
for(int i=;i<tn;i++)
ret+=sum1(pre+i,mi,k);
ret+=sum2(pre+tn,n-tn*pw,k);
return ret;
}

为了验证跑出来的数据对不对,再写一个暴力求[0,n]的程序,这查错的办法。

 LL check(int n,int k)
{
LL ret=;
int t;
for(int i=;i<=n;i++){
t=i;
while(t){
ret+=t%k;
t/=k;
}
}
return ret;
}

完整程序:

 #include <stdio.h>
typedef long long LL; LL sum1(int pre,int n,int k)
{
LL ret=;
LL pw=;
for(int i=;i<n;i++) pw*=k;
ret=pre*pw+pw*n*(k-)/;
return ret;
} LL check(int n,int k)
{
LL ret=;
int t;
for(int i=;i<=n;i++){
t=i;
while(t){
ret+=t%k;
t/=k;
}
}
return ret;
} LL sum2(int pre,LL n,int k)
{
if(n<k){
LL ret=;
for(int i=;i<=n;i++) ret+=pre+i;
return ret;
}
LL tn=n,pw=,ret=;
int mi=;
while(tn>=k){
pw*=k;
mi++;
tn/=k;
}
for(int i=;i<tn;i++)
ret+=sum1(pre+i,mi,k);
ret+=sum2(pre+tn,n-tn*pw,k);
return ret;
} int main()
{
LL n;
int k;
while(~scanf("%I64d %d",&n,&k)){
printf("%I64d\n",sum2(,n,k));
printf("%I64d\n",check(n,k));
}
return ;
}