AOJ-1171: レーザー光の反射 (Laser Beam Reflections)

keyword

平面幾何 C++

問題概要

レーザー光を上手く反射させて最短距離で目標物に当てる問題。

解法

鏡の枚数が少なく、反射の回数に制限があるので全ての反射の場合を総当たりする。反射をまともにシミュレーションするのは大変なので、鏡にぶつかる度に鏡を軸として世界を反転させる。そうするとレーザー光と(移動した後の)目標物を直線で結んで順番に鏡にぶつかっているかどうか調べるだけでよい。反射を線対称で処理するのは定石だと言っていいと思う。

感想

昔書いたコードを見直すのは精神的に来るものがある。ソースコードは一部省略している。

#define REP(i,n) for(i=0; i < (n); i++)
#define REPONE(i, n) for(i=1; i <= (n); i++)
#define LOOP(n) for(int loopCount=1; loopCount <= (n); loopCount++)
#define ITER(c) __typeof((c).begin())
#define EACH(c,it) for(ITER(c) it =(c).begin(); it!=(c).end(); it++)
#define SZ(c) ((int) (c).size())
#define ALL(c) c.begin(), c.end()
#define SUM(c) accumulate(ALL(c), 0)
#define EXIST(c,v) (find(ALL(c), (v)) != (c).end())
#define PB push_back
#define MP make_pair
#define INF (1e99)

using namespace std;
static const double PI = 3.141592653589793238462643383279;
typedef long long ll;

#define P complex<T> 
#define EPS (1e-12)

//スカラー値の符号を返す。+:1, -:-1, 0:0。
template<typename T>
int sig(T a){
    return a==0?0:a>0?1:-1;
}

//ベクトルp1, p2の内積を返す
template<typename T>
T iprod(P p1, P p2){
    return (p1.real()*p2.real()) + (p1.imag()*p2.imag());
}

//ベクトルp1, p2の外積(スカラー値,z座標)を返す
template<typename T>
T oprod(P p1, P p2){
    return (p1.real()*p2.imag()) - (p1.imag()*p2.real());
}


//線分(p1, p2)と点qの距離を返す
template<typename T>
T disSP(P p1, P p2, P q){

}

//直線(p1, p2)と点qの距離を返す
template<typename T>
T disLP(P p1, P p2, P q){

}

//線分(p1, p2)と線分(q1, q2)の交点が存在するときにtrueを返す
template<typename T>
bool crsSS(P p1, P p2, P q1, P q2){

}


//点qから直線(p1, p2)に下ろした垂線の足を返す
template<typename T>
P proj(P p1, P p2, P q){

}

//直線(p1,p2)と直線(q1, q2)の交点を返す
template<typename T>
P isLL(P p1, P p2, P q1, P q2){

}


//直線(p1,p2)に関して点qと対象な点を返す
template<typename T>
P mirrorLP(P p1, P p2, P q){
    return (T)2*proj(p1,p2,q) - q;
}

//直線(p1,p2)に関して線分(q1, q2)と対象な線分を返す
template<typename T>
vector<P> mirrorLS(P p1, P p2, P q1, P q2){
    P r1=mirrorLP(p1, p2, q1);
    P r2=mirrorLP(p1, p2, q2);
    vector<P> ans;
    ans.PB(r1);
    ans.PB(r2);
    return ans;
}

typedef complex<long double> cd;

int main(){
    int i, j, n, x, y, z, w, m;
    while(cin>>n){
        if(n==0) break;

        //input read
        vector< vector<cd> > mirrors;
        vector<cd> hoge;
        cd start, goal;
        LOOP(n){
            cin >> x >> y >> z >> w;
            hoge.clear();
            hoge.PB(cd(x,y));
            hoge.PB(cd(z,w));
            mirrors.PB(hoge);
        }
        cin >> x >> y;
        goal = cd(x,y);
        cin >> x >> y;
        start  = cd(x,y);

        //generate pattern
        set< vector<int> > pats;
        vector<int> pat;
        int i1, i2, i3, i4, i5;
        REP(i1,n+1)REP(i2,n+1)REP(i3,n+1)REP(i4,n+1)REP(i5,n+1){
            pat.clear();
            if(i1<n) pat.PB(i1);
            if(i2<n) pat.PB(i2);
            if(i3<n) pat.PB(i3);
            if(i4<n) pat.PB(i4);
            if(i5<n) pat.PB(i5);
            if(SZ(pat)>1){
                REP(i,SZ(pat)-1)if(pat[i]==pat[i+1]) goto esc;
            }
            pats.insert(pat);
esc:;
        }
//        EACH(pats,it){EACH(*it, itt) cout << *itt <<","; cout << endl;}

        //solve
        double ans = INF;
        EACH(pats,it){
            pat = *it;
            m = SZ(pat);
            vector< vector< vector<cd> > > ms;
            vector< vector<cd> > tmp;
            vector<cd> tmp2, tmp3;
            ms.PB(mirrors);
            cd mg=goal;
            REP(i,m){
                tmp = ms.back();
                tmp3 = tmp[pat[i]];
                REP(j,n)if(j!=pat[i]){
                    tmp2 = mirrorLS(tmp3[0], tmp3[1], tmp[j][0], tmp[j][1]);
                    tmp[j] = tmp2;
                }
                mg = mirrorLP(tmp3[0], tmp3[1], mg);
                ms.PB(tmp);
            }

            //debug
/*            REP(i,m+1){
                REP(j,n){
                    cout << ms[i][j][0] <<":"<< ms[i][j][1] << endl;
                }
                cout << endl;
            }
            cout << mg << endl;
            cout << abs(start - mg) << endl;
            cout << disSP(ms[1][0][0], ms[1][0][1], ms[1][4][0]) << endl;
            */

            if(abs(start-mg) < ans){
                cd prev=start, next;
                vector< vector<cd> > segs;
                REP(i,m){
                    next = isLL(start, mg, ms[i][pat[i]][0], ms[i][pat[i]][1]);
                    tmp2.clear();
                    tmp2.PB(prev);
                    tmp2.PB(next);
                    segs.PB(tmp2);
                    prev = next;
                }
                tmp2.clear();
                tmp2.PB(prev);
                tmp2.PB(mg);
                segs.PB(tmp2);

                REP(i,m){
                    if(iprod(segs[i][0]-segs[i][1], segs[i+1][0]-segs[i+1][1])<0) goto end;
                }

                //交差判定
                if(m==0){
                    REP(j,n)if(crsSS(mirrors[j][0], mirrors[j][1], start, goal)) goto end;
                }
                else if(m>0){
                REP(j,n){
                    if(j != pat[0] && crsSS(segs[0][0], segs[0][1], ms[0][j][0], ms[0][j][1]))
                        goto end;
                    if(j == pat[0] && !crsSS(start, mg, ms[0][j][0], ms[0][j][1]))
                        goto end;
                }
                REPONE(i,m-1){
                    REP(j,n){
                        if( (j!=pat[i-1] && j!=pat[i]) && crsSS(segs[i][0], segs[i][1], ms[i][j][0], ms[i][j][1]))
                            goto end;
                        if( (j==pat[i-1] || j==pat[i]) && !crsSS(start, mg, ms[i][j][0], ms[i][j][1]))
                            goto end;
                    }
                }
                REP(j,n){
                    if(j != pat.back() && crsSS(segs[m][0], segs[m][1], ms[m][j][0], ms[m][j][1]))
                        goto end;
                    if(j == pat.back() && !crsSS(start, mg, ms[m][j][0], ms[m][j][1]))
                        goto end;
                }
                }

                ans = abs(start-mg);
//                EACH(pat,ittt)cout<<*ittt<<",";cout<<endl;
            }
end:;
        }
        printf("%.4f\n", ans);
    }
    return 0;
}