なんか狐につままれたような気分。

問題:No.196 典型DP (1) - yukicoder

解法:解説のところに書いてあるとおり,O(N^3)であるような気がする計算量が,実はO(N^2)であるというトリックで時間内に解くことができる,という感じ。ソースコード参照しながらそのことについて説明します。

まず「何をやっているのか」ということですが,主にdfs(v, p)の部分で計算しています。vは今見ている頂点でpがその親という例のアレです。
中ではvの子供ごとにdp[v][k] := (頂点vを根として,黒のノードの数をk個にするような塗り方は何通りあるか)というものを計算しています(なので求める答えはdfs(0,-1)[K]ですね)。vの各 子ノードごとにこのdpを計算していくと,dp[v]はdfs中のnretでやっているように計算できます。


これの計算量は,一見すると,

N(各頂点を1回ずつ見る) * N^2(dfs中でnretを計算している部分は2重のfor文) = O(N^3)

になっているように思われます。ですがこれは実はO(N^2)の計算量で計算されています。まず簡単な例として各頂点に2つの子しかいないとしましょう。それぞれの子のサイズをそれぞれl, rとし,サイズnの頂点のdfsを計算するのにかかる時間をT(n)とすると,

T(l+r) = T(l) + T(r) + l*r

となります(T(l), T(r)はそれぞれの子ノードを探索するのにかかる計算量,l*rはfor文を回すのにかかる計算量)。ところで,(l+r)^2 = l^2+r^2+2*l*rであるので,T(n)=O(n^2)です。要するにdfs(0, -1)はO(n^2)でできるのでこの問題での計算量的に間に合っているということです。


次に一般の場合を考えましょう。頂点vに子ノードch0, ch1, ..., chmがあるとして,それぞれのサイズがs0, s1, ..., smであるとします。この時の頂点vにおける計算量は

T(s0+s1+...+sm) = T(s0) + T(s1) + ... + T(sm) + s0 + s0*s1 + (s0+s1)*s2 + ... + (s0+s1+...+s(m-1))*sm

となります。これについても,

(s0+s1+...+sm)^2 = s0^2 + s1^2 + ... + sm^2 + 2*sum(si*sj)

となるので,やはりT(n) = O(n^2)となります。

以下ソースコード

const int MAXN = 2222;
const ll MOD = 1e9+7;
vector<int> G[MAXN];
ll dp[MAXN][MAXN];

vll dfs(int v, int p) {
    vll ret(1, 1);
    for (int next : G[v]) {
        if (next == p) continue;
        vll tmp = dfs(next, v);
        int n = ret.size();
        int m = tmp.size();
        vll nret(n+m-1);
        for (int i = 0; i < n; i++) for (int j = 0; j < m; j++) {
            (nret[i+j] += ret[i]*tmp[j]) %= MOD;
        }
        ret = nret;
    }
    ret.push_back(1);
    return ret;
}

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int N, K;
    cin >> N >> K;
    for (int i = 0; i < N-1; i++) {
        int a, b;
        cin >> a >> b;
        G[a].push_back(b);
        G[b].push_back(a);
    }
    cout << dfs(0, -1)[K] << endl;
    return 0;
}