给定一个长度为 n 的字符串 s,Z函数(Z数组)是一个长度为 n 的数组,记为 Z[0..n-1]。对于每个位置 i (i > 0),Z[i] 定义为从位置 i 开始的子串 s[i..n-1] 与整个字符串 s 的最长公共前缀(LCP)的长度。约定 Z[0] = 0(或 n,取决于具体实现习惯,但 Z[0] 在计算中通常不被使用)。
如果采用朴素方法计算Z数组——对每个位置 i 都从头开始逐个字符比较——那么总时间复杂度是 O(n^2)。Z算法的精妙之处在于维护一个"匹配区间" [l, r],使得 s[l..r] 是某个前缀 s[0..r-l] 的复制,且 r 尽可能大。利用这个区间信息,可以将大量位置的Z值直接推导出来,而无需重新比较。
区间维护的核心思想
算法维护两个指针 l 和 r,表示当前已经匹配到的最右区间 [l, r],满足 s[l..r] = s[0..r-l](即该区间是前缀的一个副本)。初始时 l = r = 0。
对于每个 i (i > 0),分两种情况处理:
情况A:i ≤ r(在已有区间内)
此时可以利用已知的 Z[i-l] 来初始化 Z[i]。
由于 s[l..r] = s[0..r-l],所以 s[i..r] = s[i-l..r-l]。
令 Z[i] = min(Z[i-l], r - i + 1),然后尝试向右扩展。
情况B:i > r(在已知区间之外)
没有任何已知信息可用,需要从 i 位置重新开始逐个字符比较。
设置 Z[i] = 0,然后暴力扩展。
无论哪种情况,在获得初步的 Z[i] 值之后,都需要尝试向右扩展:只要 s[Z[i]] == s[i + Z[i]],就将 Z[i] 加1。扩展完成后,如果 i + Z[i] - 1 > r,就更新区间 l = i, r = i + Z[i] - 1。
Python实现
defz_function(s):
n = len(s)
Z = [0] * n
l = r = 0
for i inrange(1, n):
if i <= r:
Z[i] = min(r - i + 1, Z[i - l])
while i + Z[i] < n and s[Z[i]] == s[i + Z[i]]:
Z[i] += 1
if i + Z[i] - 1 > r:
l, r = i, i + Z[i] - 1
return Z
算法正确性理解:关键在于 Z[i-l] 的含义——它告诉我们从位置 i-l 开始能匹配多长的前缀。由于 s[l..r] = s[0..r-l],位置 i 处的匹配情况应该和位置 i-l 处完全一致——至少到 r 为止。如果 Z[i-l] < r - i + 1,说明匹配在区间内就已经结束了,可以直接取用;否则需要继续向右扩展。
时间复杂度证明
Z算法的时间复杂度是 O(n)。关键在于证明每个字符最多被比较一次(严格说,每个字符最多引发一次成功比较和一次失败比较)。观察 r 指针的变化:每次成功比较都会使 r 增加至少1,而 r 最多从0增加到 n-1。因此成功比较的总次数是 O(n)。失败比较每次循环最多发生一次(即 s[Z[i]] != s[i + Z[i]] 的那一刻),而循环共有 n-1 次,所以失败比较也是 O(n)。总比较次数为 O(n)。
四、字符串匹配:Z算法的核心应用
Z算法最经典的应用是线性时间字符串匹配。给定一个模式串 P 和一个文本串 T,要找出 P 在 T 中的所有出现位置。使用Z算法解决此问题的思路非常巧妙:构建一个拼接字符串 S = P + "$" + T,其中 "$" 是一个在 P 和 T 中都不会出现的分隔符。然后计算 S 的Z数组。
为何这样做有效?对于拼接串 S 中的每个位置 i(位于 T 部分),Z[i] 表示从 i 开始的子串与 P 的最长公共前缀长度(因为是跨过分隔符与前缀进行比较)。如果 Z[i] = len(P),就说明在文本串 T 的位置 i - len(P) - 1 处找到了一个完整的模式匹配。
defz_match(text, pattern):
"""在 text 中查找 pattern 的所有出现位置"""
s = pattern + "$" + text
Z = z_function(s)
m = len(pattern)
result = []
for i inrange(m + 1, len(s)):
if Z[i] == m:
result.append(i - m - 1)
return result
# 示例
text = "abcabcabc"
pattern = "abc"print(z_match(text, pattern)) # 输出: [0, 3, 6]
如果一个字符串 s 是由某个长度为 p 的子串重复多次构成的(即 s 是周期串),那么 p 就是 s 的一个周期。利用Z数组,可以快速判断并在 O(1) 时间内求出最小周期:
defmin_period(s):
n = len(s)
Z = z_function(s)
for p inrange(1, n):
if n % p == 0 and Z[p] == n - p:
return p
return n
# 示例print(min_period("abcabcabc")) # 3print(min_period("aaaaa")) # 1print(min_period("abcdef")) # 6 (无周期)
原理:如果 p 是 s 的一个周期,那么 s[p..n-1] = s[0..n-p-1],也就是说从位置 p 开始的子串与长度为 n-p 的前缀完全相同。这恰恰意味着 Z[p] >= n - p。再加上 n % p == 0(确保周期完整覆盖),即可判断。
6.2 求每个位置的最短前缀
给定一个字符串 s,对于每个位置 i,我们希望找到以 i 结尾的最短子串,使得该子串同时也是 s 的前缀。这个问题在文本压缩和模式发现中有重要应用。借助Z数组可以高效求解:
defshortest_prefix_ending_at(s):
n = len(s)
rev = s[::-1]
Z_rev = z_function(rev)
# Z_rev[i] 表示 rev[0..] 与 rev[i..] 的LCP,# 对应原串中 suffix 与 prefix 的关系
result = [0] * n
for i inrange(n):
match_len = Z_rev[n - 1 - i]
if match_len > i + 1:
match_len = i + 1
result[i] = match_len
return result
6.3 字符串压缩
给定一个字符串 s,找到它的最小压缩表示。即找到最短的字符串 t,使得 s 可以由 t 重复若干次得到(t 是 s 的最小周期子串)。这个问题其实就是求最小周期,在上面已经给出解法。
defcompress(s):
p = min_period(s)
return s[:p]
6.4 求每个前缀的出现次数
对于字符串 s 的每个前缀 s[0..k],计算它在 s 中作为子串出现的总次数(包括自身)。利用Z数组可以高效统计:
defprefix_occurrences(s):
n = len(s)
Z = z_function(s)
cnt = [0] * (n + 1)
for i inrange(n):
cnt[Z[i]] += 1
for i inrange(n - 1, -1, -1):
cnt[i] += cnt[i + 1]
result = {}
for i inrange(n):
length = i + 1
result[length] = cnt[length] + 1 # +1 计自身return result
# 示例:s = "abcabc"# 前缀 "a": 出现2次("a"在位置0和3)# 前缀 "ab": 出现2次# 前缀 "abc": 出现2次# 前缀 "abca": 出现1次# ...
defcount_distinct_substrings(s):
"""统计字符串 s 的不同子串数量"""
n = len(s)
total = 0
cur = ""for ch in s:
cur += ch
rev = cur[::-1]
Z = z_function(rev)
max_prefix = max(Z) if Z else 0
# 新增的不同子串数 = 当前长度 - 最大LCP
total += len(cur) - max_prefix
return total
defz_to_pi(Z):
n = len(Z)
pi = [0] * n
for i inrange(1, n):
if Z[i] > 0:
for j inrange(i + Z[i] - 1, i - 1, -1):
if pi[j] > 0:
break
pi[j] = j - i + 1
return pi
# 更高效的 O(n) 实现defz_to_pi_fast(Z):
n = len(Z)
pi = [0] * n
for i inrange(1, n):
pi[i] = pi[i - 1]
while pi[i] > 0 and Z[i] <= pi[i]:
pi[i] = pi[pi[i] - 1]
if Z[i] > pi[i]:
pi[i] = Z[i]
return pi
# -*- coding: utf-8 -*-"""Z Algorithm - 线性时间字符串匹配与模式分析工具箱"""defz_function(s):
"""计算字符串 s 的 Z 数组(Z函数)"""
n = len(s)
Z = [0] * n
l = r = 0
for i inrange(1, n):
if i <= r:
Z[i] = min(r - i + 1, Z[i - l])
while i + Z[i] < n and s[Z[i]] == s[i + Z[i]]:
Z[i] += 1
if i + Z[i] - 1 > r:
l, r = i, i + Z[i] - 1
return Z
defz_match(text, pattern):
"""返回 pattern 在 text 中所有匹配的起始位置列表"""
s = pattern + "\x00" + text
Z = z_function(s)
m = len(pattern)
return [i - m - 1 for i inrange(m + 1, len(s)) if Z[i] == m]
defmin_period(s):
"""返回字符串的最小周期长度"""
n = len(s)
Z = z_function(s)
for p inrange(1, n):
if n % p == 0 and Z[p] == n - p:
return p
return n
defz_match_count(text, pattern):
"""返回 pattern 在 text 中的出现次数"""returnlen(z_match(text, pattern))
if __name__ == "__main__":
# 测试 Z 函数
s1 = "aaabaaab"print(f"s = {s1!r}")
print(f"Z = {z_function(s1)}")
# 测试字符串匹配
text = "AABAACAADAABAABA"
pattern = "AABA"
matches = z_match(text, pattern)
print(f"模式 '{pattern}' 在文本中的位置: {matches}")
print(f"出现次数: {z_match_count(text, pattern)}")
# 测试周期
s2 = "abcabcabcabc"print(f"'{s2}' 的最小周期: {min_period(s2)}")
# 处理大字符串时的优化写法defz_function_fast(s):
n = len(s)
Z = [0] * n
l = r = 0
# 将字符串转为列表以加快索引速度(Python优化技巧)
arr = list(s)
for i inrange(1, n):
if i <= r:
Z[i] = min(r - i + 1, Z[i - l])
while i + Z[i] < n and arr[Z[i]] == arr[i + Z[i]]:
Z[i] += 1
if i + Z[i] - 1 > r:
l, r = i, i + Z[i] - 1
return Z
九、核心要点总结
Z算法核心要点:
Z数组定义:Z[i] 表示从位置 i 开始的子串与整个字符串前缀的最长公共长度。Z[0] 通常约定为0。