后缀自动机SAM练习笔记 (一)
文章目录
- 算法
- 练习题
- P2408 不同子串个数
- [JSOI2012] 玄武密码
- 最长公共子串
- SDOI2016 生成魔咒
- TJOI2015 弦论
- BJOI2020 封印
- AHOI 2013 差异
- 2017 山东一轮集训 Day5 字符串
- P5161 WD与数列
算法
参考博客:点我
参考资料:OI-WIKI
学习笔记
练习题
P2408 不同子串个数
建出
S
A
M
SAM
SAM,由于
S
A
M
SAM
SAM 上任意一条从源点出发到某个节点的路径对应了一个子串,因此只需要对路径计数。
由于是
D
A
G
DAG
DAG,拓扑排序即可。
[JSOI2012] 玄武密码
实际上就是给你一个模板串,每次问你一个模式串的最长前缀满足这个前缀是模板串的子串。
根据
S
A
M
SAM
SAM 的性质:任意一个子串都对应了一条从源点开始的路径。由于
S
A
M
SAM
SAM 的边的含义是每次往后加一个字符,因此直接在
S
A
M
SAM
SAM 上游走就好了。
最长公共子串
题意:给你
n
n
n 个字符串,求这些字符串的最长公共子串。
n
≤
10
n \leq 10
n≤10,字符串长度不超过
1
0
5
10^5
105。
分析:
是 后缀自动机与
A
C
AC
AC 自动机相似性 的应用。
我们知道
A
C
AC
AC 自动机每个节点的失配指针是指向最长的前缀满足这个前缀是当前字符串的后缀。
实际上就是将当前字符串去掉最少的一段前缀。
而
S
A
M
SAM
SAM 的
p
a
r
e
n
t
parent
parent 树上每个点的父亲就是这个等价类去掉一些前缀字符得到的等价类,并且
e
n
d
p
o
s
endpos
endpos 集合是扩大的,那么就更容易匹配上当前的字符。
回顾
A
C
AC
AC 自动机上多模式串的匹配方式,考虑的是模板串的每一个位置为结尾能匹配哪些串。
那么在后缀自动机上我们类似的确定某个模板串的每个位置为结尾的串满足能成为模式串的子串的最长长度。
任意钦定一个字符串为模板串,假设为第
n
n
n 个串。那么对前
n
−
1
n - 1
n−1 个串都建出
S
A
M
SAM
SAM。
考虑枚举前
n
−
1
n - 1
n−1 个串,依次确定模板串每个位置为结尾的最长匹配长度,假设当前位置是
i
i
i,字符是
c
c
c,上一个位置匹配的最大长度是
l
e
n
len
len,所在
S
A
M
SAM
SAM 上的节点为
p
p
p。
那么如果
p
p
p 没有
c
c
c 儿子,就一直跳父亲,并把
l
e
n
len
len 修改为跳到节点的等价类最大长度。如果跳到
0
0
0 那么
l
e
n
=
0
len = 0
len=0,否则就跳到停下来的
p
p
p 节点的
c
c
c 儿子上,并把
l
e
n
len
len 修改为
l
e
n
+
1
len + 1
len+1。这时
l
e
n
len
len 就是
i
i
i 位置的最大匹配长度。
正确性证明是每次都是扩大最少的
e
n
d
p
o
s
endpos
endpos 集合,保证第一次停下来一定是最长的匹配长度。
时间复杂度证明:每次跳父亲
l
e
n
len
len 都会减少,匹配上
l
e
n
len
len 只会增加
1
1
1,总增量
O
(
L
)
O(L)
O(L) 因此只会跳
O
(
L
)
O(L)
O(L) 次。
那么每个位置上的最长匹配长度需要对每个模式串取
m
i
n
min
min,因此做
n
n
n 遍即可。
由于可以挑长度最小的串做模板串,因此总复杂度可以做到
O
(
∑
L
i
n
×
n
)
=
O
(
∑
L
i
)
O(\frac{\sum L_i}{n} \times n) = O(\sum L_i)
O(n∑Li×n)=O(∑Li)。
好像还可以 SA + 单调队列 去做。
CODE:
// 最长公共子串
#include<bits/stdc++.h>
using namespace std;
const int N = 11;
const int M = 1e5 + 10;
int n, m, len[N][M];
int ans[N][M];
char str[N][M];
struct SAM {
struct Node {
int len, fa;
int ch[26];
} node[M * 2];
int tot = 1, last = 1;
inline void extend(int c) {
int p = last; int np = last = ++ tot;
node[np].len = node[p].len + 1;
for(; p && !node[p].ch[c]; p = node[p].fa) node[p].ch[c] = np;
if(!p) node[np].fa = 1;
else {
int q = node[p].ch[c];
if(node[q].len == node[p].len + 1) node[np].fa = q;
else {
int nq = ++ tot;
node[nq] = node[q]; node[nq].len = node[p].len + 1;
node[q].fa = node[np].fa = nq;
for(; p && node[p].ch[c] == q; p = node[p].fa) node[p].ch[c] = nq;
}
}
}
} Sam[N];
int main() {
int n = 1;
while(~scanf("%s", str[n] + 1)) n ++;
n --;
for(int i = 1; i <= n; i ++ ) {
scanf("%s", str[i] + 1);
if(i < n) {
int l = strlen(str[i] + 1);
for(int j = 1; j <= l; j ++ ) Sam[i].extend(str[i][j] - 'a');
}
else m = strlen(str[i] + 1);
}
for(int i = 1; i < n; i ++ ) {
int p = 1, len = 0;
for(int j = 1; j <= m; j ++ ) {
while(p != 0 && !Sam[i].node[p].ch[str[n][j] - 'a']) p = Sam[i].node[p].fa, len = Sam[i].node[p].len;
if(p == 0) p = 1;
else p = Sam[i].node[p].ch[str[n][j] - 'a'], len ++;
ans[i][j] = len;
}
}
int res = 0;
for(int i = 1, tmp = m + 1; i <= m; i ++, tmp = m + 1) {
for(int j = 1; j < n; j ++ )
tmp = min(tmp, ans[j][i]);
res = max(res, tmp);
}
cout << res << endl;
return 0;
}
SDOI2016 生成魔咒
题意:每次往一个字符串末尾加一个数字,查询本质不同的子串数量。
1 ≤ n ≤ 1 0 5 , 1 ≤ x ≤ 1 0 9 1 \leq n \leq 10^5, 1\leq x \leq 10^9 1≤n≤105,1≤x≤109。
分析:
如果使用后缀数组
S
A
SA
SA 去做,那么就是把字符串倒过来,相当于每次在开头加一个字符,那么只需要往后缀数组里加入一个后缀即可。可以先求出完整的每个后缀的排名以及
h
e
i
g
h
t
height
height 数组,然后用
s
e
t
set
set 动态维护添加过程时的排名
h
e
i
g
h
t
height
height 数组的变化。
但是由于后缀自动机的建立过程本来就是每次拓展一位,因此考虑起来更加自然:
我们考虑一个字符串的本质不同子串数怎么用后缀自动机求:
第一个做法是等价于路径计数然后拓扑排序求。
第二个做法是计算每个等价类中的字符串数量然后求和。
一个等价类
u
u
u 中的字符串数量等于
l
e
n
u
−
l
e
n
f
a
u
len_u - len_{fa_u}
lenu−lenfau。这个就是由于
p
a
r
e
n
t
parent
parent 树上每个节点的儿子都是由这个节点的最长字符串往前添加字符得到的,并且是一段区间。那么儿子中字符串的最小值就是
l
e
n
f
a
t
u
+
1
len_{fat_u} + 1
lenfatu+1,最大值就是
l
e
n
u
len_u
lenu。
那么我们只需要将新加入的节点的贡献累加到答案就好了。
CODE:
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10;
int n, num[N];
namespace SAM {
struct Node {
int len, fa;
unordered_map< int, int > ch;
} node[N * 2];
int tot = 1, last = 1;
inline int extend(int c) {
int p = last, np = last = ++ tot;
node[np].len = node[p].len + 1;
for(; p && !node[p].ch[c]; p = node[p].fa) node[p].ch[c] = np;
if(!p) {node[np].fa = 1; return node[np].len;}
else {
int q = node[p].ch[c];
if(node[q].len == node[p].len + 1) {node[np].fa = q; return node[np].len - node[q].len;}
else {
int nq = ++ tot;
node[nq] = node[q]; node[nq].len = node[p].len + 1;
node[q].fa = node[np].fa = nq;
for(; p && node[p].ch[c] == q; p = node[p].fa) node[p].ch[c] = nq;
return node[np].len - node[nq].len;
}
}
}
}
int main() {
scanf("%d", &n); LL sum = 0;
for(int i = 1; i <= n; i ++ ) {
scanf("%d", &num[i]);
sum += SAM::extend(num[i]);
printf("%lld\n", sum);
}
return 0;
}
TJOI2015 弦论
题意:给你一个长度为 n n n 字符串 S S S,字符集只有小写字母。你需要求出字典序第 K K K 小的非空子串。 输入还会给你一个 T T T:当 T = 0 T = 0 T=0 时,本质相同但是位置不同的子串看作一个。当 T = 1 T = 1 T=1 时,本质相同但是位置不同的子串看作多个。
1 ≤ n ≤ 5 × 1 0 5 1 \leq n \leq 5 \times 10^5 1≤n≤5×105, 1 ≤ K ≤ 1 0 9 1 \leq K \leq 10^9 1≤K≤109。
分析:
额外考虑一个问题:如何计算某个子串的排名。
当 T = 0 T = 0 T=0 时, S A SA SA 也可以做这两个问题,我们来比较一下:
- 求排名第
K
K
K 的字符串
对 S A SA SA:在后缀数组上依次将每个后缀去掉前 h e i g h t i height_i heighti 个前缀后将剩下的字符串写下来拍成一行就是字典序从小到大排序。因此只需要扫一遍 h e i g h t height height 数组即可。预处理前缀和可以单次 O ( log n ) O(\log n) O(logn) 查询。
对 S A M SAM SAM:考虑依次确定每一位上的字符。那么相当于在当前节点 p p p 上每次考虑下一步该走那条边 c c c,满足 不走 与 走 0 ∼ c − 1 0 \sim c - 1 0∼c−1 的路径条数和小于 K K K。 拓扑出自动机上每个点作为起点的路径条数和即可。每次查询 O ( n ) O(n) O(n)。 - 求某个子串的排名
对 S A SA SA:考虑在 s a sa sa 数组向上二分得到这个子串第一次被计算到的位置,预处理前缀和可以单次 O ( log n ) O(\log n) O(logn) 查询。
对 S A M SAM SAM:跟上面类似,只需要在 S A M SAM SAM 上游走一遍即可。单次 O ( n ) O(n) O(n)。
发现这种情况下 S A SA SA 的复杂度较优一些。
当
T
=
1
T = 1
T=1 时,
S
A
SA
SA 不能做了。考虑
S
A
M
SAM
SAM:
发现和上面唯一的不同点在于每一条路径带权,这个权值就是路径终点的
e
n
d
p
o
s
endpos
endpos 大小。那么将初始化变一下和上面是一模一样的。复杂度
O
(
n
)
O(n)
O(n)。
u
p
d
:
upd:
upd: 刚知道
S
A
M
SAM
SAM 也可以
P
o
l
y
l
o
g
Polylog
Polylog 求子串排名以及定位子串,需要用到 DAG 剖分,感觉太nb了。具体做法放下面了:
CODE:
#include<bits/stdc++.h>
#define pb emplace_back
using namespace std;
typedef long long LL;
const int N = 5e5 + 10;
bool vis[N * 2];
LL c[N * 2], g[N * 2], f[N * 2];
vector< char > ans;
int n, k, t;
char str[N];
struct edge {
int v, last;
}E[N * 2];
int head[N * 2], tot;
inline void add(int u, int v) {
E[++ tot] = (edge) {v, head[u]};
head[u] = tot;
}
namespace SAM {
struct Node {
int len, fa;
int ch[26];
} node[N * 2];
int tot = 1, last = 1;
inline void extend(int c) {
int p = last, np = last = ++ tot;
g[tot] = 1;
node[np].len = node[p].len + 1;
for(; p && !node[p].ch[c]; p = node[p].fa) node[p].ch[c] = np;
if(!p) node[np].fa = 1;
else {
int q = node[p].ch[c];
if(node[q].len == node[p].len + 1) node[np].fa = q;
else {
int nq = ++ tot;
node[nq] = node[q]; node[nq].len = node[p].len + 1;
node[q].fa = node[np].fa = nq;
for(; p && node[p].ch[c] == q; p = node[p].fa) node[p].ch[c] = nq;
}
}
}
void dfs(int x) {
for(int i = head[x]; i; i = E[i].last) {int v = E[i].v; dfs(v); g[x] += g[v];} // 有多少个 endpos
}
void Dfs(int x) {
if(vis[x]) return ;
vis[x] = 1; f[x] = c[x];
for(int i = 0; i < 26; i ++ ) {
if(node[x].ch[i]) Dfs(node[x].ch[i]);
f[x] += f[node[x].ch[i]];
}
}
inline void get(int k) { // 找到字典序第 k 小的字符串
if(f[1] < k) return ;
int p = 1;
while(1) {
if(c[p] >= k) break;
k -= c[p];
for(int i = 0; i < 26; i ++ )
if(node[p].ch[i])
if(f[node[p].ch[i]] >= k) {p = node[p].ch[i]; ans.pb(i + 'a'); break;}
else k -= f[node[p].ch[i]];
}
return ;
}
}
int main() {
scanf("%s", str + 1); n = strlen(str + 1);
scanf("%d%d", &t, &k);
for(int i = 1; i <= n; i ++ ) SAM::extend(str[i] - 'a');
for(int i = 1; i <= SAM::tot; i ++ ) add(SAM::node[i].fa, i);
SAM::dfs(1);
for(int i = 1; i <= SAM::tot; i ++ ) {
if(t == 0) c[i] = 1;
else c[i] = g[i];
}
c[1] = 0; // 不算空子符串
SAM::Dfs(1);
SAM::get(k);
if(ans.empty()) puts("-1");
else for(auto v : ans) putchar(v);
return 0;
}
BJOI2020 封印
题意:给出只包含小写字母 a , b a,b a,b 的两个字符串 s , t s,t s,t, q q q 次询问,每次询问 s [ l … r ] s[l \dots r] s[l…r] 和 t t t 的最长公共子串长度。
∣ s ∣ , ∣ t ∣ ≤ 2 × 1 0 5 , q ≤ 2 × 1 0 5 |s|,|t| \leq 2 \times 10^5, q \leq 2 \times 10^5 ∣s∣,∣t∣≤2×105,q≤2×105。
分析:
看到了 最长公共子串,可以按照上面的套路用
S
A
M
SAM
SAM 在线性复杂度内求出
s
s
s 每个位置为结尾的最长公共子串长度。
设每个位置上的答案为
a
n
s
i
ans_i
ansi。那么答案就是
max
i
=
l
r
{
m
i
n
(
i
−
l
+
1
,
a
n
s
i
)
}
\max\limits_{i = l}^{r}\{min(i - l + 1, ans_i)\}
i=lmaxr{min(i−l+1,ansi)}。
考虑到
a
n
s
i
ans_i
ansi 的意义:最长往前匹配多长。那么一旦有一个位置
p
o
s
pos
pos 往前延伸不到
l
l
l,
p
o
s
pos
pos 后面的位置肯定也延伸不到
l
l
l。因此可以二分出
p
o
s
pos
pos,那么答案就是
p
o
s
−
l
+
1
pos - l + 1
pos−l+1 与
[
p
o
s
+
1
,
r
]
[pos + 1, r]
[pos+1,r] 的
a
n
s
i
ans_i
ansi 取
m
a
x
max
max。建
s
t
st
st 表可以
O
(
1
)
O(1)
O(1) 查询。
CODE:
#include<bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
char s[N], t[N];
int n, q, ans[N];
namespace SAM {
struct Node {
int len, fa;
int ch[26];
} node[N * 2];
int tot = 1, last = 1;
inline void extend(int c) {
int p = last, np = last = ++ tot;
node[np].len = node[p].len + 1;
for(; p && !node[p].ch[c]; p = node[p].fa) node[p].ch[c] = np;
if(!p) node[np].fa = 1;
else {
int q = node[p].ch[c];
if(node[q].len == node[p].len + 1) node[np].fa = q;
else {
int nq = ++ tot;
node[nq] = node[q]; node[nq].len = node[p].len + 1;
node[q].fa = node[np].fa = nq;
for(; p && node[p].ch[c] == q; p = node[p].fa) node[p].ch[c] = nq;
}
}
}
}
int st[18][N];
inline void build_st() {
for(int i = 1; i <= n; i ++ ) st[0][i] = ans[i];
for(int i = 1; (1 << i) <= n; i ++ )
for(int j = 1; j + (1 << i) - 1 <= n; j ++ )
st[i][j] = max(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
}
int query(int l, int r) {
if(l > r) return 0;
int k = log2(r - l + 1);
return max(st[k][l], st[k][r - (1 << k) + 1]);
}
int main() {
scanf("%s", s + 1); scanf("%s", t + 1);
n = strlen(t + 1);
for(int i = 1; i <= n; i ++ ) SAM::extend(t[i] - 'a');
n = strlen(s + 1);
int p = 1, len = 0;
for(int i = 1; i <= n; i ++ ) {
while(p && !SAM::node[p].ch[s[i] - 'a']) p = SAM::node[p].fa, len = SAM::node[p].len;
if(!p) p = 1;
else p = SAM::node[p].ch[s[i] - 'a'], len ++;
ans[i] = len;
}
build_st();
scanf("%d", &q);
while(q -- ) {
int l, r; scanf("%d%d", &l, &r);
int ll = l, rr = r, mid, res = l - 1;
while(ll <= rr) {
mid = (ll + rr >> 1);
if(ans[mid] >= mid - l + 1) res = mid, ll = mid + 1;
else rr = mid - 1;
}
int ans = max(res - l + 1, query(res + 1, r));
printf("%d\n", ans);
}
return 0;
}
AHOI 2013 差异
题意:给定长度为
n
n
n 的字符串
S
S
S,令
T
i
T_i
Ti 表示它从第 i 个字符开始的后缀。求:
∑
1
≤
i
<
j
≤
n
l
e
n
(
T
i
)
+
l
e
n
(
T
j
)
−
2
×
l
c
p
(
T
i
,
T
j
)
\sum\limits_{1 \leq i < j \leq n} len(T_i) + len(T_j) - 2 \times lcp(T_i, T_j)
1≤i<j≤n∑len(Ti)+len(Tj)−2×lcp(Ti,Tj)
其中,
l
e
n
(
a
)
len(a)
len(a) 表示字符串 a 的长度,
l
c
p
(
a
,
b
)
lcp(a,b)
lcp(a,b) 表示字符串 a 和字符串 b 的最长公共前缀。
2 ≤ n ≤ 500000 2 \leq n \leq 500000 2≤n≤500000,字符集为小写字母。
分析:
用后缀数组
S
A
SA
SA 就是将
h
e
i
g
h
t
height
height 从大到小排序后并查集合并。比较板子。
只会
S
A
M
SAM
SAM 怎么办??
从这道题来分析
S
A
M
SAM
SAM 与
S
A
SA
SA 和后缀树之间的联系:
首先先来介绍 后缀树:
我们定义一个字符串
s
s
s 后缀
t
r
i
e
trie
trie 为将
s
s
s 的所有后缀插入
t
r
i
e
trie
trie 树后得到的字典树。
定义 后缀节点 为插入某个后缀时的终止节点。
后缀
t
r
i
e
trie
trie 有很好的性质:
- 两个后缀节点的 l c a lca lca 的深度就是它们的 l c p lcp lcp 长度。
- 按照字典序从小到大遍历树边得到的 d f s dfs dfs 序可以将后缀排序。
但是暴力插入后缀的时空复杂度都是 O ( n 2 ) O(n^2) O(n2) 的。
令后缀 t r i e trie trie 中所有拥有多于一个儿子的节点和后缀节点为关键点。定义 后缀树 为只保留关键点,一条链上的非关键点压缩成一条边后得到的树。 定义 隐式后缀树 为将 关键点 定义改为拥有多于一个儿子的节点和叶子节点,然后压缩 t r i e trie trie 得到的树。不难看出 隐式后缀树是后缀树的进一步压缩,因为后缀节点不一定是叶子,叶子一定是后缀节点。
下图为字符串
c
a
b
a
b
cabab
cabab 对应的后缀
t
r
i
e
trie
trie 树,后缀树和隐式后缀树。
隐式后缀树没什么用处,我们主要考虑后缀树。
后缀树每条边对应了一个字符串。每个非根节点
x
x
x 对应了一个字符串集合,为从根走到
x
x
x 的父亲
f
a
x
fa_x
fax 得到的字符串拼上
(
f
a
x
,
x
)
(fa_x, x)
(fax,x) 上字符串上的任意一个非空前缀。
有一个重要的结论:
反串 S A M 的 p a r e n t 树就是原串的后缀树 {\LARGE 反串 \ SAM\ 的\ parent\ 树就是原串的后缀树} 反串 SAM 的 parent 树就是原串的后缀树
严谨证明不会。感性理解:
S
A
M
SAM
SAM 中每个节点
x
x
x 都表示一个字符串集合,设最长的字符串为
m
x
S
x
mxS_x
mxSx,最短的为
m
n
S
x
mnS_x
mnSx。
那么节点
x
x
x 的儿子对应的字符串集合就是在
m
x
S
x
mxS_x
mxSx 前面添加一段字符,使得得到的新字符串与
x
x
x 的字符串在原串的
e
n
d
p
o
s
endpos
endpos 集合不同。
对应到反串上,就是在
m
x
S
x
mxS_x
mxSx 的后面添加一段字符,使得得到的新字符与
x
x
x 的字符串的
b
e
g
i
n
p
o
s
beginpos
beginpos 集合不同。
只有
b
e
g
i
n
p
o
s
beginpos
beginpos 相同的字符串在某段上才可以压缩,因此可以认为每个儿子是
x
x
x 节点在
t
i
r
e
tire
tire 树上的分岔。
也由此可以得到后缀树的节点数量是 O ( n ) O(n) O(n) 的。
那么每个节点里存的字符串集合,对应到原串里我们都应该反过来看。
用后缀树解决下面两个问题:
-
用后缀树求两个子串的 l c p lcp lcp。
首先需要定位两个子串,我们应该先将区间翻转过来: [ l , r ] → [ n − r + 1 , n − l + 1 ] [l, r] \to [n - r + 1, n - l + 1] [l,r]→[n−r+1,n−l+1],然后在 p a r e n t parent parent 树查找这个子串所在的节点:首先找到 [ 1 , r ] [1, r] [1,r] 的节点是简单的,只需要构建 S A M SAM SAM 时记录即可。然后不断跳 f a fa fa 直到最小长度小于等于 r − l + 1 r - l + 1 r−l+1,这个倍增一下就好了。然后只需要查两个点的 l c a lca lca, l c a lca lca 的 l e n len len 也就是最大长度就是它们的 l c p lcp lcp。 -
用后缀树对后缀排序。
每个后缀对应到反串上是一个前缀,那么只需要找到所有前缀的 d f s dfs dfs 序就好了。我们需要记录每一个点 x x x 与它的儿子之间的字符串 s s s 在原串上的第一个字符。相当于儿子在父亲 m x S mxS mxS 基础上往前补的第一个字符,这个可以在建 S A M SAM SAM 的过程中记录下来。
由此得出: S A M SAM SAM 基本解决 S A SA SA 的所有问题,因为它可以通过后缀树完成 S A SA SA 最重要的工作:后缀排序。
另外补充一下:后缀树的构建还可以用 U k k o n e n Ukkonen Ukkonen 算法。但是没啥用。
回到本题:建出后缀树后只需要在每个 l c a lca lca 处计算贡献即可。因此树形 D P DP DP 一下就行。
CODE:
#include<bits/stdc++.h>
#define pb emplace_back
using namespace std;
const int N = 5e5 + 10;
typedef long long LL;
int n;
LL ans, del, f[N * 2];
char str[N];
vector< int > E[N * 2];
namespace SAM {
struct Node {
int len, fa;
int ch[26];
} node[N * 2];
int tot = 1, last = 1;
inline void extend(int c) {
int p = last, np = last = ++ tot;
f[tot] = 1; // 前缀对应到原串上是后缀
node[np].len = node[p].len + 1;
for(; p && !node[p].ch[c]; p = node[p].fa) node[p].ch[c] = np;
if(!p) node[np].fa = 1;
else {
int q = node[p].ch[c];
if(node[q].len == node[p].len + 1) node[np].fa = q;
else {
int nq = ++ tot;
node[nq] = node[q]; node[nq].len = node[p].len + 1;
node[q].fa = node[np].fa = nq;
for(; p && node[p].ch[c] == q; p = node[p].fa) node[p].ch[c] = nq;
}
}
}
void dfs(int x) {
for(auto v : E[x]) dfs(v), del += 2LL * f[x] * f[v] * node[x].len, f[x] += f[v];
}
}
int main() {
scanf("%s", str + 1); n = strlen(str + 1);
for(int i = 1; i <= n; i ++ ) ans += 1LL * (n - 1) * (n - i + 1);
reverse(str + 1, str + n + 1);
for(int i = 1; i <= n; i ++ ) SAM::extend(str[i] - 'a');
for(int i = 1; i <= SAM::tot; i ++ ) E[SAM::node[i].fa].pb(i);
SAM::dfs(1);
cout << ans - del << endl;
return 0;
}
2017 山东一轮集训 Day5 字符串
题意:
给定
n
n
n 个字符集为小写字母的字符串
s
i
s_i
si,一个串
t
t
t 是可接受的,当且仅当
t
t
t 可以表示成
p
1
+
p
2
⋯
+
p
n
p_1+p_2 \dots +p_n
p1+p2⋯+pn,其中
p
i
p_i
pi 为
s
i
s_i
si 的一个子串(可以为空)。问有多少种本质不同的字符串
t
t
t 是可接受。答案对
1
0
9
+
7
10^9 + 7
109+7 取模。
1 ≤ n ≤ 1 0 6 , ∑ i = 1 n ∣ s i ∣ ≤ 1 0 6 1 \leq n \leq 10^6, \sum\limits_{i = 1}^{n}|s_i| \leq 10^6 1≤n≤106,i=1∑n∣si∣≤106。
分析:
状态机
d
p
dp
dp 好题。
首先每个
s
i
s_i
si 只有本质不同的子串有用。考虑对每个串建出
S
A
M
SAM
SAM。
那么有一个显然错误的想法是把每个串本质不同的子串数量乘起来就是答案。错误原因是会重复。
先假设所有
p
i
p_i
pi 非空:
考虑一组方案中的
p
i
+
p
i
+
1
p_{i} + p_{i + 1}
pi+pi+1,设
p
i
+
1
p_{i + 1}
pi+1 的开头字符为
c
c
c。如果
p
i
p_i
pi 后面接上
c
c
c 后也是
s
i
s_i
si 的子串,那么这种方案就会和
(
p
i
+
c
)
+
(
−
c
+
p
i
+
1
)
(p_i + c) + (-c +p_{i + 1})
(pi+c)+(−c+pi+1) 重复。
反过来,如果任意
p
i
+
c
p_i + c
pi+c 都不是
s
i
s_i
si 的子串,那么我们把它统计到答案中,这时候答案就是不重不漏的。
证明的话考虑对于一组方案从前往后按照上面的方式调整,即不断将
p
i
+
1
p_{i + 1}
pi+1 的开头移到
p
i
p_i
pi 的末尾,那么我们要统计的就是所有没法调整的方案。那么需要证明任意两个无法调整的方案拼出来的字符串不相等。反证法:假设相等,那么其中一个一定能调整到后一个,证毕。
如果存在
p
i
p_i
pi 为空怎么办?
假设
p
i
p_i
pi 为空,那么只需要
p
i
+
1
p_{i + 1}
pi+1 的首字母
c
c
c 不是
p
i
p_i
pi 的子串并且
p
i
−
1
+
c
p_{i - 1} + c
pi−1+c 不是
s
i
−
1
s_{i - 1}
si−1 的子串即可。相当于对限制求交得到
c
c
c 的取值集合。
那么可以
d
p
dp
dp,只需要记录上一个状态的限制下每种字母可以作为开头的方案即可。
设
f
i
,
c
f_{i, c}
fi,c 表示考虑了前
i
i
i 个字符串,第
i
+
1
i + 1
i+1 个字符串开头是
c
c
c 的话能提供合法方案数。
d
p
i
,
x
,
c
dp_{i, x, c}
dpi,x,c 表示第
i
i
i 个字符串
S
A
M
SAM
SAM 上的
x
x
x 节点开始游走,路径的最后一个节点没有
c
c
c 边的路径数。
那么转移有:
p
i
p_i
pi 不为空:
f
i
,
c
←
f
i
−
1
,
y
×
d
p
i
,
s
o
n
1
,
y
,
c
f_{i, c} \gets f_{i - 1, y} \times dp_{i, son_{1, y}, c}
fi,c←fi−1,y×dpi,son1,y,c
p
i
p_i
pi 为空:
f
i
,
c
←
f
i
−
1
,
c
×
[
s
o
n
1
,
c
=
0
]
f_{i, c} \gets f_{i - 1, c} \times [son_{1, c} = 0]
fi,c←fi−1,c×[son1,c=0]
如何计算答案:
转移到当前
i
i
i 是钦定
i
i
i 是最后一个非空的
p
i
p_i
pi。
那么有:
a
n
s
←
f
i
−
1
,
x
×
g
i
,
s
o
n
1
,
x
ans \gets f_{i - 1, x} \times g_{i, son_{1, x}}
ans←fi−1,x×gi,son1,x
g
i
,
x
g_{i, x}
gi,x 表示从
i
i
i 号
S
A
M
SAM
SAM 的 第
x
x
x 个点开始的路径条数。
最后答案加一因为可以全部为空。复杂度 O ( 26 ∑ ∣ s i ∣ ) O(26 \sum|s_i|) O(26∑∣si∣)。
CODE:
#include<bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
typedef long long LL;
const int mod = 1e9 + 7;
int n, tim[N * 2];
char str[N];
int g[26], h[26], f[N * 2][26], u[N * 2], ret; // f[i][c] 表示从 i 开始走一条路径最后中点没有 c 边的方案数
inline int Add(int x, int y) {
return x + y >= mod ? x + y - mod : x + y;
}
struct Node {
int len, fa;
int ch[26];
} node[N * 2];
struct SAM {
int tot = 1, last = 1;
inline int newnode() {
tot ++; node[tot] = node[0];
return tot;
}
inline void extend(int c) {
int p = last, np = last = newnode();
node[np].len = node[p].len + 1;
for(; p && !node[p].ch[c]; p = node[p].fa) node[p].ch[c] = np;
if(!p) node[np].fa = 1;
else {
int q = node[p].ch[c];
if(node[q].len == node[p].len + 1) node[np].fa = q;
else {
int nq = newnode();
node[nq] = node[q]; node[nq].len = node[p].len + 1;
node[q].fa = node[np].fa = nq;
for(; p && node[p].ch[c] == q; p = node[p].fa) node[p].ch[c] = nq;
}
}
}
inline void build(char *str) {
int n = strlen(str + 1);
node[1] = node[0];
for(int i = 1; i <= n; i ++ ) extend(str[i] - 'a');
}
void dfs(int x, int t) {
if(tim[x] == t) return ;
tim[x] = t; u[x] = 1;
for(int i = 0; i < 26; i ++ ) f[x][i] = 0;
for(int i = 0; i < 26; i ++ )
if(node[x].ch[i]) {
int y = node[x].ch[i];
dfs(y, t); u[x] = Add(u[x], u[y]);
for(int j = 0; j < 26; j ++ ) f[x][j] = Add(f[x][j], f[y][j]);
}
for(int i = 0; i < 26; i ++ ) if(!node[x].ch[i]) f[x][i] = Add(f[x][i], 1);
}
} sam[N];
int main() {
scanf("%d", &n);
for(int i = 0; i < 26; i ++ ) g[i] = 1;
for(int i = 1; i <= n; i ++ ) {
scanf("%s", str + 1);
sam[i].build(str);
sam[i].dfs(1, i);
// x先求贡献, 假设后面都空
for(int j = 0; j < 26; j ++ )
if(node[1].ch[j]) ret = Add(ret, 1LL * g[j] * u[node[1].ch[j]] % mod); // 有长度
memset(h, 0, sizeof h);
for(int j = 0; j < 26; j ++ ) { // 先来考虑非空的贡献
if(node[1].ch[j]) {
for(int k = 0; k < 26; k ++ ) {
h[k] = Add(h[k], 1LL * g[j] * f[node[1].ch[j]][k] % mod);
}
}
}
// 接着考虑填空串的贡献
for(int j = 0; j < 26; j ++ )
if(!node[1].ch[j]) h[j] = Add(h[j], g[j]);
memcpy(g, h, sizeof h);
}
ret = Add(ret, 1);
cout << ret << endl;
return 0;
}
P5161 WD与数列
题意:
定义两个整数序列
A
,
B
A,B
A,B 是 匹配的 当且仅当满足
∣
A
∣
=
∣
B
∣
|A| = |B|
∣A∣=∣B∣ 且
∀
1
≤
i
<
j
≤
∣
A
∣
,
A
i
−
B
i
=
A
j
−
B
j
\forall 1 \leq i < j \leq |A|, A_{i} - B_i = A_j - B_j
∀1≤i<j≤∣A∣,Ai−Bi=Aj−Bj。
现在给你一个长度为
n
n
n 的整数序列
C
C
C,问有多少对不相交连续子序列是匹配的。
1 ≤ n ≤ 3 × 1 0 5 , 1 ≤ ∣ C i ∣ ≤ 1 0 9 1 \leq n \leq 3 \times 10^5, 1 \leq |C_i| \leq 10^9 1≤n≤3×105,1≤∣Ci∣≤109。
分析:
将式子变形:
A
i
−
A
j
=
B
i
−
B
j
A_{i} - A_{j} = B_{i} - B_{j}
Ai−Aj=Bi−Bj。
得到两个长度大于等于
2
2
2 的序列是匹配的充要条件为 差分数组相同。
首先将答案加上
n
(
n
−
1
)
2
\frac{n(n - 1)}{2}
2n(n−1) 表示长度为
1
1
1 的匹配的答案。 然后将序列差分,那么问题变成了 有多少个不交且不相邻的连续子序列相等。
考虑一个暴力的想法:枚举两个前缀
[
1
,
i
]
,
[
1
,
j
]
[1, i],[1, j]
[1,i],[1,j], 求出它们的
l
c
s
lcs
lcs(最长公共后缀),然后将答案加上
min
(
l
c
s
,
j
−
i
−
1
)
\min(lcs, j - i - 1)
min(lcs,j−i−1)。
由于两个前缀
i
,
j
i, j
i,j 的
l
c
s
lcs
lcs 就是它们在
S
A
M
SAM
SAM 对应节点的
l
c
a
lca
lca 的最长长度,因此考虑建出
S
A
M
SAM
SAM 在
l
c
a
lca
lca 处计算答案。
那么只需要线段树维护区间下标和以及下标个数,线段树合并。启发式合并计算贡献即可。
下面写的是
S
A
SA
SA 版本:枚举前缀变成了枚举后缀,按照
h
e
i
g
h
t
height
height 从小到大合并就知道了
l
c
p
lcp
lcp,剩下的都一样。
复杂度
O
(
n
log
2
n
)
O(n \log^2 n)
O(nlog2n)。
CODE:
// 注意到一个性质:两个匹配的序列差分数组是相等的,变成了最长公共前缀问题
#include<bits/stdc++.h>
#define pb emplace_back
#define MP make_pair
using namespace std;
const int N = 3e5 + 10;
typedef long long LL;
typedef pair< int, LL > PII;
int n, tot;
LL a[N], b[N], d[N], ans;
int sa[N], height[N], rk[N];
int odr[N];
namespace SA {
int m, x[N * 2], y[N * 2], c[N];
void get_sa() {
m = tot;
for(int i = 1; i <= n; i ++ ) c[x[i] = b[i]] ++;
for(int i = 1; i <= m; i ++ ) c[i] += c[i - 1];
for(int i = n; i >= 1; i -- ) sa[c[x[i]] --] = i;
for(int k = 1; k <= n; k <<= 1) {
int num = 0;
for(int i = n - k + 1; i <= n; i ++ ) y[++ num] = i;
for(int i = 1; i <= n; i ++ )
if(sa[i] > k) y[++ num] = sa[i] - k;
for(int i = 0; i <= m; i ++ ) c[i] = 0;
for(int i = 1; i <= n; i ++ ) c[x[i]] ++;
for(int i = 1; i <= m; i ++ ) c[i] += c[i - 1];
for(int i = n; i >= 1; i -- ) sa[c[x[y[i]]] --] = y[i], y[i] = 0;
swap(x, y);
x[sa[1]] = 1, num = 1;
for(int i = 2; i <= n; i ++ )
x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++ num;
if(num == n) break;
m = num;
}
}
void get_height() {
for(int i = 1; i <= n; i ++ ) rk[sa[i]] = i;
for(int i = 1, k = 0; i <= n; i ++ ) {
if(rk[i] == 1) continue;
if(k) k --;
int j = sa[rk[i] - 1];
while(i + k <= n && j + k <= n && b[i + k] == b[j + k]) k ++;
height[rk[i]] = k;
}
}
}
bool cmp(int x, int y) {return height[x] > height[y];}
struct SegmentTree {
int ls, rs, cnt;
LL sum;
#define ls(x) t[x].ls
#define rs(x) t[x].rs
#define cnt(x) t[x].cnt
#define sum(x) t[x].sum
}t[N * 20];
int node, root[N];
void update(int p) {cnt(p) = cnt(ls(p)) + cnt(rs(p)); sum(p) = sum(ls(p)) + sum(rs(p));}
void ins(int &p, int lp, int rp, int pos) {
if(!p) p = ++ node;
if(lp == rp) {cnt(p) ++; sum(p) += lp; return ;}
int mid = (lp + rp >> 1);
if(pos <= mid) ins(ls(p), lp, mid, pos);
else ins(rs(p), mid + 1, rp, pos);
update(p);
}
int Merge(int p, int q, int lp, int rp) {
if(!p || !q) return p ^ q;
if(lp == rp) {cnt(p) += cnt(q); sum(p) += sum(q); return p;}
int mid = (lp + rp >> 1);
ls(p) = Merge(ls(p), ls(q), lp, mid);
rs(p) = Merge(rs(p), rs(q), mid + 1, rp);
update(p);
return p;
}
inline PII Add(PII x, PII y) {return MP(x.first + y.first, x.second + y.second);}
PII ask(int p, int lp, int rp, int l, int r) {
if(l > r) return MP(0, 0);
if(!p) return MP(0, 0);
if(l <= lp && r >= rp) return MP(cnt(p), sum(p));
int mid = (lp + rp >> 1);
if(r <= mid) return ask(ls(p), lp, mid, l, r);
else if(l > mid) return ask(rs(p), mid + 1, rp, l, r);
return Add(ask(ls(p), lp, mid, l, r), ask(rs(p), mid + 1, rp, l, r));
}
int bin[N];
int Find(int x) {return x == bin[x] ? x : bin[x] = Find(bin[x]);}
vector< int > V[N];
void Mg(int x, int y, int lcp) {
int f1 = Find(x), f2 = Find(y);
if(V[f1].size() > V[f2].size()) swap(f1, f2);
int sz1 = V[f1].size(), sz2 = V[f2].size();
for(auto v : V[f1]) {
int num = 0;
PII rs = ask(root[f2], 1, n, v + 1, min(v + lcp + 1, n));
ans += rs.second - 1LL * (v + 1) * rs.first; num += rs.first;
rs = ask(root[f2], 1, n, max(1, v - lcp - 1), v - 1);
ans += 1LL * (v - 1) * rs.first - rs.second; num += rs.first;
ans += 1LL * lcp * (sz2 - num);
V[f2].pb(v);
}
V[f1].clear(); bin[f1] = f2;
root[f2] = Merge(root[f2], root[f1], 1, n);
}
int main() {
scanf("%d", &n);
if(n == 1) {puts("0"); return 0;}
ans += 1LL * n * (n - 1) / 2LL;
for(int i = 1; i <= n; i ++ ) scanf("%lld", &a[i]);
for(int i = 1; i < n; i ++ ) b[i] = a[i + 1] - a[i];
n --; for(int i = 1; i <= n; i ++ ) d[++ tot] = b[i];
sort(d + 1, d + tot + 1);
tot = unique(d + 1, d + tot + 1) - (d + 1);
for(int i = 1; i <= n; i ++ ) b[i] = lower_bound(d + 1, d + tot + 1, b[i]) - (d);
SA::get_sa(); SA::get_height();
for(int i = 1; i <= n; i ++ ) odr[i] = i;
sort(odr + 1, odr + n + 1, cmp);
for(int i = 1; i <= n; i ++ ) {
bin[i] = i, V[i].pb(i), ins(root[i], 1, n, i);
}
for(int i = 1; i <= n; i ++ ) {
int u = odr[i];
if(u == 1) continue;
int x = sa[u], y = sa[u - 1];
Mg(x, y, height[u]);
}
cout << ans << endl;
return 0;
}