SRM545 500pt : Spacetsk

問題概要

L*H(L,H<2000)のグリッドがあり、その格子点の中からK(<2000)点選ぶ。ただし、K個の点は全て同一直線上にあり、かつその直線は傾き0でなく、(i,0) (0<=i<=L)を通らなくてはならない。そのような選び方が何通りあるか求める問題。

考えたこととか

  • 2000だしO(N^2)かなあ、setとかの重いlogはくっつけないようにしないと。
  • dpで考えてみるとdp[k段目][x座標][これまでに使った個数]でこれだけでO(N^3)なので無理げ。
  • dpでなく、線を決めて考えると…?
  • 傾きの候補はO(N^2)個なので全部列挙できる。原点を通ることにすれば、gcdで中にある点の個数も求められる。きっとこれだな…。
  • まずK=1の場合は例外処理で答を返せばよい。これはチャレンジポイントにはならないか…。
  • 傾きは正のものだけ考えて後で2倍しよう。x軸に垂直な線は例外処理。
  • 後は傾き全列挙して、その傾きごとにいくつあるかamortizeでO(1)に求めれば良い。
  • amortizeとか言ったけど差分情報とか使いにくそうだし単純にO(1)だろうなあ。
  • 右端で切れるのと上端で切れるのあるけど分けて考えるのちょっと面倒だなあ。
  • mod dxで考えたりすると上手く行くんだろうけど植木算的な+1とかが出てきて苦手だ…。
  • しかも二項係数の和を効率的に計算する必要があるなあ。
  • あっ、これは単に前処理で累積和計算しておくだけか。
  • とはいえ方針はあってそう。よし書くか。
  • サンプル通らない。しかも足りなかったり余ったりしていて原因の特定が難しい。
  • どこだ?どこが間違ってるんだ?時間が減っていてヤバい。さすがにこのレベルの問題は通せないとひどい。
  • 二項係数の計算間違ってるじゃねーか。というかライブラリにあるのに何故スクラッチで書いたんだ自分は。
  • サンプル通った。一応提出前に最大ケーステストしよう。
  • メモリ確保で死んでる。どこだ?
  • vectorにpairを2000^2個突っ込むのが危なそう。時間ないから気持ち程度の枝刈り入れて、テストする時間もないからそのまま出しちゃえ。
    • 結局これが原因で落ちた。後で固定長配列に書き直したらちゃんと通った(多分reserveするだけでもよかった)。実は互いに素かどうか調べるだけだから必要なかった処理ではあるけど。
    • もったいない。

acceptされたコード

#include <cstdio>
#include <algorithm>
using namespace std;

typedef long long int64;

const int MAX_N = 2002;
const int64 MOD = (int64)(1e9 + 7);

int64 comb[MAX_N + 1][MAX_N + 1];
int64 accum[MAX_N + 1];
pair<int,int> lines[MAX_N * MAX_N];
int LL;

struct Spacetsk {

	int countsets(int L, int H, int K) {
		if(K == 1){
			return (L + 1) * (H + 1);
		}

		int64 ans = 0;
		for(int i=0; i<=MAX_N; ++i){
			comb[i][0] = comb[i][i] = 1;
			for(int j=1; j<i; ++j){
				comb[i][j] = (comb[i-1][j-1] + comb[i-1][j]) % MOD;
			}
		}
		for(int i=0; i<=MAX_N; ++i){
			if(i < K){
				accum[i+1] = accum[i];
			}
			else{
				accum[i+1] = (accum[i] + comb[i][K]) % MOD;
			}
		}

		for(int i=1; i<=H; ++i){
			for(int j=1; j<=L; ++j){
				int d = __gcd(i, j);
				lines[LL++] =  make_pair(j/d, i/d) ;
			}
		}

		sort(lines, lines + LL);
		LL = unique(lines, lines + LL) - lines;

		for(int k=0; k<LL; ++k){
			int dx = lines[k].first, dy = lines[k].second;
			int d = min(H / dy, L / dx);
			if(d + 1 < K){
				continue;
			}
			int64 res = (int64)(L - dx*d + 1) * comb[d+1][K] % MOD;
			ans = (ans + res) % MOD;
			ans = (ans + dx * accum[d + 1]) % MOD;
		}

		ans = ans * 2 % MOD;
		ans = (ans + (L + 1) * comb[H + 1][K]) % MOD;
		return (int)ans;
	}

};