POJ-1759 : Garland

問題概要

長さN(<1000)の数列Aがある。A[1]は固定されている。それ以外の部分について、A[i]>=0, A[i] = (A[i-1]+A[i+1])/2 - 1 (i in (1,N))となる。A[N]のとりうる最小値を求める問題。

解法

最小値はA[N]について単調なので答について二分探索する。答を決めたらそれ以外の値は連立一次方程式を解くことで求めることができる。

acceptされたコード

#include <cstdio>
#include <algorithm>
using namespace std;

const int MAX_N = 1000;
int N;
double x0;

void init() {
	scanf("%d%lf", &N, &x0);
}

template<typename numType>
inline bool updateMin(numType& lhs, const numType& rhs) {
	if (lhs > rhs) {
		lhs = rhs;
		return true;
	}
	return false;
}

const int MAX_EQ = 1000;

void solveTridiagonal(int n, double d[], double l[], double r[], double b[], double x[]) {
	//右上と左下が0の場合
	if (l[0] == 0.0 && r[n-1] == 0.0) {
		for (int i = 1; i < n; ++i) {
			d[i] -= r[i-1] * (l[i] / d[i-1]);
			b[i] -= b[i-1] * (l[i] / d[i-1]);
		}
		for (int i = n-2; i >= 0; --i) {
			b[i] -= b[i+1] * (r[i] / d[i+1]);
		}
		for (int i = 0; i < n; ++i) {
			x[i] = b[i] / d[i];
		}
		return ;
	}

	static double d2[MAX_EQ], l2[MAX_EQ], r2[MAX_EQ], b2[MAX_EQ], a[MAX_EQ];

	for (int i = 0; i < n-1; ++i) {
		d2[i] = d[i];
		l2[i] = (i ? r[i-1] : 0.0);
		r2[i] = (i < n-2 ? l[i+1] : 0.0);
	}

	fill(b2, b2+n-1, 0.0);
	b2[0] = r[n-1];
	b2[n-2] = l[n-1];
	solveTridiagonal(n-1, d2, l2, r2, b2, a);

	for (int i = 0; i < n-1; ++i) {
		b[n-1] -= a[i] * b[i];
	}
	d[n-1] -= a[n-2] * r[n-2] + a[0] * l[0];
	b[0] -= b[n-1] * l[0] / d[n-1];
	b[n-2] -= b[n-1] * r[n-2] / d[n-1];
	l[0] = r[n-2] = 0.0;
	solveTridiagonal(n-1, d, l, r, b, x);
	x[n-1] = b[n-1] / d[n-1];
}

double d[MAX_N], l[MAX_N], r[MAX_N], b[MAX_N], x[MAX_N];

bool check(double test) {
	for (int i = 0; i < N; ++i) {
		d[i] = 1.0;
		if (i == 0 || i == N - 1) {
			l[i] = r[i] = 0.0;
		}
		else {
			l[i] = r[i] = -0.5;
			b[i] = -1.0;
		}
	}
	b[0] = x0;
	b[N-1] = test;

	solveTridiagonal(N, d, l, r, b, x);

	double mini = 1e100;
	for (int i = 0; i < N; ++i) {
		updateMin(mini, x[i]);
	}
	return mini > 0;
}

double solve() {
	double low = 0.0, high = 1e20;
	for (int _ = 0; _ < 100; ++_) {
		double mid = (low + high) / 2.0;
		if (check(mid)) {
			high = mid;
		}
		else {
			low = mid;
		}
	}
	return high;
}

int main() {
	init();
	printf("%.2f\n", solve() + 1e-11);
	return 0;
}