【知识总结】动态 DP

时间:2023-12-31 15:35:56

勾起了我悲伤的回忆 —— NOIP2018 316pts ……

主要思想:将 DP 过程分解为方便单点修改和一个区间合并的操作(通常类似矩阵乘法),然后用数据结构(通常为线段树)维护。

例:给定一个长为 \(n\) 的整数序列,相邻两个数最多选一个,有 \(m\) 次修改序列中的一个数,求每次修改后选出数之和的最大值。

\(n,m\leq 10^5\) 。

如果不会做不带修改的情况,请默默摁 Ctrl + w 然后去学 DP 入门

如果不带修改,明显设 \(f_{i,0/1}\) 表示当第 \(i\) 个点选 (0) / 不选 (1) 时,前 \(i\) 个点的和的最大值。于是有如下转移方程:

\[f_{i,0}=f_{i-1,1}
\]

\[f_{i,1}=\max(f_{i-1,0},f_{i-1,1})+a_i
\]

如果加入修改操作呢?只有这两个 DP 方程比较难办,因为修改一个值就要重新计算后面的所有答案。GG

接下来是「动态 DP 」中最巧妙的部分:考虑用一个矩阵来表示从 \(i-1\) 点向 \(i\) 点转移,用某个表示「初始状态」的矩阵依次乘上每个点的转移就是答案。因为矩阵乘法有结合律,所以可以把答案表示成「初始状态」乘上「修改点前面的矩阵乘积」乘上「当前位置修改后的矩阵」乘上「修改点后面的矩阵乘积」。这样只需要用线段树单点修改和查询区间乘积(事实上这道题只需要查全局乘积)即可。

然而,这道题中转移的运算并不是加和乘,尤其是其中还有一个碍眼的求最大值。但我们可以把矩阵乘法的定义稍加修改,把原来两个整数的「乘法」改为两个整数的加法,「加法」改为对两个整数取最大值。这样我们就构造如下转移矩阵:

\[\begin{bmatrix}
f_{i-1,0}&f_{i-1,1}
\end{bmatrix}
\begin{bmatrix}
0&a_i\\
0&-\infty\\
\end{bmatrix}=
\begin{bmatrix}
f_{i,0}&f_{i,1}\\
\end{bmatrix}\]

还有一个很多人没考虑过的细节 (可能是大佬们认为这个问题太显然不需要考虑) :这个「初始状态」是什么呢?对于这道题,前一个数如果不选是不影响当前决策的,而如果选了的话就会造成一个当前点不能选的「约束」。而第一个点无论如何都不会受到这种「约束」,所以第一个点的「前一个点」应该被看作「没有选」,即初始状态为 \(\begin{bmatrix}0&-\infty\end{bmatrix}\) 。

我们把这个问题扩展到树上,即每条边的两端点中至少选一个点(洛谷 4719【模板】动态 DP )。考虑树链剖分来转化成序列问题。设 \(f_{i,0/1}\) 表示 \(i\) 点选 / 不选时 \(i\) 点子树中的最大权值和,\(g_{i,0/1}\) 表示 \(i\) 点选 / 不选时 \(i\) 点子树除 \(s_i\) 的子树以外的部分中的最大权值和,其中 \(s_i\) 是 \(i\) 的重儿子。对于一条重链有如下方程:

\[\begin{bmatrix}
f_{s_i,0}&f_{s_i,1}
\end{bmatrix}
\begin{bmatrix}
g_{i,0}&g_{i,1}\\
g_{i,0}&-\infty\\
\end{bmatrix}=
\begin{bmatrix}
f_{i,0}&f_{i,1}\\
\end{bmatrix}\]

这样,每个点的答案是「初始状态」乘上它到所在重链末尾的矩阵乘积。

至于具体实现,可以开始先一遍 DP 算出所有的 \(f\) 和 \(g\) 。每次修改时沿着重链向上爬,暴力修改链首父亲的 \(g\) 值。链首到链首父亲的边是一条轻边,所以这样每次修改一个点时要更新 \(g\) 值的点的数量约等于当前点到根的路径上的轻边数量(可能有加一减一之类的细节),是 \(O(\log n)\) 。因此总复杂度 \(O(mlog^2n)\) 。

和上面类似的分析,初始状态(叶子节点那个不存在的重儿子的 \(f\) 值)是 \(\begin{bmatrix}0&-\infty\end{bmatrix}\) 。用这个东西去乘相当于取原矩阵的第一行,所以不需要「显式」地乘。

代码:

很抱歉我代码里的矩阵行列和上文是反的,所有矩阵乘法的顺序也是反的我也不知道怎么回事 QAQ 。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cctype>
using namespace std; namespace zyt
{
template<typename T>
inline bool read(T &x)
{
char c;
bool f = false;
x = 0;
do
c = getchar();
while (c != EOF && c != '-' && !isdigit(c));
if (c == EOF)
return false;
if (c == '-')
f = true, c = getchar();
do
x = x * 10 + c - '0', c = getchar();
while (isdigit(c));
if (f)
x = -x;
return true;
}
template<typename T>
inline void write(T x)
{
static char buf[20];
char *pos = buf;
if (x < 0)
putchar('-'), x = -x;
do
*pos++ = x % 10 + '0';
while (x /= 10);
while (pos > buf)
putchar(*--pos);
}
const int N = 1e5 + 10, INF = 0x3f3f3f3f;
int n, m, head[N], ecnt, w[N], size[N], son[N], fa[N], dfn[N], dfncnt, top[N], f[N][2], g[N][2], end[N], pos[N];
struct edge
{
int to, next;
}e[N << 1];
void add(const int a, const int b)
{
e[ecnt] = (edge){b, head[a]}, head[a] = ecnt++;
}
void dfs(const int u, const int f)
{
fa[u] = f, size[u] = 1;
for (int i = head[u]; ~i; i = e[i].next)
{
int v = e[i].to;
if (v == f)
continue;
dfs(v, u);
size[u] += size[v];
if (size[v] > size[son[u]])
son[u] = v;
}
}
void dfs2(const int u, const int t)
{
top[u] = t, dfn[u] = ++dfncnt, pos[dfncnt] = u, end[t] = u;
if (son[u])
dfs2(son[u], t);
for (int i = head[u]; ~i; i = e[i].next)
{
int v = e[i].to;
if (v == fa[u] || v == son[u])
continue;
dfs2(v, v);
}
}
void dfs3(const int u)
{
g[u][0] = 0, g[u][1] = w[u];
for (int i = head[u]; ~i; i = e[i].next)
{
int v = e[i].to;
if (v == fa[u] || v == son[u])
continue;
dfs3(v);
g[u][0] += max(f[v][0], f[v][1]);
g[u][1] += f[v][0];
}
f[u][0] = g[u][0], f[u][1] = g[u][1];
if (son[u])
{
dfs3(son[u]);
f[u][0] += max(f[son[u]][0], f[son[u]][1]);
f[u][1] += f[son[u]][0];
}
}
struct Matrix
{
int data[2][2], n, m;
Matrix(const int _n = 0, const int _m = 0)
: n(_n), m(_m)
{
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
data[i][j] = -INF;
}
Matrix operator * (const Matrix &b) const
{
Matrix ans(n, b.m);
for (int i = 0; i < n; i++)
for (int k = 0; k < m; k++)
for (int j = 0; j < b.m; j++)
ans.data[i][j] = max(ans.data[i][j], data[i][k] + b.data[k][j]);
return ans;
}
}val[N];
namespace Segment_Tree
{
struct node
{
Matrix m;
}tree[N << 2];
void update(const int rot)
{
tree[rot].m = tree[rot << 1].m * tree[rot << 1 | 1].m;
}
void build(const int rot, const int lt, const int rt)
{
tree[rot].m = Matrix(2, 2);
if (lt == rt)
return void(tree[rot].m = val[pos[lt]]);
int mid = (lt + rt) >> 1;
build(rot << 1, lt, mid), build(rot << 1 | 1, mid + 1, rt);
update(rot);
}
void change(const int rot, const int lt, const int rt, const int p)
{
if (lt == rt)
return void(tree[rot].m = val[pos[p]]);
int mid = (lt + rt) >> 1;
if (p <= mid)
change(rot << 1, lt, mid, p);
else
change(rot << 1 | 1, mid + 1, rt, p);
update(rot);
}
Matrix query(const int rot, const int lt, const int rt, const int ls, const int rs)
{
if (ls <= lt && rt <= rs)
return tree[rot].m;
int mid = (lt + rt) >> 1;
if (rs <= mid)
return query(rot << 1, lt, mid, ls, rs);
else if (ls > mid)
return query(rot << 1 | 1, mid + 1, rt, ls, rs);
else
return query(rot << 1, lt, mid, ls, rs) * query(rot << 1 | 1, mid + 1, rt, ls, rs);
}
}
int work()
{
using namespace Segment_Tree;
read(n), read(m);
memset(head, -1, sizeof(int[n + 1]));
for (int i = 1; i <= n; i++)
read(w[i]), val[i] = Matrix(2, 2);
for (int i = 1; i < n; i++)
{
int a, b;
read(a), read(b);
add(a, b), add(b, a);
}
dfs(1, 0), dfs2(1, 1), dfs3(1);
for (int i = 1; i <= n; i++)
val[i].data[0][0] = val[i].data[0][1] = g[i][0], val[i].data[1][0] = g[i][1], val[i].data[1][1] = -INF;
build(1, 1, n);
while (m--)
{
int u, x;
read(u), read(x);
val[u].data[1][0] += x - w[u];
w[u] = x;
Matrix a, b;
while (u)
{
a = query(1, 1, n, dfn[top[u]], dfn[end[top[u]]]);
change(1, 1, n, dfn[u]);
b = query(1, 1, n, dfn[top[u]], dfn[end[top[u]]]);
u = fa[top[u]];
val[u].data[0][0] += max(b.data[0][0], b.data[1][0]) - max(a.data[0][0], a.data[1][0]);
val[u].data[0][1] = val[u].data[0][0];
val[u].data[1][0] += b.data[0][0] - a.data[0][0];
}
Matrix ans = query(1, 1, n, dfn[1], dfn[end[1]]);
write(max(ans.data[0][0], ans.data[1][0])), putchar('\n');
}
return 0;
}
}
int main()
{
#ifdef BlueSpirit
freopen("4719.in", "r", stdin);
#endif
return zyt::work();
}