2019 Multi-University Training Contest 2 - 1009 - 回文自动机

时间:2022-02-01 10:39:47

http://acm.hdu.edu.cn/showproblem.php?pid=6599

有好几种实现方式,首先都是用回文自动机统计好回文串的个数。

记得把每个节点的cnt加到他的fail上,因为他既然出现了那么他的fail也当然会出现。

这里需要一直从fail向上找到一个长度恰好一半的节点,这也是TLE的来源,重复跳了大量的指针。dalao给了一种实现。

额外维护一个x数组,假如x[i]==0,说明还没人动过它,x[i]=i。

然后沿着fail把x[i]向上移动到第一个>=它的长度的一半的fail祖先,验证是不是长度恰好一半。

最后把i的fail父亲的x指针指向x[i],也就是把它的fail父亲直接x指向i的fail祖先,这样它父亲就不需要经过一系列没有用的fail转移了。

这个是个“不那么彻底的”路径压缩。源于一个点s的fail父f,f需要用的fail祖先必定是s的祖先,s可以每次把它父亲f一起给移动了。

当然也有hash大法、manacher大法。

hash大法好!

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

struct Node {
    int len, ch[26], fail;
    ll cnt, x, y;
    //string str;
    Node(int len = 0) : len(len), fail(0) {
        memset(ch, 0, sizeof(ch));
        //下面是维护额外信息
        cnt = 0;
        x = 0;
        y = 0;
        //str = "";
    }
    /*void show() {
        printf("  str=\"%s\"\n", str.c_str());
        printf("  len=%d cnt=%d\n", len, cnt);
    }*/
};

const int MAXN = 300000;

ll ans2[MAXN + 5];

//PalindromicAutomaton
struct PAM {
    Node nd[MAXN + 5];

    int len, top, last;     // len为字符串长度mtop为节点个数,last为最后插入字符所对应的节点
    char s[MAXN + 5];
    //string ls;  //用来展示的辅助字符串

    int getfail(int x) {        //沿着fail指针找到第一个回文后缀
        while(s[len - nd[x].len - 1] != s[len])
            x = nd[x].fail;
        return x;
    }

    void init() {
        len = 0, top = 0, last = 0;
        nd[top] = Node(0);
        nd[top].fail = 1;
        nd[++top] = Node(-1);
        nd[top].fail = 0;
        s[0] = '$';
    }

    void extend(char c) {
        s[++len] = c;
        int now = getfail(last);     //找到插入的位置
        //ls = nd[now].str + c;   //用来展示的辅助字符串
        if(!nd[now].ch[c - 'a']) {     //若没有这个节点,则新建并求出它的fail指针
            nd[++top] = Node(nd[now].len + 2);
            nd[top].fail = nd[getfail(nd[now].fail)].ch[c - 'a'];
            nd[now].ch[c - 'a'] = top;
        }
        last = nd[now].ch[c - 'a'];
        //nd[last].str = ls;  //用来展示的辅助字符串
        //下面是维护额外信息
        ++nd[last].cnt;
    }

    /*void show() {
        for(int i = top; i >= 0; --i) {
            printf("node:  id=%d\n", i);
            nd[i].show();
            printf("fail:  id=%d\n", nd[i].fail);
            nd[nd[i].fail].show();
            puts("");
        }
    }*/

    void count() {
        for(int i = top; i >= 2; --i)
            nd[nd[i].fail].cnt += nd[i].cnt;
    }

    void solve() {
        for(int i = top; i >= 2; --i) {
            if(nd[i].x == 0)
                nd[i].x  = i;
            while(nd[nd[i].x].len > (nd[i].len + 1) / 2 )
                nd[i].x = nd[nd[i].x].fail;
            if(nd[nd[i].x].len == (nd[i].len + 1) / 2)
                nd[i].y = 1;
            nd[nd[i].fail].x = nd[i].x;
        }
    }

    void ans(int n) {
        memset(ans2, 0, sizeof(ans2[1]) * (n + 1));
        for(int i = top; i >= 2; --i) {
            ans2[nd[i].len] += nd[i].cnt * nd[i].y;
        }
        for(int i = 1; i <= n; ++i) {
            printf("%lld%c", ans2[i], " \n"[i == n]);
        }
    }
} pam;

char s[MAXN + 5];

int main() {
#ifdef Yinku
    freopen("Yinku.in", "r", stdin);
#endif // Yinku
    while(~scanf("%s", s)) {
        int n = strlen(s);
        pam.init();
        for(int i = 0; s[i] != '\0'; ++i)
            pam.extend(s[i]);
        //pam.show();
        pam.count();
        pam.solve();
        pam.ans(n);
    }
    return 0;
}