KMP 算法详解

例题:洛谷P3375

Posted by TH911 on November 18, 2024

例题链接

$\text{Upd on 2025/11/25}$:全文重写以前写的什么鬼东西

本文中,字符串下标统一从 $1$ 开始

设字符串 $s$,$s[i]$ 表示 $s$ 的第 $i$ 个字符,以 $s[l,r]$ 表示 $s[l]s[l+1]\cdots s[r]$ 构成的子串。

前缀函数

设长度为 $n$ 的字符串 $s$,设其前缀函数 $\pi(i)$,满足 $\pi(i)$ 为 $s$ 的前缀 $s[1,i]$ 的最长 border 的长度。

定义字符串 $s$ 的 border 定义为一个不等于 $s$ 的字符串 $t$,满足 $t$ 同时是 $s$ 的前缀和后缀。

例如对于字符串 $\texttt{ABABA}$,有 $\pi(1)=\pi(2)=0,\pi(3)=1,\pi(4)=2,\pi(5)=3$。

显然 $\pi(1)=0$。

前缀函数的计算

考虑 $\pi(i)$ 满足什么条件。

维护指针 $j$,表示成功匹配到了 $s[1,j]=s[i-j,i-1]$,且 $j$ 最大。

  • 若 $s[i]=s[j+1]$,那么有 $\pi(i)=j+1$,令 $i\leftarrow i+1,j\leftarrow j+1$。

  • 否则当 $s[i]\neq s[j+1]$ 时,令 $j\leftarrow\pi(j)$,这样可以保留一个最长的 border。

    如果一直找不到 $s[i]=s[j+1]$ 直到 $j=0$,则说明无法匹配,令 $i\leftarrow i+1$,跳过这一位。

前缀函数计算的时间复杂度

考虑 $i$ 单调不降,$j$ 增加的上界同 $i$,均为 $\mathcal O(n)$。故总时间复杂度 $\mathcal O(n)$。

KMP 算法

设模式串 $s$,文本串 $t$,在 $t$ 中匹配 $s$。

KMP 算法引入了一个概念——fail 指针(失配指针)。

如图所示:

设指针 $i,j$,表示匹配到了 $t[i]$,已经成功找到了 $t[i-j+1,i]=s[1,j]$。

设已经找到了 $t[i-j,i-1]=s[1,j]$:

  • 若 $t[i]=s[j+1]$,则代表这一位匹配成功,可以继续匹配,令 $i\leftarrow i+1,j\leftarrow j+1$。

  • 否则当 $t[i]\neq s[j+1]$ 时,说明匹配失败

    那么我们需要充分利用已经匹配过的信息,从而避免无用计算,提升算法效率。

    KMP 算法设计 fail 指针 $\operatorname{fail}(j)$ 表示 $j+1$ 失配(成功匹配到 $j$)时,$j$ 应当跳到哪里。

    当 $t[i]\neq s[j+1]$ 时,令 $j\leftarrow\mathrm{fail}(j)$。

$\operatorname{fail}(j)$ 满足 $s[1,j-1]$ 的 border 最长。

那么在 KMP 中,取 $\operatorname{fail}(j)=\pi(j)$,这可以保证匹配无误,保留一个最长的 border。


其实从这里也可以看出来,求前缀函数 $\pi(i)$ 的过程本质上也可以看作让 $s$ 自己匹配自己。

时间复杂度同前缀函数分析,$\mathcal O(n+m)$,$n,m$ 为字符串长度。

例题 AC 代码

luogu P3375【模板】KMP

给定字符串 $s_1,s_2$,求 $s_2$ 在 $s_1$ 中所有出现位置和 $s_2$ 的前缀函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
constexpr const int N=1e6;
int n,m,fail[N+1+1];
char s1[N+1],s2[N+1];
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	
	cin>>(s1+1)>>(s2+1);
	n=strlen(s1+1);m=strlen(s2+1);
	for(int i=2,j=0;i<=m;){
		if(s2[i]==s2[j+1]){
			fail[i++]=++j;
		}else if(j==0){
			fail[i++]=0;
		}else{
			j=fail[j];
		}
	}
	for(int i=1,j=0;i<=n;){
		if(s1[i]==s2[j+1]){
			i++,j++;
			if(j==m){
				cout<<i-m<<'\n';
				j=fail[j];
			}
		}else if(j==0){
			i++;
		}else{
			j=fail[j];
		}
	}
	for(int i=1;i<=m;i++){
		cout<<fail[i]<<' ';
	}
	
	cout.flush();
	
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}

存档

放一段以前的 Markdown 源码,想看的自己看看好了:

```markdown ## 什么是 KMP 算法 ### 命名 首先,你要了解“KMP”的命名由来。 其实这仅仅是因为 KMP 算法由三个叫 Donald E. **K**nuth、James H. **M**orris, Jr. 和 Vaughan R. **P**ratt 的人共同提出而已。 ### 作用 参考[例题(洛谷P3375)](https://www.luogu.com.cn/problem/P3375)。 在一个字符串 $s1$(通常称之为“文本串”)中查找另一个字符串 $s2$(通常称之为“模式串”)的出现次数和出现位置。 (以下 $n,m$ 分别为 $s1,s2$ 的长度) ## 朴素算法 最优时间复杂度:$\mathcal O\left(n+m\right)$。 最坏时间复杂度:$\mathcal O(nm)$。 遍历 $s$,然后同时如果 $s_1[i]=s_2[i]$ 就继续判断 $s_1[i+1]=s_2[i+1],s_1[i+2]=s_2[i+2],\cdots,s_1[i+m-1]=s_2[m]$。在中途如果判断出 $s_1[j]\ne s_2[j]$,就跳出循环 $i$ 继续遍历。 朴素代码代码如下: ​```cpp //#include<bits/stdc++.h> #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace std; const int N=1e6; char s1[N+1],s2[N+1]; int n,m; int main(){ /*freopen("test.in","r",stdin); freopen("test.out","w",stdout);*/ scanf("%s%s",s1,s2); n=strlen(s1),m=strlen(s2); int ans=0; for(int i=0;i<n;i++){ for(int j=0;j<m;j++){ if(s1[i+j]!=s2[j])break; if(j==m-1){ //匹配到了 } } }printf("%d\n",ans); /*fclose(stdin); fclose(stdout);*/ return 0; } ​``` ## KMP 算法 ### 策略 先看看朴素算法的匹配策略: ![007](https://img2024.cnblogs.com/blog/3541769/202507/3541769-20250720134015891-75142289.webp) 再看看 KMP 算法的匹配策略: ![008](https://img2024.cnblogs.com/blog/3541769/202507/3541769-20250720134023529-932666720.webp)

图片来源:见参考链接

可以发现,KMP 算法在失配时”将模式串移到了适配位置的后方“。 显然,我们不可能真的去这么做,因为太费时了。 因此我们考虑使用两个指针 $i,j$,分别指向文本串 $s_1$ 和模式串 $s_2$。 ### 求出 $pre$ 后匹配 使 $i$ 遍历 $[1,n]$,$j$ 如果能够匹配就增加,否则就挪到另一个位置 $pre_j$。 假设我们已经求出了 $pre_j$,那么 KMP 算法将会变得无比简单。 先上代码: ​```cpp pre[0]=-1; for(int i=0,j=0;i<n;){ if(j==-1||s1[i]==s2[j])i++,j++; else j=pre[j]; if(j==m){ //匹配到了 } } ​``` 其中,$pre_0=-1$ 仅仅是一个特殊值(详见下文)。 现在我们考虑如何求出 $pre_j$,以及**怎样的 $pre_j$ 能最大程度上减少重复运算、提高效率。** ### $pre$ 数组是什么 引入一个概念:border。 定义一个字符串 $s$ 的 border 为 $s$ 的一个**非 $s$ 本身**的子串 $t$,满足 $t$ **既是 $s$ 的前缀,又是 $s$ 的后缀**。(***border 可以为空串***) 那么 $pre_i$ 便是 $s_2[1,i]$ 的**最长 border 的长度**。 为什么? 看个例子:在 $\texttt{CDACDBCD}$ 中匹配 $\texttt{CDBCD}$。 最开始长这样:$\begin{aligned}&\texttt{CDA}\color{red}\texttt{CD}\color{black}\texttt{BCD}\\ &\texttt{CDB}\color{red}\texttt{CD}\end{aligned}$。 就会把模式串位移成:$\begin{aligned}\texttt{CDA}&\color{red}\texttt{CD}\color{black}\texttt{BCD}\\ &\color{red}\texttt{CD}\color{black}\texttt{BCD}\end{aligned}$。 可以发现:此时会存在**公共部分**($\color{red}\texttt{CD}$)。 那么我们不难发现,这个公共部分必然是 $s_2[1,i]$ 的一个 border(***可以为空串!!***)。 那么,为什么要使 border 最长呢? 其实也很简单,就是**让公共部分最长**(向前跳的尽量少),因为**这样能够防止漏掉漏掉可能的匹配**。 ### 求解 $pre$ 数组 知道了这些,现在开始考虑求出 $pre_i$。 我们直接让 $s_2$ 匹配自己即可。 先放代码: ​```cpp for(int i=0,j=-1;i<m;){ if(j==-1||s2[i]==s2[j])pre[++i]=++j; else j=pre[j]; } ​``` 一个明显的事实:$i\geq j$ 恒成立。 由 $pre_j$ 的定义可得,$pre_j\leq j$ 恒成立。 则每次循环中,要么 $i,j$ 同时自增,差值不变,要么 $j$ 减少;因此,$i\geq j$ 恒成立。 因此无需担心不能够进行“自己匹配自己”。 ~~实在不能理解可以手推,毕竟我推了六张草稿纸,大部分都是画例子。。。。。。~~ 然后其实就是一个 $i$ 指针在右边,$j$ 指针在左边,如果 $s_2[i]=s_2[j]$,那么 border 的长度自然增加,否则就是“**失配**”。 “失配”了,自然就有 $j\leftarrow pre_j$。 ### 最长 border 为空 特殊值 $-1$ 的用途,这种情况 $pre$ 指向第一个字符 $s_2[0]$ 即可。 注意:**如果你的字符串下标从 $1$ 开始,特殊值请设置为 $0$**。 因为在 `pre[++i]=++j` 中,若是 $j=-1$,则 $pre_{i+1}=0$,但是 $0$ 是一个空值。 ## 例题 AC 代码 ​```cpp //#include<bits/stdc++.h> #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace std; const int N=1e6; char s1[N+1],s2[N+1]; int n,m,pre[N+1]={-1}; int main(){ /*freopen("test.in","r",stdin); freopen("test.out","w",stdout);*/ scanf("%s%s",s1,s2); n=strlen(s1),m=strlen(s2); for(int i=0,j=-1;i<m;){ if(j==-1||s2[i]==s2[j])pre[++i]=++j; else j=pre[j]; } for(int i=0,j=0;i<n;){ if(j==-1||s1[i]==s2[j])i++,j++; else j=pre[j]; if(j==m)printf("%d\n",i-m+1); } for(int i=1;i<=m;i++)printf("%d ",pre[i]); putchar(10); /*fclose(stdin); fclose(stdout);*/ return 0; } ​``` ## 参考链接 <https://zhuanlan.zhihu.com/p/83334559>(图片来源) <https://www.cnblogs.com/fswly/p/17959786> <https://www.cnblogs.com/zzuuoo666/p/9028287.html> ``` </details>