template <int L, int H>
struct AhoCorasick {
static_assert(L <= H);
struct Node {
int fail;
int near;
int chain_sz;
std::array<int, H - L + 1> to;
std::optional<std::vector<int>> ids;
int operator[](int idx) const {
return to[idx];
}
int& operator[](int idx) {
return to[idx];
}
Node() : fail(0), near(0), chain_sz(0), to({}) {
void(0);
}
};
int n;
std::vector<int> lens;
std::vector<Node> nodes;
template <typename S>
AhoCorasick(const std::vector<S>& s) : n(int(s.size())), lens(n), nodes(1) {
for (int i = 0; i < n; ++i) {
lens[i] = s[i].size();
assert(lens[i] != 0);
}
for (int i = 0; i < n; ++i) {
int cur = 0;
for (const auto& c : s[i]) {
assert(L <= c && c <= H);
if (nodes[cur][c - L] == 0) {
nodes[cur][c - L] = nodes.size();
nodes.emplace_back();
}
cur = nodes[cur][c - L];
}
if (!nodes[cur].ids) {
nodes[cur].ids.emplace();
}
nodes[cur].ids->push_back(i);
nodes[cur].chain_sz += 1;
}
std::vector<int> que(1, 0);
for (int b = 0, sz = nodes.size(); b < sz; ++b) {
int cur = que[b];
for (int i = 0; i < H - L + 1; ++i) {
if (nodes[cur][i] == 0) {
continue;
}
int nxt = nodes[cur][i];
int f = nodes[cur].fail;
while (f != 0 && nodes[f][i] == 0) {
f = nodes[f].fail;
}
if (nxt != nodes[f][i]) {
f = nodes[f][i];
}
nodes[nxt].fail = f;
if (nodes[f].ids) {
nodes[nxt].near = f;
} else {
nodes[nxt].near = nodes[f].near;
}
nodes[nxt].chain_sz += nodes[nodes[nxt].near].chain_sz;
que.push_back(nxt);
}
}
}
template <typename S>
[[nodiscard]] long long count(const S& t) const {
long long cnt = 0;
int cur = 0;
for (const auto& c : t) {
assert(L <= c && c <= H);
while (cur != 0 && nodes[cur][c - L] == 0) {
cur = nodes[cur].fail;
}
cur = nodes[cur][c - L];
cnt += nodes[cur].chain_sz;
}
return cnt;
}
template <typename S>
[[deprecated]] [[nodiscard]] std::vector<std::vector<int>> findEqualPositions(const S& t) const { // never tested this function
std::vector<std::vector<int>> pos(n);
int cur = 0;
for (int i = 0; i < int(t.size()); ++i) {
assert(L <= t[i] && t[i] <= H);
while (cur != 0 && nodes[cur][t[i] - L] == 0) {
cur = nodes[cur].fail;
}
cur = nodes[cur][t[i] - L];
for (int j = cur; j != 0; j = nodes[j].near) {
if (!nodes[j].ids) {
continue;
}
for (int id : *nodes[j].ids) {
pos[id].push_back(i - lens[id] + 1);
}
}
}
return pos;
}
};