HDU - 5401 Persistent Link/cut Tree dp

HDU - 5401

明显发现这个东西可以递归处理, 然后把式子列出来, 记忆化搜就可以了。

#include<bits/stdc++.h>
#define LL long long
#define LD long double
#define ull unsigned long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define fio ios::sync_with_stdio(false); cin.tie(0);

using namespace std;

const int N = 60 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-8;
const double PI = acos(-1);

template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;}
template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;}
template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;}
template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;}

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

LL m;
LL a[N], b[N], c[N], d[N], l[N];
LL ans[N];
LL sz[N];

map<LL, LL> memo1[N];
map<PLL, LL> memo2[N];

LL getPointDis(LL id, LL u, LL v) {
    if(u == v) return 0;
    if(u > v) swap(u, v);
    if(memo2[id].count(mk(u, v))) return memo2[id][mk(u, v)];
    if(v < sz[a[id]]) {
        return memo2[id][mk(u, v)] = getPointDis(a[id], u, v);
    }
    else if(u >= sz[a[id]]) {
        return memo2[id][mk(u, v)] = getPointDis(b[id], u - sz[a[id]], v - sz[a[id]]);
    }
    else {
        return memo2[id][mk(u, v)] = (getPointDis(a[id], u, c[id]) + getPointDis(b[id], v - sz[a[id]], d[id]) + (LL)l[id]) % mod;
    }
}

LL getAllToPoint(LL id, LL u) {
    if(id == 0) return 0;
    if(memo1[id].count(u)) return memo1[id][u];
    LL ans = 0;
    if(u < sz[a[id]]) {
        add(ans, (getAllToPoint(b[id], d[id]) + sz[b[id]] % mod * (l[id] + getPointDis(a[id], u, c[id])) % mod) % mod);
        add(ans, getAllToPoint(a[id], u));
    } else {
        add(ans, (getAllToPoint(a[id], c[id]) + sz[a[id]] % mod * (l[id] + getPointDis(b[id], u - sz[a[id]], d[id])) % mod) % mod);
        add(ans, getAllToPoint(b[id], u - sz[a[id]]));
    }
    return memo1[id][u] = ans;
}

void init() {
    for(int i = 0; i < N; i++) {
        memo1[i].clear();
        memo2[i].clear();
    }
    memset(sz, 0, sizeof(sz));
    memset(ans, 0, sizeof(ans));
}

int main() {
    while(scanf("%lld", &m) != EOF) {
        init();
        sz[0] = 1;
        for(int i = 1; i <= m; i++) {
            scanf("%lld%lld%lld%lld%lld", &a[i], &b[i], &c[i], &d[i], &l[i]);
            sz[i] = sz[a[i]] + sz[b[i]];
        }
        for(int i = 1; i <= m; i++) {
            ans[i] = (ans[a[i]] + ans[b[i]]) % mod;
            add(ans[i], sz[b[i]] % mod * getAllToPoint(a[i], c[i]) % mod);
            add(ans[i], sz[a[i]] % mod * getAllToPoint(b[i], d[i]) % mod);
            add(ans[i], (sz[a[i]] % mod) * (sz[b[i]] % mod) % mod * l[i] % mod);
        }
        for(int i = 1; i <= m; i++) {
            printf("%lld
", ans[i]);
        }
    }
    return 0;
}

/*
*/

 

你可能感兴趣的