SRM 502 Div2 Hard TheCowDivTwo

今まではDiv2 EasyとMediumしか解いていなかったけど,今後はHardも解いていく!

問題

牧場に0..N-1の数字がついた牛がN頭いて,そのうちK頭が牧場から逃げていった.逃げていった牛の数字を足した合計がNで割り切れることがわかっている時,逃げ出した牛の集合の数を答えよ

解答

ものすごい数になることが予想できたので,ぼんやりDPを使うんだろうなと思っていたけど,どう実装するか詰めることが出来なかった.
今回はTopCoderの解説がとてもわかり易くなったので,大体の流れを記す.

Login - TopCoder Wiki

問題設定を見て,まず思うのはDPを使うことだ,なぜならこの問題は数え上げ問題だし,制約がとても大きい.
もし,DPを使おうと思うなら,まずは繰り返し(or 再帰?)が必要だ.まずは全体像を考えてみよう,0..N-1からなるリストを持っている,リストから合計がNで割り切れる数をK個選ぶ.今後,作る関数はあるリストと何個取り出すかを受け取るとしよう.
繰り返しを始める時,なんらかの数のリストを持っている,ここでList[0]を取り出すK個に含めるかどうかの2つの場合分けを考える.List[0]を含める場合List[1:]からK-1個取り出し,List[0]を含めない場合List[1:]からK個取り出す.前者の場合は注意が必要だ,なぜならList[0]を使うのでList[1:]からK-1個を取り出した合計S+List[0]がNで割り切れるように選ばなくてはいけなくなるからだ.次の式からSは(N-List[0]) mod Nになるようにしなくてはいけない.

S+List[0] = 0 mod N
S = -List[0] mod N
S = N - List[0] mod N

K-1個取り出し,合計が(N-List[0]) mod Nとなる取り出し方を考える必用がある.このような繰り返しのためには,取り出しの合計 mod Nとなるという制約が与えられる必要がある,これは最初のmod Nが0であるという制約条件と同じである.与えられた値Sは次の繰り返しに必要な新たな制約条件S'を次のように導出する.

S' = (S - List[0]) mod N

以上のことを実現するため,関数をcountSets(List, K, S)と設定しよう.countSets(List, K, S)はListの中から合計数mod NがSとなるようにK個選んだ時の集合の個数である.
とすると,この問題の答えはcountSets([0..N-1], K, 0)と表すことが出来る.

function countSets(List, K, S) {
    List' = remove List[0] from List.
    return countSets(List', K-1, (S-List[0]) mod N) + countSets(List', K, S)
}

以下が初期値を考慮したソースコードだが,問題点が一つ.入力値がN<=1000, K<=47なのでO(N*K*N)の以下のアルゴリズムは2秒以内に計算が収まるが,メモ化を行うための配列は4*N*K*Nでだいたい183MBほどかかってしまう.

 table[N+1][K+1][N] // Arguments are: [p][k][S]
    for p = N to 0 {
      for k = 0 to K {
        for S = 0 to N-1 {
          if (k == 0) {
            if (S==0) {
                table[p][k][s] = 1
            } else {
                table[p][k][s] = 0
            }
          } else if (p == N) {
            table[p][k][s] = 0
          } else {
            table[p][k][s] = table[p+1][k-1][(S-p) mod N] + table[p+1][k][S] 
          }
        }
      }
    }

その問題点を解決したのが以下のコード.重要なのはtable[p][k][s]をけいさんするためにはtable[p+1]の値が求まってさえいればいいということ.これでメモ化に必要な配列が2*K*Nに抑えることが出来る.
## p%2で一つ前の配列を指すのは面白いアイディアだから覚えておこう.pの一個前の値は(p+1) % 2.

int find(int N, int K)
{
    //Note that memory is now only O(K*N).
    int table[2][K+1][N];
    for (int p = N; p>=0; p--) {
      for (int k = 0; k<=K; k++) {
        for (int s = 0; s<N; s++) {
          if (k == 0) {
            if (s==0) {
                table[p%2][k][s] = 1;
            } else {
                table[p%2][k][s] = 0;
            }
          } else if (p == N) {
            table[p%2][k][s] = 0;
          } else {
            table[p%2][k][s] = table[(p+1)%2][k-1][ ( s-p+N )%N ] //Use p
                             + table[(p+1)%2][k][s];        //Don't use p
                             
            //We only need the result Modulo 1000000007
            table[p%2][k][s] %= 1000000007;
          }
        }
      }
    }
    
    // The result is to call the recurrence when
    //   p = 0, k = K, and S=0
    return table[0 % 2][K][ 0 ];

}