SRM 504.5 900pt: TheTicketsDivOne

keyword

メモ化再帰 確率 C++

問題概要

N(<1000)人が一列に並んでいる。N=1のときか、1/6の確率で先頭の人が勝利する。1/3の確率で先頭の人は列から抜け、1/2の確率で列の後ろに並び直す。M番目の人が勝利する確率を求める問題。

解法

確率と言えば解法は

  • 高校数学っぽい場合の数
  • 二分探索
  • 動的計画法
  • (連立)方程式

だけど、今回は状態数が1000^2しか無いしDPでいけそう。ただし状態がループする可能性があるので方程式っぽくもある。これまたよくあるパターン。
ただし、状態がループする可能性は小さい(底が1/2以下で指数的に減少する)ので確率が十分小さくなったら打ちきれば良さそう。本番中に提出したコードはN,Mと深さでメモ化再帰した。ただしメモリやスタックオーバーフローの関係上Nが大きい場合と小さい場合で深さの上限を変えたりした。
別解として打ち切り無しのDP解法を考える。こっちのほうが綺麗。まず、m>1のときはdp[n][m]=1/3*dp[n-1][m-1]+1/2*dp[n][m-1]。m=1のときは方程式を解くことによって解を求められる。詳細はソースコード参照のこと。ソースコード中のrec(n,k)はnCk*p^k*q^(n-k)を返す(n回の独立な試行で丁度k回確率pの事象が起こりそれ以外は確率qの事象が起こる確率)。この関数はいつか使う機会がありそう。

感想

こういう値が十分小さくなったら打ちきるみたいな解法は割と好み。あと本番中に提出したコードはスタックオーバーフローが結構怪しかった。深さが10^6だったので危険領域もいいところ。局所変数が多かったらアウトだったと思う。
打ち切り解法

double memo[1009][1009];
bool visited[1009][1009];
double memo2[109][109][109];
bool vis[109][109][109];

double find(int n, int m) {
	if(n>100){
		return solve(n,m);
	}
	else{
		return solve2(n,m,0);
	}
}

double solve(int n, int m){
	if(visited[n][m]) return memo[n][m];
	visited[n][m] = true;
	if(n==1){
		return  memo[n][m] = 1.0;
	}
	if(m==1){
		return memo[n][m] = (1.0/6.0 + 0.5*solve(n,n));
	}
	return memo[n][m] = (1.0/3.0*solve(n-1,m-1) + 1.0/2.0*solve(n,m-1));
}

double solve2(int n, int m, int d){
	if(d==100) return 0.0;
	if(vis[n][m][d]) return memo2[n][m][d];
	vis[n][m][d] = true;
	if(n==1){
		return memo2[n][m][d] = 1.0;
	}
	if(m==1){
		return memo2[n][m][d] = (1.0/6.0 + 0.5*solve2(n,n,d+1));
	}
	return memo2[n][m][d] = (1.0/3.0*solve2(n-1,m-1,d) + 1.0/2.0*solve2(n,m-1,d));
}

DP解法

double dp[1009][1009];
double memo[1009][1009];
bool visited[1009][1009];

class TheTicketsDivOne {
public:
double find(int n, int m) {
	dp[1][1] = 1.0;
	for(int i=2; i<=n; i++){
		double p = pow(1.0/2.0, i);
		dp[i][1] = 1.0/6.0;
		for(int k=1; k<i; k++){
			dp[i][1] += 1.0/2.0*(rec(i-1,k) * dp[i-k][1]);
		}
		dp[i][1] /= 1-p;
		for(int j=2; j<=i; j++){
			dp[i][j] = (1.0/3.0*dp[i-1][j-1] + 1.0/2.0*dp[i][j-1]);
		}
	}
	return dp[n][m];
}


//calculate (n,k) * (0.33)^k * (0.5)^(n-k)
double rec(int n, int k){
	if(visited[n][k]) return memo[n][k];
	visited[n][k] = true;
	if(n==1){
		return memo[n][k] = (k==0?1.0/2.0:1.0/3.0);
	}
	if(k==0){
		return memo[n][k] = (pow(1.0/2.0, n));
	}
	if(n==k){
		return memo[n][k] = (pow(1.0/3.0, n));
	}
	return memo[n][k] = (1.0/3.0*rec(n-1,k-1) + 1.0/2.0*rec(n-1,k));
}