// 区间DP标准模板
// 状态: dp[i][j] 表示区间 [i, j] 上的最优解
// 初始化: 长度为1的区间
for (int i = 0; i < n; i++) {
dp[i][i] = 0; // 根据具体问题初始化
}
// 阶段: 枚举区间长度 len
for (int len = 2; len <= n; len++) {
// 状态: 枚举区间起点 i
for (int i = 0; i + len - 1 < n; i++) {
int j = i + len - 1; // 区间终点
dp[i][j] = INF; // 初始化为较大值(求最小值时)
// 决策: 枚举分割点 k
for (int k = i; k < j; k++) {
// 合并左右子区间的代价
int cost = mergeCost(i, k, j);
dp[i][j] = min(dp[i][j], dp[i][k] + dp[k+1][j] + cost);
}
}
}
// 答案: dp[0][n-1] 为整个序列的最优解
模板的要点分析:
外层循环 len:从 2 到 n,确保按区间长度递增计算,这是区间DP最核心的遍历顺序。
中层循环 i:枚举区间起点,j = i + len - 1 自动计算终点。注意边界条件 i + len - 1 < n。
内层循环 k:枚举分割点,k 的取值范围是 [i, j-1]。
初始化:长度为 1 的区间 dp[i][i] 通常初始化为 0(不合并时代价为零)或初始值。
INF 设置:根据求最大值还是最小值,初始化为 -INF 或 INF。
在 Python 中同样可以方便地实现区间DP,下面给出对应的 Python 版本:
# Python 区间DP模板
n = len(arr)
dp = [[0] * n for _ in range(n)]
# 初始化长度为1的区间
for i in range(n):
dp[i][i] = 0
# 按区间长度从小到大遍历
for length in range(2, n + 1): # 区间长度
for i in range(0, n - length + 1): # 区间起点
j = i + length - 1 # 区间终点
dp[i][j] = float('inf')
# 枚举分割点
for k in range(i, j):
cost = merge_cost(i, k, j)
dp[i][j] = min(dp[i][j],
dp[i][k] + dp[k+1][j] + cost)
print(dp[0][n-1]) # 整个区间的最优解
提示:在某些问题中,分割点 k 的枚举范围可以进行优化(如四边形不等式优化),k 不一定需要从 i 到 j-1 全部枚举。但对于初学者,建议先从完整枚举开始理解。
# 括号匹配 - Python实现
def longest_valid_bracket(s):
n = len(s)
dp = [[0] * n for _ in range(n)]
# 判断是否匹配
def match(a, b):
return (a == '(' and b == ')') or \
(a == '[' and b == ']') or \
(a == '{' and b == '}')
for length in range(2, n + 1):
for i in range(0, n - length + 1):
j = i + length - 1
if match(s[i], s[j]):
inner = dp[i+1][j-1] if i+1 <= j-1 else 0
dp[i][j] = inner + 2
for k in range(i, j):
dp[i][j] = max(dp[i][j], dp[i][k] + dp[k+1][j])
return dp[0][n-1]
# 示例
s = "([{}])()"
print(longest_valid_bracket(s)) # 输出: 8 (全部匹配)
s2 = "([)]"
print(longest_valid_bracket(s2)) # 输出: 2 (最长合法子序列为 "()" 或 "[]")
# 回文子串分割 - Python实现
def min_cut(s):
n = len(s)
if n == 0:
return 0
# 预处理: is_pal[i][j] 表示 s[i..j] 是否为回文串
is_pal = [[False] * n for _ in range(n)]
for i in range(n):
is_pal[i][i] = True
for i in range(n - 1):
is_pal[i][i+1] = (s[i] == s[i+1])
for length in range(3, n + 1):
for i in range(0, n - length + 1):
j = i + length - 1
is_pal[i][j] = (s[i] == s[j] and is_pal[i+1][j-1])
# 区间DP: dp[i] 表示前缀 s[0..i] 的最少分割次数
dp = [float('inf')] * n
for i in range(n):
if is_pal[0][i]:
dp[i] = 0
else:
for j in range(i):
if is_pal[j+1][i]:
dp[i] = min(dp[i], dp[j] + 1)
return dp[n-1]
# 示例
s = "aab"
print(min_cut(s)) # 输出: 1 -> "aa" + "b"
s2 = "abccba"
print(min_cut(s2)) # 输出: 0 -> 本身就是回文串
利用决策单调性,我们在枚举分割点时,不需要从 i 枚举到 j-1,只需从 opt[i][j-1] 枚举到 opt[i+1][j]。
# 四边形不等式优化后的区间DP模板
# 适用于满足决策单调性的问题(如石子合并)
n = len(arr)
dp = [[0] * n for _ in range(n)]
opt = [[0] * n for _ in range(n)] # 记录最优分割点
# 初始化
for i in range(n):
dp[i][i] = 0
opt[i][i] = i # 长度为1的区间,分割点为自己
# 按区间长度遍历
for length in range(2, n + 1):
for i in range(0, n - length + 1):
j = i + length - 1
dp[i][j] = float('inf')
# 利用决策单调性缩小枚举范围
for k in range(opt[i][j-1], opt[i+1][j] + 1):
if k < j: # 确保分割点有效
cost = dp[i][k] + dp[k+1][j] + range_sum(i, j)
if cost < dp[i][j]:
dp[i][j] = cost
opt[i][j] = k
print(dp[0][n-1])
# 四边形不等式优化 - 石子合并完整实现
def stone_merge_optimized(stones):
n = len(stones)
prefix = [0] * (n + 1)
for i in range(1, n + 1):
prefix[i] = prefix[i-1] + stones[i-1]
def range_sum(i, j):
return prefix[j+1] - prefix[i]
dp = [[0] * n for _ in range(n)]
opt = [[0] * n for _ in range(n)]
for i in range(n):
dp[i][i] = 0
opt[i][i] = i
for length in range(2, n + 1):
for i in range(0, n - length + 1):
j = i + length - 1
dp[i][j] = float('inf')
# 决策单调性: k 的范围被缩小
left = opt[i][j-1]
right = opt[i+1][j]
for k in range(left, right + 1):
if k < j:
cost = dp[i][k] + dp[k+1][j] + range_sum(i, j)
if cost < dp[i][j]:
dp[i][j] = cost
opt[i][j] = k
return dp[0][n-1]
# 大数据量测试
import random, time
stones = [random.randint(1, 100) for _ in range(500)]
start = time.time()
result = stone_merge_optimized(stones)
elapsed = time.time() - start
print(f"结果: {result}, 耗时: {elapsed:.3f}秒")
# O(n²) 优化后,500个元素只需约0.1秒
算法思路:通过树形DP计算每个节点的子树大小,然后计算删除每个节点后最大连通块的大小。对于节点 u,删除 u 后会产生若干连通块:每个子节点 v 所在的子树(大小为 sz[v]),以及"上方"的部分(大小为 n - sz[u])。这些连通块大小的最大值即为删除 u 后的代价,取代价最小的节点即为重心。
# 树的重心 - Python实现
def find_centroid(n, edges):
from collections import defaultdict
adj = defaultdict(list)
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
sz = [0] * n # 子树大小
max_part = [0] * n # 删除节点后最大连通块大小
visited = [False] * n
def dfs(u):
visited[u] = True
sz[u] = 1
max_part[u] = 0
for v in adj[u]:
if not visited[v]:
dfs(v)
sz[u] += sz[v]
# 子树的连通块大小
max_part[u] = max(max_part[u], sz[v])
# "上方"部分的大小
max_part[u] = max(max_part[u], n - sz[u])
dfs(0)
# 找到 max_part 最小的节点
centroid = min(range(n), key=lambda x: max_part[x])
return centroid, max_part[centroid]
# 示例: 星形图, 0为中心
n = 6
edges = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5)]
centroid, cost = find_centroid(n, edges)
print(f"重心: {centroid}, 最大子树大小: {cost}")
# 输出: 重心: 0, 最大子树大小: 5
# 示例: 链 0-1-2-3-4
n = 5
edges2 = [(0, 1), (1, 2), (2, 3), (3, 4)]
centroid2, cost2 = find_centroid(n, edges2)
print(f"重心: {centroid2}, 最大子树大小: {cost2}")
# 输出: 重心: 2, 最大子树大小: 2