SRM 546 500pt : FavouriteDigits

問題概要

整数N(<10^15)が与えられる。N以上の数で、10進表記したときdigit1がcount1回以上でて、digit2がcount2回以上でてくるようなものの最小値を求める問題。count1+count2<=15。

考えたこととか

  • ???桁数少ないし、3^n*nが枝刈り込みで間に合うような気がする。え、250よりはるかに簡単では。
  • 考え直して見るけど、どうやってもこれで行けそうにしか見えない。
  • もう少し詳しく。3^n*nはn=15で3*10^8くらいだから危ないけど、実際はcount1やcount2の数が多くても少なくても枝は刈れる。直感では少なくとも10倍くらいは余裕で刈れるはず。なので計算量の面では問題ない。
  • で、各桁にdigit1を埋めるかdigit2を埋めるかそれ以外を埋めるかを決めたら、後は辞書順最小の定石にしたがうだけ。それ以外のところを最初全部9で埋めておき、先頭の桁からなるべく小さい方に更新していく。このとき調べるのは0、もとの数字、もとの数字+1、9くらい調べておけば十分。
  • あからさまに繰り上げがコーナーケースっぽいし実際サンプルに入っている。でもこれは前処理でカバーできる。
  • 例えばN=3456とかのとき、10000より大きい解を最初に作っておく、これは貪欲にdigit1とdigit2と0を詰めるだけ。
  • よし、実装始めよう。
  • 前処理で時間かかる。後は再帰で書くだけ。
  • 合わない…。何故だ…。
  • あっ、count1 + count2が元の桁数より大きくなることがあるのか。修正…まだ合わない。
  • タイムアップ。0完でかなり悲しかった。
  • 後で見たら、for文のループの向きが逆だった。毎回補完に頼っていたのでつい手癖で昇順を書いてしまっていた。

別の解法

DPで下から決めていく。これだと計算量が多項式時間に抑えられる。そしてこっちの方が場合分けは減らせてシンプルに書ける…。が、これは別に見えなくても仕方ないかな…。

acceptされたコード

探索版、汚い…。これではミスって当然。

#include <cstdio>
#include <cassert>
#include <algorithm>
#include <cstring>
using namespace std;

typedef long long int64;

const int64 INF = 1LL<<62;

int L, d1, d2, C1, C2;
int64 ANS;
int64 U, N;
int64 tens[20];
int ts[20];
int ns[20];

struct FavouriteDigits {

	int64 findNext(int64 N_, int digit1, int count1, int digit2, int count2) {
		if(digit1 < digit2){
			return findNext(N_, digit2, count2, digit1, count1);
		}
		d1 = digit1;
		d2 = digit2;
		C1 = count1;
		C2 = count2;
		N = N_;
		tens[0] = 1;
		for(int i=0; i<18; ++i){
			tens[i+1] = tens[i] * 10;
		}
		int64 ans = INF;
		ANS = INF;
		char buf[100];
		sprintf(buf, "%lld", N);
		L = strlen(buf);
		for(int i=0; i<L; ++i){
			ns[i] = buf[L-1-i]&15;
		}
		U = 1;
		for(int _=0; _<L; ++_){
			U *= 10;
		}
		if(C1 + C2 > L){
			N = U;
			return findNext(N, digit1, count1, digit2, count2);
		}


		{
			char buf2[100];
			sprintf(buf2, "%lld", U);
			int l = strlen(buf2);
			int c1 = count1, c2 = count2;
			if(digit1 == 1){
				--c1;
			}
			else if(digit2 == 1){
				--c2;
			}
			for(int i=l-1; i>=1; --i){
				if(c1 > 0){
					--c1;
					buf2[i] = digit1 + '0';
				}
				else if(c2 > 0){
					--c2;
					buf2[i] = digit2 + '0';
				}
				else{
					buf2[i] = '0';
				}
			}
			sscanf(buf2, "%lld", &ans);
		}

		rec(0, 0, 0);

		return min(ANS, ans);
	}

	void rec(int depth, int c1, int c2){
		if(depth == L){
			if(!(c1 >= C1 && c2 >= C2)){
				return ;
			}

			int64 t = 0;
			for(int i=0; i<L; ++i){
				if(ts[i] == 1){
					t += tens[i] * d1;
				}
				else if(ts[i] == 2){
					t += tens[i] * d2;
				}
				else{
					t += tens[i] * 9;
				}
			}

			if(t >= N){
				ANS = min(ANS, t);
				for(int i=L-1; i>=0; --i)if(!ts[i]){
					t -= tens[i] * 9;
					if(t >= N){
						continue;
					}

					t += tens[i] * ns[i];
					if(t >= N){
						continue;
					}

					if(ns[i] != 9){
						t += tens[i] * 1;
					}
				}
				ANS = min(ANS, t);
			}

			return ;
		}

		if((C1 - c1 + C2 - c2) < L - depth + 2){
			ts[depth] = 0;
			rec(depth + 1, c1, c2);
		}
		if(c1 < C1){
			ts[depth] = 1;
			rec(depth + 1, c1+1, c2);
		}
		if(c2 < C2){
			ts[depth] = 2;
			rec(depth + 1, c1, c2+1);
		}
	}

};