Kattis - nvwls (AC自动机last优化 + dp)
Posted acerkoo
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Kattis - nvwls (AC自动机last优化 + dp)相关的知识,希望对你有一定的参考价值。
题意
给出一个字典,每个单词去掉元音字母 A、E、I、O、U
之后形成一个新字典。
先给出一个只有辅音组成的串,用原字典中的单词还原该串,若存在多种还原方式,输出还原后元音字母数量最多的那种,若依旧多种,则任意输出。
思路
ac自动机fail树上跑dp的一眼套路题。
总结一下遇到的坑:
- 多个单词去掉元音字母之后形成的新单词相同。
- 若原单词只有辅音字母。
- dp过程中未保证完全还原辅音串。
- 特殊样例将 跳fail的过程卡成了 (n^2) 。
解决办法:
- 在字典树的结尾节点保存编号时,保存最大值。
- 将初值从 $0 $ 改为 (-1) 。
- dp数组初值同样改为 (-1),(dp[0] = 0) ,dp只能由非(-1)的状态转移。
- 对fail数组进行 last优化。
last[u] = len[fail[u]]? fail[u]: last[fail[u]];
。
last优化
普通方法将建图+匹配的复杂度成功优化为了 ??(∑??+??) ,但是,匹配成功时的计数也是需要跳fail边的。然而,为了跳到一个结束节点,我们可能需要中途跳到很多没用的伪结束节点:
如果一个节点的fail指向一个结尾节点,那么这个点也成为一个(伪)结尾节点。在匹配时,如果遇到结尾节点,就进行相应的计数处理。
这里面就又有优化的余地了:对于不是真正结束节点的伪结束点,直接跳过它就好了。我们用一个last指针表示“在它顶上的fail边所指向的一串节点中,第一个真正的结束节点”。于是,每次计数处理时,我们不跳fail边,改为跳last边,省去了很多冗余操作。
获得last指针的方法也十分简单,就是在void build()
中加一句话:
last[u] = len[fail[u]]? fail[u]: last[fail[u]];
Code
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+10;
char a[maxn], str[maxn], S[maxn], tmp[maxn];
int res[maxn], pos[maxn], L[maxn];
int n, pn;
int val[maxn];
int trie[maxn][26], fail[maxn], last[maxn];
int len[maxn], dep[maxn];
int e[maxn];
int que[maxn],h, t;
int tot;
inline void insert(string t, int id) {
register int p = 0;
for (register int c, i = 0; t[i]; ++i) {
c = t[i]-'A';
if(!trie[p][c]) {
trie[p][c] = ++tot;
dep[tot] = dep[p] + 1;
}
p = trie[p][c];
}
if(val[e[p]] <= val[id]) e[p] = id;
len[p] = dep[p];
}
inline void build() {
h = 1, t = 0;
for (register int i = 0; i < 26; ++i) {
if(trie[0][i])
que[++t] = trie[0][i];
}
while(h <= t) {
register int u = que[h++];
for (register int i = 0; i < 26; ++i) {
if(trie[u][i]) fail[trie[u][i]] = trie[fail[u]][i], que[++t] = trie[u][i];
else trie[u][i] = trie[fail[u]][i];
}
last[u] = len[fail[u]]? fail[u]: last[fail[u]];
}
}
int pre[maxn], dp[maxn], sta[maxn];
inline void count(char* str) {
dp[0] = 0;
register int p = 0, LL = 0;
for (register int i = 1; str[i]; ++i, ++LL) {
register int c = str[i]-'A';
p = trie[p][c];
dp[i] = -1;
for (register int j = p; j; j = last[j]) {
if(e[j] && dp[i-len[j]]!=-1 && dp[i] < dp[i-len[j]]+val[e[j]]) {
dp[i] = dp[i-len[j]]+val[e[j]];
sta[i] = e[j];
pre[i] = i-len[j];
}
}
}
for (register int i = LL; i > 0; i = pre[i]) res[++pn] = sta[i];
for (register int i = pn; i >= 1; --i) {
for (register int j = pos[res[i]]; j < pos[res[i]] + L[res[i]]; ++j) putchar(S[j]);
if(i==1) putchar('
');
else putchar(' ');
}
}
int main() {
val[0] = -1;
scanf("%d", &n);
register int ll = 0;
for (register int i = 1; i <= n; ++i) {
scanf("%s", a);
pos[i] = ll; L[i] = strlen(a);
strcat(S+ll, a); ll += L[i];
val[i] = 0;
register int LLL = 0;
for (register int j = 0; a[j]; ++j) {
char ch = a[j];
if(ch=='A'||ch=='E'||ch=='I'||ch=='O'||ch=='U') ++val[i];
else tmp[LLL++] = a[j];
}
tmp[LLL] = '