BUET Inter-University Programming Contest - 2011 F, UVa-12429 : Finding Magic Triplets

問題概要

a + b^2 = c^3 (mod K(<10^5))となる1≦a≦b≦c≦N(<10^5)なる組の個数を求める問題。

解法

bをNから1に減らしながらやる。BITのb^3 (K)のindexに1を加えて、[b^2+1, b^2+b]までのBITの区間和を足し上げる。このとき、K

acceptされたコード

計算量O(N*log K)。

#include <cstdio>
#include <cstring>
using namespace std;

typedef long long int64;

const int MAX_K = (int)(1e5);

int N, K;

struct BinaryIndexedTree{

	typedef int64 bit_t;

	static const int MAX_BIT = 3*MAX_K + 1;
	bit_t data[MAX_BIT+1];
	int SIZE;

	void init(int size){
		memset(data, 0, sizeof(data));
		SIZE = size;
	}

	bit_t sum(int n){
		bit_t ret = 0;
		for(;n;n-=n&-n){
			ret += data[n];
		}
		return ret;
	}

	bit_t sum(int from, int to){
		return sum(to)-sum(from);
	}

	void add(int n, bit_t x){
		for(n++;n<=SIZE;n+=n&-n){
			data[n]+=x;
		}
	}
};

BinaryIndexedTree bitree;


void init(){
	scanf("%d%d", &N, &K);
}

int64 solve(){
	bitree.init(2*K+1);

	int64 ans = 0;
	for(int64 i=N; i>=1; i--){
		int64 b = i * i % K, c = i * i * i % K;
		bitree.add(c, 1);
		bitree.add(c+K, 1);
		bitree.add(c+2*K, 1);
		int64 len = i;
		if(len >= K){
			ans += (len / K) * bitree.sum(K);
			len %= K;
		}
		if(len > 0){
			ans += bitree.sum(b + 1, b + len + 1);
		}
	}

	return ans;
}

int main(){
	int T;
	scanf("%d", &T);
	for(int i=0; i<T; i++){
		init();
		printf("Case %d: %lld\n", i+1, solve());
	}

	return 0;
}