2042 (AOJ 1241): Lagrange's Four-Square Theorem

keyword

組合せ C++

概要

n(<2^15)を4つ以下の正の平方数の和で表す方法は何通りあるか求める問題。クエリの数は2^8以下。
硬貨の支払い方の問題だと思うとDPが思い浮かぶ。素朴にやるとdp[n][平方数の個数][一番大きな硬貨]で、計算量は(2^15*4*2^8)*2^8となって解けない。DPではなく、各クエリに対して丁寧に場合分けして数え上げる。4枚で4つとも異なる数を求めるときは多少工夫が必要。1個のクエリに対して2^8 * 2^8 * log(2^8)程度の計算量で答えを出すことができる。

bool sq1[1<<15];
bool sq2[1<<15];
bool sq3[1<<15];
bool sq4[1<<15];
vector< vector<int> > ns;

inline int solve(int n){
    int ret = 0;

//1
    //1
    if(sq1[n]) ret++;
//2
    //2
    if(sq2[n]) ret++;

    //1-1
    for(int i=1; 2*i*i<n; i++){
        int i2 = i*i;
        if(sq1[n-i2]) ret++;
    }
//3
    //3
    if(sq3[n]) ret++;

    //2-1
    for(int i=1; i*i<n; i++){
        int i2 = i*i;
        if(i2*3 != n && sq2[n-i2]) ret++;
    }

    //1-1-1
    for(int i=1; i*i<n; i++){
        int i2 = i*i;
        for(int j=i+1; i2+j*j<n; j++){
            int j2 = j*j;
            int k2 = n - i2 - j2;
            if(j2<k2 && sq1[k2] ) ret++;
        }
    }
//4
    //4
    if(sq4[n]) ret++;

    //1-3
    for(int i=1; i*i < n; i++){
        int i2 = i*i;
        if(i2*4 != n && sq3[n-i2]) ret++;
    }

    //2-2
    for(int i=1; i*i < n; i++){
        int i2 = 2*i*i;
        for(int j=i+1; i2 + j*j <= n; j++){
            if(i2 + 2*j*j == n) ret++;
        }
    }

    //1-1-2
    for(int i=1; i*i < n; i++){
        int i2 = i*i;
        for(int j=i+1; i2 + j*j < n; j++){
            int j2 = j*j, k = n-i2-j2;
            if( sq2[k] && (k>>1)!=i2 && (k>>1)!=j2 ) ret++;
        }
    }

    //1-1-1-1
    for(int i=1; i*i < n; i++){
        int i2 = i*i;
        for(int j=i+1; i2 + j*j < n; j++){
            int j2 = j*j;
            ret += distance(lower_bound(ALL(ns[j]),n-j2-i2),
                            upper_bound(ALL(ns[j]),n-j2-i2));
        }
    }

    return ret;
}

int main(){
    int i, j, k, n;

    REP(i,1<<15) sq1[i] = false;
    REP(i,1<<15) sq2[i] = false;
    REP(i,1<<15) sq3[i] = false;
    REP(i,1<<15) sq4[i] = false;

    for(i=1;i*i<(1<<15);i++){
        sq1[i*i] = true;
    }
    for(i=1;2*i*i<(1<<15);i++){
        sq2[2*i*i] = true;
    }
    for(i=1;3*i*i<(1<<15);i++){
        sq3[3*i*i] = true;
    }
    for(i=1;4*i*i<(1<<15);i++){
        sq4[4*i*i] = true;
    }

    ns.PB(vector<int>());
    for(int i=1; i*i<(1<<15); i++){
        ns.PB(vector<int>());
        for(int j=i+1; j*j<(1<<15); j++){
            int j2 = j*j;
            for(int k=j+1; j2+k*k<(1<<15); k++){
                ns[i].PB(j2+k*k);
            }
        }
        sort(ALL(ns[i]));
    }

    while(scanf("%d",&n) && n){
        printf("%d\n",solve(n));
    }

    return 0;
}