KUPC 2011 practice A: 秘境の呪文

問題概要

長さM(<50)の文字列がN(<10)個与えられる。全ての文字列に含まれる最長の部分文字列の長さを求める問題。

考えたこと

  • practiceのA問題なのに敷居高すぎじゃないか?初参加の人がA問題見て、「難しそうだし参加見送ろうかな」とかならないことを祈る。
  • 一番素朴な方法を考えよう。長さをM..0と降順に調べる。先頭の文字列から得られる部分文字列は大体M個。それが他の文字列に含まれるかどうかの判定ひとつにつきM^2。これだと全体でO(M^4*N)。
  • もちろん長さに関しては二分探索が使えてO(M^3*N*logM)に落ちるし、実際は定数も小さいので余裕で間に合うはず。
  • しかしせっかくなので別の解法でやろう。完全に好みの問題で、ラビンカープを使うことにする。
  • 例によってハッシュ値の一致だけで判定する。今回の問題ではNが大きいほどハッシュ値の衝突を考えなくてよくなる。
  • やり方は、各文字列に対して長さKの部分文字列のハッシュ値をローリングハッシュで計算し、配列に突っ込む。その後N個のハッシュ値配列の共通部分をとって空でなければ長さKの部分文字列が存在すると高い確率で言えるだろう。
  • Kについては二分探索が使えるので、全体の計算量はO(N * M * log M)。
  • 実装のときは手抜きでKについて降順に調べたのでO(N*M^2)となっている。
  • 空間計算量はO( (N+M)*M )。

practiceで通ったコード

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

typedef unsigned long long uint64;

const int MAX_N = 10;
const int MAX_M = 50;
const int P = 1000000007;

int N, M;
char strs[MAX_N][MAX_M+1];

//ローリングハッシュ
vector<uint64> hash_func(const char str[MAX_M+1], const int len){
	vector<uint64> ret;

	uint64 hash = 0, base = 1;
	for(int i=0; i<len; i++){
		hash = hash*P + str[i];
		base *= P;
	}
	ret.push_back(hash);

	for(int i=len; i<M; i++){
		hash = hash * P + str[i] - str[i-len] * base;
		ret.push_back(hash);
	}

	sort(ret.begin(), ret.end());

	return ret;
}

int solve(){
	
	//とりあえず全探索で。ダメなら二分探索
	for(int len = M; len > 0; len--){
		vector<uint64> hashes = hash_func(strs[0], len);
		for(int i=1; i<N; i++){
			vector<uint64> next_hashes = hash_func(strs[i], len), tmp;
			set_intersection(hashes.begin(), hashes.end(),
								next_hashes.begin(), next_hashes.end(),
								back_inserter( tmp ));
			tmp.swap(hashes);
		}

		if(!hashes.empty()) return len;
	}

	return 0;
}

int main(){

	//入力
	scanf("%d%d\n", &N, &M);
	for(int i=0; i<N; i++){
		scanf("%[^\n]%*c", strs[i]);
	}

	//出力
	printf("%d\n", solve());

	return 0;
}

二分探索

int solve(){
	
	//答えに関して二分探索
	int low = 0, high = M+1;
	while(high - low > 1){
		int len = (high + low) / 2;
		vector<uint64> hashes = hash_func(strs[0], len);
		for(int i=1; i<N; i++){
			vector<uint64> next_hashes = hash_func(strs[i], len), tmp;
			set_intersection(hashes.begin(), hashes.end(),
								next_hashes.begin(), next_hashes.end(),
								back_inserter( tmp ));
			tmp.swap(hashes);
		}

		if(!hashes.empty()){
			low = len;
		}
		else{
			high = len;
		}
	}

	return low;
}

バグを減らすための工夫

  • 整数の二分探索は苦手で、よく間違える。
  • 条件を満たすときlowを更新するのかhighを更新するのか、最後に何を返すのか、lowとhighの初期値は何にすべきなのか…等々。
  • これに関しては何か暗記して覚えるというよりも、毎回数直線書いて確認するのがいい気がする。