SRM 534 500pt : EllysNumbers

問題概要

整数n(<10^18)と重複を持たない数列A(|A|=:N<500, A[i]<10^9)が与えられる。互いに素な部分集合を選んで積がnに等しくなるのは何通りか求める問題。

解法

nを素因数分解してp_i^e_iが出てきたとき、これと同じ素数のべきで構成された数だけを使う。なので、まずそれぞれの要素を素因数分解してその約数でnを割り、割りきれなかったら即座に0を返す。割り切れないときはどのp_i^e_iが満たされていないかを状態にしてビットDPする。p_i^e_iの個数はたかだか15なので間に合う。
前処理は面倒だけど複雑ではなく、ビットDPパートは簡単。ビットDPの部分は、1を特別扱いすれば必ずスーパーセットに配る形になってトポロジカル順序が自明に出てきてよい(空間計算量も自然に抑えられる)。ただし自分のコードは1を特別扱いしていない。

acceptされたコード

計算量O(2^15*500)くらい

#include <sstream>
#include <vector>
#include <string>
#include <numeric>
#include <map>
#include <algorithm>
#include <cstring>
using namespace std;

typedef long long int64;

const int MAX_N = 500;
const int MAX_M = 16;

int64 cur[1<<MAX_M], nxt[1<<MAX_M];
int bits[MAX_N];
bool valid[MAX_N];

struct EllysNumbers {

	int64 getSubsets(int64 n, vector <string> special) {
		vector<int64> xs;
		stringstream ss(accumulate(special.begin(), special.end(), string()));
		for(int64 x; ss>>x; ){
			xs.push_back(x);
		}

		const int N = xs.size();
		vector< map<int64,int> > divs(N);
		vector<int64> as;

		for(int i=0; i<N; i++){
			int64 m = xs[i];
			for(int j=2; j*j <= m; j++)if(m%j==0){
				for(;m%j==0;){
					as.push_back(j);
					m /= j;
					divs[i][j]++;
				}
			}
			if(m != 1){
				as.push_back(m);
				divs[i][m]++;
			}
		}

		sort(as.begin(), as.end());
		as.erase(unique(as.begin(), as.end()), as.end());

		map<int64, int> div_n;

		int64 nn = n;
		for(int i=0; i<(int)as.size(); i++){
			int64 x = as[i];
			for(;nn%x == 0;){
				nn /= x;
				div_n[x]++;
			}
		}
		if(nn != 1){
			return 0;
		}

		const int M = div_n.size();

		map<int64, int> appear;
		int p = 0;
		for(map<int64, int>::iterator itr=div_n.begin(); itr!=div_n.end(); itr++){
			appear[itr->first] = p++;
		}

		for(int i=0; i<N; i++){
			bool ok = true;
			for(map<int64,int>::iterator itr=divs[i].begin(); itr!=divs[i].end(); itr++){
				if(div_n.find(itr->first)==div_n.end() || div_n[itr->first] != itr->second){
					ok = false;
				}
				else {
					bits[i] |= (1<<appear[itr->first]);
				}
			}

			valid[i] = ok;
		}

		cur[0] = 1;
		for(int i=0; i<N; i++)if(valid[i]){
			memset(nxt, 0, sizeof(nxt));

			for(int bit=0; bit<(1<<M); bit++){
				nxt[bit] += cur[bit];
				if( (bit ^ bits[i]) == (bit | bits[i]) ){
					nxt[bit | bits[i]] += cur[bit];
				}
			}
			memcpy(cur, nxt, sizeof(cur));
		}

		return cur[(1<<M)-1];
	}

};