三分探索の精度

凸関数の最大値は精度良く求められるけど、最大値を与える値の精度はよくないというお話。

  • 具体的な問題で考えてみよう。
  • 関数f(x) = x*(1-x)の最大値と、最大値を与える値を誤差1e-9以下で計算しなさい。
  • もちろん高校数学の範囲で解けて、厳密解はx=0.5で最大値0.25をとる。
  • これを三分探索で解いてみよう。
const double EPS = 1e-10;

double f(double x){
	return x*(1.0-x);
}

int main(){
	double low = 0.0, high = 1.0;
	while(high - low > EPS){
		double width = high - low, ml = low + width/3.0, mr = low + 2.0*width/3.0;
		if(f(ml) >= f(mr)){
			high  = mr;
		}
		else{
			low = ml;
		}
	}
	puts(" ------- ");
	printf("x: %.9f\n",(high+low)*0.5);
	printf("f(x): %.9f\n", f((high+low)*0.5));
	return 0;
}
  • 実行結果は
 ------- 
x: 0.499999996
f(x): 0.250000000

となり、f(x)は求まっているもののxは誤差が1e-9を越えてしまっている。

  • この原因は丸め誤差で、x=0.5の近傍では丸め誤差のせいでf(x+ε)=f(x) (εは1e-9より少し大きい値を想定した微小量)になってしまう可能性があるから。
    • もう少し詳しく書くと、f(x+ε)はf(x)+f'(x)*ε+0.5*f''(x)*ε^2と近似されるが、
    • f'(x)の値が0で、かつεも小さいので、それらを掛け合わせた第2項、第3項が桁落ちしてしまう可能性がある。
      • fをxの周りでTaylor展開するともう少し詳しく分かるかも。
    • なので、下手をするとlong doubleにしても通らない可能性がある。
      • (ちなみにこの問題ではfをlong doubleにすると通る)
    • 最大値そのものは正確に求まることも上の話から明らか。
  • 上の話は実は二分探索するときも注意しなければならない。
    • 二分探索ではf'が必ず正(or 負)になると保証されているけど、
    • |f'(x)|が0に近い部分で零点を持つ場合はやはり桁落ちしてしまう可能性がある。
  • 話をもとの問題に戻そう。
  • 丸め誤差が入ったらf(x+ε)が最大値をとるというのは数値計算的に間違っていないんだから、普通に探索やっても上手くいきそうに無い。
  • じゃあどうすれば良いかというと、これは数学の力に頼ってf'(x)=0となるxを求めるしかない。
  • 幸いなことに、凸関数はf'が単調なので二分探索で零点を求めることができる。
    • もしくは、手計算で厳密解を出すか。
  • とはいえ、今までの議論よりf'(x)をf(x)から数値的に近似するのは駄目だと分かる。
  • 以下のコードで、gはf'(x)*Δxの近似値。
double g(double x){
	return f(x+1e-10)-f(x);
}

int main(){
	double low = 0.0, high = 1.0;
	while(high - low  > EPS){
		double mid = (high + low)*0.5;
		if(g(mid)<0){
			high = mid;
		}
		else{
			low = mid;
		}
	}

	puts(" ------- ");
	printf("x: %.9f\n",(high+low)*0.5);
	printf("f(x): %.9f\n", f((high+low)*0.5));

	return 0;
}

実行結果は

 ------- 
x: 0.500000067
f(x): 0.250000000

となり失敗。なお、これはgに関してみると二分探索の失敗例にもなっている。

  • なので、f'(x)は厳密な奴を使う。
double h(double x){
	return 1.0 - 2.0*x;
}

int main(){
	double low = 0.0, high = 1.0;
	while(high - low  > EPS){
		double mid = (high + low)*0.5;
		if(h(mid)<0){
			high = mid;
		}
		else{
			low = mid;
		}
	}

	puts(" ------- ");
	printf("x: %.9f\n",(high+low)*0.5);
	printf("f(x): %.9f\n", f((high+low)*0.5));

	return 0;
}

実行結果は

 ------- 
x: 0.500000000
f(x): 0.250000000
  • やっと答えがでた。
  • fの導関数が分からないときはどうすればいいのかって?
  • どうすればいいんでしょうね。
  • マニアックな話の様だけど、ICPCの国内予選で出たSecrets in Shadowsはこれに近い話のせいで解けなかった。