POJ-2155 : Matrix

問題概要

N*N(N<1000)の行列があり、初期値は全て0である。クエリがQ(<5*10^4)個与えられる。一つは矩形区間内の値を全てフリップする。もう一つは指定された要素を答える。

解法

2次元のRange Sum Queryを利用すれば解ける。RSQはsegtreeでやってもよいが、練習がてらBITを用いて解いてみた。蟻本で紹介されてるのはBIT2本で1次元のRSQを処理しているが、今回は2次元で4本必要になる。とはいえ基本的な考え方は変わらず、sum(y,x)=b_xy(y,x)*yx + b_y(y,x)*y + b_x(y,x)*x + b(y,x)となるようにb_xy, b_y, b_x, bにいわゆるいもす法っぽく値を加算していく。どれだけの値を加えるかは頑張って計算する。

acceptされたコード

#include <cstdio>
#include <cstring>
using namespace std;

const int MAX_N = 1000;

typedef long long int64;
typedef int64 bit_t;

struct BinaryIndexedTree{

	static const int MAX_BIT = ::MAX_N;
	bit_t data[MAX_BIT+1][MAX_BIT+1];
	int SIZE;

	//doubleのときはmemsetはダメ
	void init(int size){
		memset(data, 0, sizeof(data));
		SIZE = size;
	}

	bit_t sum(int y, int x){
		bit_t ret = 0;
		for(int ty=y;ty;ty-=ty&-ty){
			for(int tx=x;tx;tx-=tx&-tx){
				ret += data[ty][tx];
			}
		}
		return ret;
	}

	bit_t sum(int y1, int y2, int x1, int x2){
		return sum(y2,x2) + sum(y1,x1) - (sum(y2,x1) + sum(y1,x2));
	}

	void add(int y, int x, bit_t v){
		for(int ty=y+1;ty<=SIZE;ty+=ty&-ty){
			for(int tx=x+1;tx<=SIZE;tx+=tx&-tx){
				data[ty][tx]+=v;
			}
		}
	}
};

BinaryIndexedTree b_xy, b_x, b_y, b;


//[y1, y2) * [x1, x2)
void add(int y1, int y2, int x1, int x2, bit_t v){
	b_xy.add(y1,x1,v);
	b_xy.add(y1,x2,-v);
	b_xy.add(y2,x1,-v);
	b_xy.add(y2,x2,v);

	b_x.add(y1,x1,-y1*v);
	b_x.add(y1,x2,y1*v);
	b_x.add(y2,x1,y2*v);
	b_x.add(y2,x2,-y2*v);

	b_y.add(y1,x1,-x1*v);
	b_y.add(y1,x2,x2*v);
	b_y.add(y2,x1,x1*v);
	b_y.add(y2,x2,-x2*v);

	b.add(y1,x1,y1*x1*v);
	b.add(y1,x2,-x2*y1*v);
	b.add(y2,x1,-y2*x1*v);
	b.add(y2,x2,y2*x2*v);
}

int64 sum(int y, int x){
	return b_xy.sum(y,x)*y*x + b_x.sum(y,x)*x + b_y.sum(y,x)*y + b.sum(y,x);
}

int64 sum(int y1, int y2, int x1, int x2){
	return sum(y1,x1) + sum(y2,x2) - (sum(y1,x2) + sum(y2,x1));
}

void proc(){
	int n, q;
	scanf("%d%d", &n, &q);
	b_xy.init(n);
	b_x.init(n);
	b_y.init(n);
	b.init(n);

	for(int i=0; i<q; ++i){
		char op;
		scanf(" %c", &op);
		if(op == 'Q'){
			int y, x;
			scanf("%d%d", &x, &y);
			printf("%d\n", (int)(sum(y-1,y,x-1,x)&1));
		}
		else{
			int y1, y2, x1, x2;
			scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
			add(y1-1,y2,x1-1,x2,1);
		}
	}
}

int main(){
	int T;
	scanf("%d", &T);
	for(int _=0; _<T; ++_){
		proc();
		if(_ < T - 1){
			puts("");
		}
	}
	return 0;
}