SRM501 Div1 Medium, Div2 Hard FoxAverageSequence

問題

配列Aがビューティフルとは次のような制約条件を満たす時である
1. 0<=A[i]<=40
2. A[i] <= (A[0] + A[1] + ... + A[i-1]) / i
3. A[i] > A[i+1] > A[i+2]となる列が存在しない
今-1..40からなる数列が与えられた時-1を0..40の数字に変換して得られるビューティフルな数列の個数を答えよ

解答

DPで解く.0..i番目の配列までで得られる個数をそれまでの総和,最後に使われた数,減少傾向にあるかでそれぞれ保存しておく.i番目のDP表を埋めるのにi-1番目の表までで十分なのでM[2][40*40][40][2]のDP表を作る.M[i%2]の値を埋めるのにM[(i-1)%2]の値を利用することで容量を節約(2*40**4でも容量的に十分だったので,容量面で頑張らなくても良かったのかも)
後は桁あふれと余計な処理を行わないように,条件に合う値を埋めていく.計算量的には40**5なので結構危うい.(実際最適化をする前だとTLEになってしまった)

配列は基本vectorで記述するようにしているけど,DP表を書く時は見栄えが悪くなるなぁ.他の人のコードを見るとほぼ全員といっていいほどDP表にはvector使ってない.

int mod = 1000000007;
int m;
class FoxAverageSequence
{
public:
  // void calc(vector<vector<vector<int> > > & M2, vector<vector<vector<int> > > & M, int j, int v, int i){
  void calc(vector<vector<vector<vector<int> > > > & M, int j, int v, int i){
    int k;
    long long tmp = 0;
    if (j < i*v){
      M[i%2][j][v][0] = 0;
      M[i%2][j][v][1] = 0;
      return;
    }
    for(k = v+1; k < m && j-k >= 0; k++) {
      if ((j-v)/i >= v) tmp += M[(i-1)%2][j-v][k][0] % mod;
    }
    if ((j-v)/i >= v) M[i%2][j][v][1] = tmp % mod;
    tmp = 0;
    for(k = v; k >= 0 && j-k >= 0; k--){
      if ((j-v)/i >= v) {
        tmp += (M[(i-1)%2][j-v][k][0] + M[(i-1)%2][j-v][k][1]) % mod;
      }
    }
    if ((j-v)/i >= v) M[i%2][j][v][0] = tmp % mod;
  }
  int theCount(vector <int> seq)
    {
      int n = seq.size();
      m = 41;
      int i, j, k;
      vector<vector<vector<vector<int> > > > M(2, vector<vector<vector<int> > >(m*n, vector<vector<int> >(m, vector<int>(2, 0))));
      vector<vector<vector<int> > > init(vector<vector<vector<int> > >(m*n, vector<vector<int> >(m, vector<int>(2, 0))));
      // cout << "M.size()=" << M.size()
      //      << " M[0].size()=" << M[0].size()
      //      << " M[0][0].size()=" << M[0][0].size()<< endl;
      if (seq[0] == -1){
        for(i = 0; i < m; i++){
          M[0][i][i][0] = 1;
        }
      }else{
        M[0][seq[0]][seq[0]][0] = 1;
      }
      for(i = 1; i < n; i++){
        for(j = 0; j < m*n; j++){
          if (seq[i] != -1){
            calc(M, j, seq[i], i);
          }else{
            for(k = 0; k < m; k++){
              if (j-k<0) continue;
              calc(M, j, k, i);
            }
          }
        }
        M[(i-1)%2] = init;
      }
      long long res = 0;
      for(j = 0; j < m*n; j++){
        for(k = 0; k < m; k++){
          res = (res + M[(n-1)%2][j][k][0] + M[(n-1)%2][j][k][1]) % mod;
        }
      }
      return res;
    }
};