HDU3065 (病毒侵袭持续中)[AC自动机]

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=3065

这题是AC自动机模板题,统计每个模板串在文本串中的出现次数。下面贴上模板。

需要注意的是找到一个单词时,可以是直接找到了这个单词的单词节点,也可以是通过这个节点经由后缀链接跳转到了其它单词节点。因为有类似这种情况:10是1010的后缀,但如果匹配到1010显然也应该匹配一次10,这时应该通过后缀链接跳转过去修改10对应的匹配数。
注:后缀链接指向当前节点能通过fail指针跳转到的上一个单词节点的标号。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <stack>
using namespace std;
typedef long long ll;
const int maxn=50009;

struct Trie {
int ch[maxn][26];
int f[maxn];
int last[maxn];
int val[maxn];
ll num[maxn];
int index[1009];
int sz;
void init (){
sz=1;
memset(ch[0],0,sizeof(ch[0]));
memset(f,0,sizeof(f));
memset(last,0,sizeof(last));
memset(val,0,sizeof(val));
memset(num,0,sizeof(num));
memset(index,0,sizeof(index));
}
int idx(char c) { return c-'A'; }
void insert(char *s,int v) {
int u=0,n=strlen(s);
for(int i=0;i<n;i++) {
int c=idx(s[i]);
if(!ch[u][c]) {
memset(ch[sz],0,sizeof(ch[sz]));
val[sz]=0;
ch[u][c]=sz++;
}
u=ch[u][c];
}
val[u]+=1;
index[v]=u;
}
void find(char *T) {
int n=strlen(T);
int j=0;
for(int i=0;i<n;i++) {
int c=idx(T[i]);
if(c<0 || c>=26) {
j=0;
continue;
}
while(j && !ch[j][c]) j=f[j];
j=ch[j][c];
if(val[j]) num[j]++;
int k=j;
while(last[k]) { //通过后缀链接修改还能匹配到的单词节点
k=last[k];
num[k]++;
}
}
}
int getFail() {
queue<int> q;
while(!q.empty()) q.pop();
f[0]=0;
for(int c=0;c<26;c++) {
int u=ch[0][c];
if(u) { f[u]=0;q.push(u);last[u]=0; }
}
while(!q.empty()) {
int r=q.front(); q.pop();
for(int c=0;c<26;c++) {
int u=ch[r][c];
if(!u) continue;
q.push(u);
int v=f[r];
while(v && !ch[v][c]) v=f[v];
f[u]=ch[v][c];
last[u]=val[f[u]] ? f[u] : last[f[u]]; //预处理后缀链接
}
}
}
}tr;

char s[1009][60];
char str[2000009];
int main() {
int n;
while(~scanf("%d",&n)) {
tr.init();
for(int i=1;i<=n;i++) {
scanf("%s",s[i]);
tr.insert(s[i],i);
}
tr.getFail();
scanf("%s",str);
tr.find(str);

for(int i=1;i<=n;i++) {
if(tr.num[tr.index[i]]>0)
printf("%s: %I64d\n",s[i],tr.num[tr.index[i]]);
}
}
}