POJ-2155 : Matrix
問題概要
N*N(N<1000)の行列があり、初期値は全て0である。クエリがQ(<5*10^4)個与えられる。一つは矩形区間内の値を全てフリップする。もう一つは指定された要素を答える。
解法
2次元のRange Sum Queryを利用すれば解ける。RSQはsegtreeでやってもよいが、練習がてらBITを用いて解いてみた。蟻本で紹介されてるのはBIT2本で1次元のRSQを処理しているが、今回は2次元で4本必要になる。とはいえ基本的な考え方は変わらず、sum(y,x)=b_xy(y,x)*yx + b_y(y,x)*y + b_x(y,x)*x + b(y,x)となるようにb_xy, b_y, b_x, bにいわゆるいもす法っぽく値を加算していく。どれだけの値を加えるかは頑張って計算する。
acceptされたコード
#include <cstdio> #include <cstring> using namespace std; const int MAX_N = 1000; typedef long long int64; typedef int64 bit_t; struct BinaryIndexedTree{ static const int MAX_BIT = ::MAX_N; bit_t data[MAX_BIT+1][MAX_BIT+1]; int SIZE; //doubleのときはmemsetはダメ void init(int size){ memset(data, 0, sizeof(data)); SIZE = size; } bit_t sum(int y, int x){ bit_t ret = 0; for(int ty=y;ty;ty-=ty&-ty){ for(int tx=x;tx;tx-=tx&-tx){ ret += data[ty][tx]; } } return ret; } bit_t sum(int y1, int y2, int x1, int x2){ return sum(y2,x2) + sum(y1,x1) - (sum(y2,x1) + sum(y1,x2)); } void add(int y, int x, bit_t v){ for(int ty=y+1;ty<=SIZE;ty+=ty&-ty){ for(int tx=x+1;tx<=SIZE;tx+=tx&-tx){ data[ty][tx]+=v; } } } }; BinaryIndexedTree b_xy, b_x, b_y, b; //[y1, y2) * [x1, x2) void add(int y1, int y2, int x1, int x2, bit_t v){ b_xy.add(y1,x1,v); b_xy.add(y1,x2,-v); b_xy.add(y2,x1,-v); b_xy.add(y2,x2,v); b_x.add(y1,x1,-y1*v); b_x.add(y1,x2,y1*v); b_x.add(y2,x1,y2*v); b_x.add(y2,x2,-y2*v); b_y.add(y1,x1,-x1*v); b_y.add(y1,x2,x2*v); b_y.add(y2,x1,x1*v); b_y.add(y2,x2,-x2*v); b.add(y1,x1,y1*x1*v); b.add(y1,x2,-x2*y1*v); b.add(y2,x1,-y2*x1*v); b.add(y2,x2,y2*x2*v); } int64 sum(int y, int x){ return b_xy.sum(y,x)*y*x + b_x.sum(y,x)*x + b_y.sum(y,x)*y + b.sum(y,x); } int64 sum(int y1, int y2, int x1, int x2){ return sum(y1,x1) + sum(y2,x2) - (sum(y1,x2) + sum(y2,x1)); } void proc(){ int n, q; scanf("%d%d", &n, &q); b_xy.init(n); b_x.init(n); b_y.init(n); b.init(n); for(int i=0; i<q; ++i){ char op; scanf(" %c", &op); if(op == 'Q'){ int y, x; scanf("%d%d", &x, &y); printf("%d\n", (int)(sum(y-1,y,x-1,x)&1)); } else{ int y1, y2, x1, x2; scanf("%d%d%d%d", &x1, &y1, &x2, &y2); add(y1-1,y2,x1-1,x2,1); } } } int main(){ int T; scanf("%d", &T); for(int _=0; _<T; ++_){ proc(); if(_ < T - 1){ puts(""); } } return 0; }