CJOJ 【DP合集】最长上升序列2 — LIS2

时间:2022-02-20 22:47:01

题面

已知一个 1 ∼ N 的排列的最长上升子序列长度为 K ,求合法的排列个数。

好题(除了我想不出来我应该找不到缺点), 想一想最长上升子序列的二分做法, 接在序列后面或者替换.

所以对于每一个位置的数, 都只有三种状态, 分别是: 在我们的替换的序列中, 使用过却被序列抛弃的, 未使用过的.

考虑到数据范围比较小, 我们可以使用3进制的状压dp, 0代表未使用过, 1代表被抛弃的, 2表示在替换序列中的.

但是由于这是一个1到n的排列, 所以每个数都只能使用一次, 我们还要记录每个数是否使用过, 继续考虑状压dp, 0代表使用过, 1代表未使用.

因为有两个状压dp, 个人觉得记忆化搜索比较好写, 记忆化搜索有三个参数, \(dfs(s, t, len)\), \(s\)代表替换的序列的状态, \(t\)代表可以使用的数的状态, len表示当前可替换序列的长度, 下面根据dfs代码分析:

long long dfs(int s, int t, int len)//状态已经在上面讲过了, 初始s等于0, 代表替换序列中没有数, 毕竟还没有开始选对不对, t初始为(1 << n) - 1, 既然还没开始选当然每个数都可以选啦, len为长度, 初始为0, 理由仍然是还没有开始选
{
if(t == 0) return len == k; //由于每个位置都只能选一个数, 所以选完了就代表当前序列构造完了
if(f[s] != -1) return f[s]; //当前状态已出现过, 即曾经选过的某些数, 抛弃过某些数的情况曾经出现过, 可以直接用了, 这也是记忆化搜索的精髓
f[s] = 0; //开始选了, 上面写-1而不是0的原因是有些情况并没有合法的序列, 此时f[s]为0, 会重复搜索
int rt = t;//复制一下, 你不可能直接减t对不
int pos = 0; //可以替换的位置
while(rt)
{
int p = rt & (-rt); //取可以使用的数
rt -= p; //减一下, 不然下次减掉还是这个数, 取取弹弹无穷尽也......
while(pos < len && stack[pos + 1] <= num[p]) pos++; //由于你是从t的二进制表示的低位往高位走, 即取的数是从小到大的, 那么当前放的位置肯定不会比上次放的位置前面, 注意题面, 为上升子序列就要放在更后面才能满足上升, 自己想一下吧, 可能解释的不太清楚
int rem = stack[pos + 1]; //记录一下
stack[pos + 1] = num[p]; //替换
f[s] += dfs(s + 2 * pow[num[p]] - pow[rem], t - p, len + (pos == len)); //往下搜索
stack[pos + 1] = rem; //回溯
}
return f[s]; //return
}

还有一些小细节, 比如说由于二进制的第一位是(1 or 0 << 0), 所以我们将1到n的序列变为0到n-1的序列, 这里在代码中会讲.

代码

#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std; int n, k, pow[20], num[1 << 15], stack[20];
long long f[15000000]; inline int read()
{
int x = 0, w = 1;
char c = getchar();
while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
return x * w;
} long long dfs(int s, int t, int len)
{
if(t == 0) return len == k;
if(f[s] != -1) return f[s];
f[s] = 0;
int rt = t;
int pos = 0;
while(rt)
{
int p = rt & (-rt);
rt -= p;
while(pos < len && stack[pos + 1] <= num[p]) pos++;
int rem = stack[pos + 1];
stack[pos + 1] = num[p];
f[s] += dfs(s + 2 * pow[num[p]] - pow[rem], t - p, len + (pos == len));
stack[pos + 1] = rem;
}
return f[s];
} int main()
{
n = read(); k = read();
memset(f, -1, sizeof(f));
for(int i = 0; i < n; i++) { num[1 << i] = i; pow[i] = i ? pow[i - 1] * 3 : 1; }//这里就是上面所说的小技巧
printf("%lld\n", dfs(0, (1 << n) - 1, 0));
return 0;
}

还有哪些不懂的可以在下面问一下...