232-aka-pretty-pi 2024. 11. 4. 19:25
template <int L, int H>
struct AhoCorasick {
  static_assert(L <= H);

  struct Node {
    int fail;
    int near;
    int chain_sz;
    std::array<int, H - L + 1> to;
    std::optional<std::vector<int>> ids;

    int operator[](int idx) const {
      return to[idx];
    }

    int& operator[](int idx) {
      return to[idx];
    }

    Node() : fail(0), near(0), chain_sz(0), to({}) {
      void(0);
    }
  };

  int n;
  std::vector<int> lens;
  std::vector<Node> nodes;

  template <typename S>
  AhoCorasick(const std::vector<S>& s) : n(int(s.size())), lens(n), nodes(1) {
    for (int i = 0; i < n; ++i) {
      lens[i] = s[i].size();
      assert(lens[i] != 0);
    }

    for (int i = 0; i < n; ++i) {
      int cur = 0;
      for (const auto& c : s[i]) {
        assert(L <= c && c <= H);
        if (nodes[cur][c - L] == 0) {
          nodes[cur][c - L] = nodes.size();
          nodes.emplace_back();
        }
        cur = nodes[cur][c - L];
      }
      if (!nodes[cur].ids) {
        nodes[cur].ids.emplace();
      }
      nodes[cur].ids->push_back(i);
      nodes[cur].chain_sz += 1;
    }

    std::vector<int> que(1, 0);
    for (int b = 0, sz = nodes.size(); b < sz; ++b) {
      int cur = que[b];
      for (int i = 0; i < H - L + 1; ++i) {
        if (nodes[cur][i] == 0) {
          continue;
        }
        int nxt = nodes[cur][i];
        int f = nodes[cur].fail;
        while (f != 0 && nodes[f][i] == 0) {
          f = nodes[f].fail;
        }
        if (nxt != nodes[f][i]) {
          f = nodes[f][i];
        }
        nodes[nxt].fail = f;
        if (nodes[f].ids) {
          nodes[nxt].near = f;
        } else {
          nodes[nxt].near = nodes[f].near;
        }
        nodes[nxt].chain_sz += nodes[nodes[nxt].near].chain_sz;
        que.push_back(nxt);
      }
    }
  }

  template <typename S>
  [[nodiscard]] long long count(const S& t) const {
    long long cnt = 0;
    int cur = 0;
    for (const auto& c : t) {
      assert(L <= c && c <= H);
      while (cur != 0 && nodes[cur][c - L] == 0) {
        cur = nodes[cur].fail;
      }
      cur = nodes[cur][c - L];
      cnt += nodes[cur].chain_sz;
    }
    return cnt;
  }

  template <typename S>
  [[deprecated]] [[nodiscard]] std::vector<std::vector<int>> findEqualPositions(const S& t) const { // never tested this function
    std::vector<std::vector<int>> pos(n);
    int cur = 0;
    for (int i = 0; i < int(t.size()); ++i) {
      assert(L <= t[i] && t[i] <= H);
      while (cur != 0 && nodes[cur][t[i] - L] == 0) {
        cur = nodes[cur].fail;
      }
      cur = nodes[cur][t[i] - L];
      for (int j = cur; j != 0; j = nodes[j].near) {
        if (!nodes[j].ids) {
          continue;
        }
        for (int id : *nodes[j].ids) {
          pos[id].push_back(i - lens[id] + 1);
        }
      }
    }
    return pos;
  }
};