連立一次方程式ソルバ
確率や期待値のDPで式を立てたとき、連立一次方程式
Ax = b
を解きたくなることはよくあります。競技プログラミングで使うようなものに限っていくつかメモしておきます。
クラメルの公式
行列式を計算することで厳密解を求めます。nが小さい(2や3)ときに式を手で展開して埋め込むのはありかもしれません。
ガウスの消去法
蟻本などで紹介されている通称掃き出し法。O(N^3)で解を求めます。分かりやすいけど、遅いし誤差にも弱いことが知られています。ピボットルールとか気にしないといけません。ライツアウトみたいにMOD Mで解きたいときなどは誤差を無視できますが。
LU分解
O(N^3)の前処理でA=LU(or LDU)の形に分解しておくことで、bを取り替えたときにO(N^2)で答えを求めることができます。
三重対角行列1
A[i,j] = 0 (|i-j|>=2)であるような行列に対してはO(N)で解を計算できます。前進消去と後退代入を必要な部分だけするようにしているだけで本質的には掃き出し法です。状態間の依存関係が線になっているときに現れます。
三重対角行列2
三重対角行列の右上と左下にぽつんと非ゼロ要素がくっついた行列です。A[i,j] = 0 (|i-j|>=2 and (i,j)!=(0,n-1), (n-1,0))であるような行列に対しても、やはりO(N)で解を求めることができます。やはりアイデアは前進消去と後退代入で、部分問題として右上と左下が0の三重対角行列のを解きます。状態間の依存関係が環になっているときに現れます。結構出現頻度は高いんじゃないでしょうか。
反復法
厳密解を諦めて(丸め誤差があるので「厳密解」というのは殆ど意味がない概念です)、近似解を作りにいく方法です。適当な初期値からスタートしてよい解が得られるまで何度も改良を繰り返す方法を反復法といいます。真の解に「確実に、速く」収束するのがよい反復法です。ヤコビ法やSOR法などがよく知られています。
共役勾配法
数値計算に詳しい方に「最強の教えてください」と尋ねたら共役勾配法(CG法)を紹介してもらいました。
この方法は反復法のように係数行列に対して行列ベクトル積をとる操作を何度か繰り返して解を改良していきます。反復法と違うのは、n回改良を繰り返すと丸め誤差を無視すれば厳密解が得られることです。ただし実際にn回繰り返すことは少なく、それよりも速く真の解に収束させる(近づける)ことを狙います。
しかしこの方法は丸め誤差の影響を受けやすく、ランダムに作った係数行列などで試してみると解はほとんど収束しません。しかし、実際に問題を解く際に現れるような係数行列に対してはうまくいくことが多いようです。
共役勾配法はさまざまなバリエーションがありますが、BiCGSTABというのが非対称行列でも使えて実装もシンプルなので手頃だと思います。
係数行列が疎な場合(実用上疎なことが多い)には行列の持ち方を変えることによって空間計算量と行列ベクトル積の時間計算量をO(N)に落とすことができます。
実際にどれくらいの反復回数で解が十分近似できるかは係数行列によってばらつきが大きいですが、nの平方根程度繰り返せば十分なことが多いように感じました。テストが十分ではないので嘘かもしれません。n=10^4くらいだったら間に合うことが多い様に感じました(最適化オプションでだいぶ速度が変わりました)。
とはいえ、収束しないことがあるというのは極めて大きいデメリットなので使うのは最後の手段にしようと思いました。
追記:色々試してみたら収束しないケースが多すぎてホント最後の手段にしかなっていない。BiCGSTAB以外にLSQRというのがあって、多少数値的に安定になるそうだけどそれを組んでも収束しないことは多々あった(収束も遅くなった)。
テストはSRM504.5 TheTicketsDivOneにて。係数行列が特殊なのでちょっと不十分。
三重対角行列
- n=1のときやn=2のときの動作は結構怪しい。
- D[i] = A[i,i], L[i] = A[i,i-1], R[i] = A[i,i+1]とする。
const int MAX_EQ = 100000; //元の方程式が解を持たない場合の動作は不明 //破壊的、xに解が入る以外は何の保証もない //n=1の場合も結構怪しい void solve_tridiagonal(int n, double d[], double l[], double r[], double b[], double x[]){ //右上と左下が0の場合 if(l[0] == 0.0 && r[n-1] == 0.0){ for(int i=1; i<n; ++i){ d[i] -= r[i-1] * (l[i] / d[i-1]); b[i] -= b[i-1] * (l[i] / d[i-1]); } for(int i=n-2; i>=0; --i){ b[i] -= b[i+1] * (r[i] / d[i+1]); } for(int i=0; i<n; ++i){ x[i] = b[i] / d[i]; } return ; } static double d2[MAX_EQ], l2[MAX_EQ], r2[MAX_EQ], b2[MAX_EQ], a[MAX_EQ]; for(int i=0; i<n-1; ++i){ d2[i] = d[i]; l2[i] = (i ? r[i-1] : 0.0); r2[i] = (i < n-2 ? l[i+1] : 0.0); } fill(b2, b2+n-1, 0.0); b2[0] = r[n-1]; b2[n-2] = l[n-1]; solve_tridiagonal(n-1, d2, l2, r2, b2, a); for(int i=0; i<n-1; ++i){ b[n-1] -= a[i] * b[i]; } d[n-1] -= a[n-2] * r[n-2] + a[0] * l[0]; b[0] -= b[n-1] * l[0] / d[n-1]; b[n-2] -= b[n-1] * r[n-2] / d[n-1]; l[0] = r[n-2] = 0.0; solve_tridiagonal(n-1, d, l, r, b, x); x[n-1] = b[n-1] / d[n-1]; }
BiCGSTAB
- nが小さいときは非常に怪しい。
- ゼロ割でnanが返ってくることもしばしば。
- ノルムのとり方は適当に取り替えてもよい。
- r0やxの初期値も適当に変えてよい。
- 係数行列がある程度良い性質を持っていないと全然収束しない。使うのは最後の手段で。
- 蟻本の例題「ランダムウォーク」でW=100, H=100が解けたりするけどW=1, H=2が解けなかったりする。
- 小さい場合は掃き出し法やクラメルの公式で解くのもあり。
const int MAX_EQ = 10000; const double TOL = 1e-15; struct SparseMatrix{ int length; vector<double> vals[MAX_EQ]; vector<int> ids[MAX_EQ]; SparseMatrix(){ for(int i=0; i<length; ++i){ vals[i].clear(); ids[i].clear(); } } SparseMatrix(int n):length(n){ SparseMatrix(); } //A[i,j] += v void add(int i, int j, double v){ ids[i].push_back(j); vals[i].push_back(v); } }; void mat_vec_prod(const SparseMatrix& A, const double x[], double dst[]){ const int n = A.length; for(int i=0; i<n; ++i){ dst[i] = 0.0; for(int j=0; j<(int)A.ids[i].size(); ++j){ dst[i] += A.vals[i][j] * x[A.ids[i][j]]; } } } double iprod(int n, const double a[], const double b[]){ double ret = 0.0; for(int i=0; i<n; ++i){ ret += a[i] * b[i]; } return ret; } double norm(int n, const double a[]){ double ret = 0.0; for(int i=0; i<n; ++i){ ret += abs(a[i]); } return ret; } void BiCGSTAB(const SparseMatrix& A, const double b[], double x[]){ const int n = A.length; static double r0[MAX_EQ], r[MAX_EQ], p[MAX_EQ], t[MAX_EQ], nr[MAX_EQ], Ax[MAX_EQ], Ap[MAX_EQ], At[MAX_EQ]; fill(x, x+n, 0.0); mat_vec_prod(A, x, Ax); for(int i=0; i<n; ++i){ r0[i] = p[i] = r[i] = b[i] - Ax[i]; } double normb = norm(n, b); for(int _=0; _<n; ++_){ //for(;;){ mat_vec_prod(A, p, Ap); double alpha = iprod(n, r0, r) / iprod(n, r0, Ap); for(int i=0; i<n; ++i){ t[i] = r[i] - alpha * Ap[i]; } mat_vec_prod(A, t, At); double omega = iprod(n, t, At) / iprod(n, At, At); for(int i=0; i<n; ++i){ x[i] += alpha * p[i] + omega * t[i]; } for(int i=0; i<n; ++i){ nr[i] = t[i] - omega * At[i]; } double beta = alpha / omega * iprod(n, r0, nr) / iprod(n, r0, r); for(int i=0; i<n; ++i){ r[i] = nr[i]; } for(int i=0; i<n; ++i){ p[i] = r[i] + beta * (p[i] - omega * Ap[i]); } if(norm(n, r) < TOL * normb){ break; } } }