SRM 541 550pt : AkariDaisukiDiv1

問題概要

f(X)=A+X+B+X+C (各変数は文字列)がある。f^k(X) (k<10^7)に対してパターンPが何回現れるかをMOD 10^9+7で求める問題。A,B,C,X,Pは50文字以下。

解法

Xの中にPがk回現れるとすると、f(X)の中でPは2*k+(AとBとCに一部が含まれるようなPの個数)となる。( .. )の部分は文字列のprefixとsuffixを持っておけばよい。毎回prefixとsuffixを計算し直すと計算量O(k*|P|)とかでTLEするけど、prefixとsuffixはしばらくすると変化しなくなる。なので、変化が生じる序盤だけprefixとsuffixを計算して、prefixとsuffixが動かないようになったら後は残り回数反復させる。
ただし、こういうのは小さいときに色々面倒なので、文字列が短いところでは真面目に計算して持っておく。
もちろん最後は行列累乗とかすれば計算量落とせるけど、どうせ間に合うんだから愚直にやってよい。

acceptされたコード

#include <string>
#include <algorithm>
using namespace std;

typedef long long int64;

const int64 MOD = (int64)(1e9 + 7);

struct AkariDaisukiDiv1 {

	int countF(string A, string B, string C, string X, string P, int iter) {
		const int L = P.length();
		string cur = A + X + B + X + C;
		int64 m = 0;
		for(int i=0; i + L <=(int)cur.length(); ++i){
			if(cur.substr(i, L) == P){
				++m;
			}
		}
		string st, ed;
		st = cur.substr(0, L);
		ed = (int)cur.length() >= L ? cur.substr((int)cur.length() - L) : cur;
		bool remain = false;

		int64 cBa = 0, Aa = 0, cC = 0;
		for(int _=1; _<iter; ++_){
			if((int)cur.length() <= L){
				cur = A + cur + B + cur + C;
				st = cur.substr(0, L);
				ed = (int)cur.length() >= L ? cur.substr((int)cur.length() - L) : cur;
				m = 0;
				for(int i=0; i + L <=(int)cur.length(); ++i){
					if(cur.substr(i, L) == P){
						++m;
					}
				}
			}
			else if(remain){
				m = (2 * m + cBa + Aa + cC) % MOD;
			}
			else{
				int64 nm = 2 * m % MOD;
				string nst = (A + st).substr(0, L);
				string ned = ed + C;
				if((int)ned.length() > L){
					ned = ned.substr((int)ned.length() - L);
				}

				cBa = Aa = cC = 0;

				const int sl = st.length();
				const int el = ed.length();

				string edBst = ed + B + st;
				for(int i=(sl==L?1:0), j=(el==L?-1:0) + (int)edBst.length(); i + L <= j; ++i){
					if(edBst.substr(i, L) == P){
						++cBa;
					}
				}
				string Ast = A + st;
				for(int i=0, j=(el==L?-1:0) + (int)Ast.length(); i + L <= j; ++i){
					if(Ast.substr(i, L) == P){
						++Aa;
					}
				}

				string edC = ed + C;
				for(int i=(sl==L?1:0), j=(int)edC.length(); i + L <= j; ++i){
					if(edC.substr(i, L) == P){
						++cC;
					}
				}

				nm = (nm + cBa + Aa + cC) % MOD;
				if(st == nst && ed == ned){
					remain = true;
				}
				st = nst;
				ed = ned;
				m = nm;
			}
		}

		return (int)((m%MOD + MOD)%MOD);
	}

};