BZOJ2806: [Ctsc2012]Cheat熟悉的文章【二分+后缀自动机+dp】

二分+后缀自动机还是比较好想的。dp优化只能膜拜题解了。==

题意:给定一些01字符串作为标准作文库。然后给你几篇作文。对于每篇作文,把这篇作文分成若干段后,如果一段的长度不小于L,且是标准作文库中某个字符串的连续子串,这样这一段作文就是匹配的。求L的最大值,使得这篇作文的90%是匹配的。

 

一眼看出是二分+判定。因为如果L0可行,那小于L0的L都可行。

然后就可以考虑如何判定。分段的话应该是要dp了。设这个作文是S。

先预处理出一个数组d[i],d[i]=max{i-j+1|S[j..i]是标准作文库中某个字符串的连续子串}

这个用后缀自动机。把标准作文库里的所有字符串拼起来。中间加一个分隔符2好了。然后O(n)建后缀自动机。

然后让S在自动机上跑。

跑的过程具体就是:从root开始,如果下一个字符可以转移就转移,更新len值,如果不能就一直回退直到可以转移,回到根了那len=0,如果还没到根,len=当前结点的l值,并且继续转移。

定义数组f[i]为S[1..i]匹配的最大长度。则f[i]=max{f[j]+i-j|i-d[i]<=j<=i-L}。可惜这样dp是O(n^2)的承受不起。

 

优化:

f[i]=i+max{f[j]-j|i-d[i]<=j<=i-L}。然后令g[j]=f[j]-j。

考虑它的决策区间。

首先i-d[i]是非严格递增的。

 

如果有某个i,使得i-d[i]>i+1-d[i+1],由d[i]的意义知S[i-d[i]+1..i]和S[i+1-d[i+1]+1..i+1]是字典中某个串的子串。
由于i-d[i]+1>i+1-d[i+1]+1,知S[i+1-d[i+1]+1..i]也是字典中某个串的子串且比S[i-d[i]+1..i]更长,这与d[i]的定义中的max违背。

 

i-L是严格递增的。显然。

那么[i-d[i],i-L]只会右移。那么用g[]递减的一个单调队列维护决策区间即可。

 

最后的复杂度应该是O(nlogn)。

 

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int maxn=1100010;
int l[maxn],ch[maxn][3],d[maxn],f[maxn],g[maxn],q[maxn],fa[maxn];
int rt=1,tail=1,tot=1,n,len;    char s[maxn];
void add(int c){
    int p=tail,np=++tot,r,q;
    l[np]=l[p]+1;   tail=np;
    for(;p&&!ch[p][c];p=fa[p])  ch[p][c]=np;
    if(!p){ fa[np]=rt;  return;}
    if(l[p]+1==l[q=ch[p][c]]){  fa[np]=q;   return;}
    fa[r=++tot]=fa[q];  memcpy(ch[r],ch[q],sizeof(ch[r]));
    l[r]=l[p]+1;    fa[np]=fa[q]=r;
    for(;p&&ch[p][c]==q;p=fa[p])    ch[p][c]=r;
}void init(){
    int p=rt,k=0,x;
    for(int i=1;i<=n;i++){
        if(ch[p][x=s[i]-48])    k++,    p=ch[p][x]; else{
            while(p&&!ch[p][x]) p=fa[p];
            if(!p)  p=rt,k=0;   else    k=l[p]+1,p=ch[p][x];
        }d[i]=k;
    }
}int mt(int r){
    int s=1,e=0,t;
    for(int i=1;i<=n;i++){
        if((t=i-r)>=0){
            while(s<=e&&g[q[e]]<g[t])   e--;
            q[++e]=t;
        }while(s<=e&&q[s]<i-d[i])    s++;
        f[i]=f[i-1];
        if(s<=e&&f[i]<g[q[s]]+i)    f[i]=g[q[s]]+i;
        g[i]=f[i]-i;
    }return f[n]>=len;
}int main(){
    int tc,m;   scanf("%d%d",&tc,&m);
    for(int i=1;i<=m;i++){
        scanf("%s",s+1);
        int k=strlen(s+1);
        for(int j=1;j<=k;j++)    add(s[j]-48);
        if(i!=m)add(2);
    }while(tc--){
        scanf("%s",s+1);    n=strlen(s+1);
        len=(int)(ceil((double)n*0.9));
        init(); if(!mt(1)){   printf("0\n");  continue;}
        int L=2,R=n,mid;
        while(L<=R){
            mid=(L+R)>>1;
            if(mt(mid)) L=mid+1;    else    R=mid-1;
        }printf("%d\n",R);
    }return 0;
}