# description
来源:洛谷 P4242
给你一棵树,每个节点都有颜色,有两种操作:
- 指定其中 个节点,对于每个节点,计算它 到每个给定的节点的路径上颜色的段数 的和。
- 修改一条路径上所有节点的颜色。
# solution
首先看描述就可以识算法了。
看到 显然是建虚树出来。
看到修改路径上的颜色,显然可以用树剖维护。
建虚树时,刚好可以利用树剖得到两个点之间的所有颜色。
然后问题是建虚树出来之后考虑怎么做。
首先看到路径上这三个字,可以考虑点分治。
# 点分治做法
描述约定:
- 对于题目中给出的 个点,我们叫他关键点。
- 对于虚树上的点,叫它虚树点。(显然关键点都是虚树点)
然后点分治步骤:
- 找到树重心,将其作为根节点
- 依次遍历根节点的每棵子树,计算子树中所有关键点到根节点的路径中颜色的段数,并
sum += ColorCount, nodeCount += 1
。 - 依次遍历根节点的每棵子树
- 删除这棵子树的贡献
- 计算每个关键点的答案(我们将经过根节点的答案拆分成两段,一段从根根节点到其它关键点,一段从当前点到根节点,答案显然应该是:
- 加回子树贡献
- 依次递归每棵子树
时间复杂度 。
除了点分治以外,还可以考虑虚树题目的常用解法:树形 dp。
# 树形 dp 做法
首先考虑有根树的做法。
以节点 为根,考虑虚树上这样一条边 :
对于子树 中的每个点,到 的路径都必须经过 ,所以子树 中的每个点到 的贡献都会多上 ,其中 表示 到 的路径上(包含 )颜色段数。
记 为以 为根的子树的答案, 为以 为根的子树中关键点的个数,得到下面这个式子:
对于之后的换根,同样按照这个式子再减去子树贡献就行(记 为节点父亲):
想展开就展开吧,总之我懒得展开 不过不展开也有个好处,就是逻辑更加清晰。
dp 的复杂度是 的,但是树剖和虚树的复杂度是 的,总复杂度 。
这个做法贴个 Code:
/* | |
1. 修改一条路径上所有节点的颜色 | |
2. 计算所有节点到其他节点的颜色总数 | |
*/ | |
#include <cstdio> | |
#include <vector> | |
#include <algorithm> | |
using namespace std; | |
#define rep(i, s, e) for(int i=s;i<=e;++i) | |
#define pb push_back | |
const int mxn = 1e5+10; | |
int n, q, c[mxn], dfn[mxn], ed[mxn], map[mxn]; | |
vector <int> G[mxn]; | |
struct node { | |
int lC, rC, sum; | |
inline node operator + (const node& rhs) const { | |
if(!rhs.sum) return *this; if(!sum) return rhs; | |
return (node) {lC, rhs.rC, sum + rhs.sum - (rC == rhs.lC)}; | |
} | |
inline node reverse() { swap(lC, rC); return *this; } | |
}; | |
const node E0 = {0, 0, 0}; | |
#define lc (o<<1) | |
#define rc (o<<1|1) | |
#define mid ((L+R)>>1) | |
struct Segtr { | |
node s[mxn<<2]; int tag[mxn<<2]; | |
inline void cover(int o, int tg) { s[o] = {tg, tg, 1}; tag[o] = tg; } | |
inline void psdn(int o) { if(tag[o]) { cover(lc, tag[o]), cover(rc, tag[o]); tag[o] = 0; } } | |
inline void psup(int o) { s[o] = s[lc] + s[rc]; } | |
void build(int o=1, int L=1, int R=n) { | |
if(L == R) return (void)(s[o] = {c[map[L]], c[map[L]], 1}); | |
build(lc, L, mid); build(rc, mid+1, R); psup(o); | |
} | |
node query(int cl, int cr, int o=1, int L=1, int R=n) { | |
if(cl <= L && R <= cr) return s[o]; psdn(o); | |
node res = (cl <= mid ? query(cl, cr, lc, L, mid) : E0) + (cr > mid ? query(cl, cr, rc, mid+1, R) : E0); | |
return res; | |
} | |
void set(int cl, int cr, int p, int o=1, int L=1, int R=n) { | |
if(cl <= L && R <= cr) return cover(o, p); psdn(o); | |
if(cl <= mid) set(cl, cr, p, lc, L, mid); | |
if(cr > mid) set(cl, cr, p, rc, mid+1, R); psup(o); | |
} | |
}; | |
#undef lc | |
#undef rc | |
#undef mid | |
struct tr_cut { | |
int sz[mxn], son[mxn], fa[mxn], dep[mxn], dfc, top[mxn]; | |
Segtr T; | |
void dfs1(int u, int fat) { | |
fa[u] = fat; dep[u] = dep[fat] + 1; sz[u] = 1; | |
for(auto v: G[u]) if(v != fat) { | |
dfs1(v, u); sz[u] += sz[v]; if(sz[v] > sz[son[u]]) son[u] = v; | |
} | |
} | |
void dfs2(int u) { | |
dfn[u] = ++dfc; map[dfc] = u; | |
if(son[u]) { top[son[u]] = top[u]; dfs2(son[u]); } | |
for(auto v: G[u]) if(v != fa[u] && v != son[u]) { top[v] = v; dfs2(v); } | |
ed[u] = dfc; | |
} | |
inline int lca(int u, int v) { | |
while(top[u] != top[v]) dep[top[u]] > dep[top[v]] ? u = fa[top[u]] : v = fa[top[v]]; | |
return dep[u] > dep[v] ? v : u; | |
} | |
inline void modify(int u, int v, int p) { | |
while(top[u] != top[v]) { | |
if(dep[top[u]] < dep[top[v]]) swap(u, v); | |
T.set(dfn[top[u]], dfn[u], p); | |
u = fa[top[u]]; | |
} | |
if(dep[u] > dep[v]) swap(u, v); | |
T.set(dfn[u], dfn[v], p); | |
} | |
inline node query(int u, int v) { | |
// 从浅到深 | |
if(dep[u] > dep[v]) swap(u, v); | |
node res = E0; | |
while(top[v] != top[u]) { | |
res = T.query(dfn[top[v]], dfn[v]) + res; | |
v = fa[top[v]]; | |
} | |
if(u != v) res = T.query(dfn[u]+1, dfn[v]) + res; | |
return res; | |
} | |
inline void init() { dfs1(1, 0); top[1] = 1; dfs2(1); T.build(); } | |
} cut; | |
inline bool cmp(int x, int y) { return dfn[x] < dfn[y]; } | |
struct fake_tree { | |
int pt[mxn<<1], m, M, tot, key[mxn], stk[mxn], top, sz[mxn], f[mxn], fa[mxn], ask[mxn]; | |
struct edge { int v; node w; }; | |
vector <edge> G[mxn]; node s[mxn]; | |
inline void init(int GETIN) { | |
m = M = GETIN; | |
rep(i, 1, m) { int x; scanf("%d", &x); pt[i] = ask[i] = x; key[x] = 2; } | |
if(!key[1]) key[1] = 1, pt[++M] = 1; | |
sort(pt + 1, pt + M + 1, cmp); | |
tot = M; | |
rep(i, 2, M) { | |
int C = cut.lca(pt[i], pt[i-1]); | |
if(!key[C]) key[C] = 1, pt[++tot] = C; | |
} | |
sort(pt + 1, pt + tot + 1, cmp); | |
stk[top = 1] = 1; | |
rep(i, 1, tot) c[pt[i]] = cut.T.query(dfn[pt[i]], dfn[pt[i]]).lC; | |
#define T (stk[top]) | |
rep(i, 2, tot) { | |
int cur = pt[i]; | |
while(dfn[cur] > ed[T]) --top; node w = cut.query(T, cut.fa[cur]); | |
G[T].pb({cur, w}); stk[++top] = pt[i]; | |
} | |
#undef T | |
} | |
void dfs1(int u, int fat) { | |
s[u] = {c[u], c[u], 1}; fa[u] = fat; f[u] = sz[u] = (key[u] == 2); | |
for(auto e: G[u]) { | |
dfs1(e.v, u); sz[u] += sz[e.v]; | |
f[u] += f[e.v] + sz[e.v] * ( (s[u] + e.w + s[e.v]).sum - 1); | |
} | |
} | |
void dfs2(int u, node from) { | |
if(fa[u]) f[u] = f[fa[u]] - sz[u] * (from.sum - 1) + (m - sz[u]) * (from.sum - 1); | |
for(auto e: G[u]) dfs2(e.v, s[u] + e.w + s[e.v]); | |
} | |
inline void solve() { | |
dfs1(1, 0); dfs2(1, E0); | |
rep(i, 1, m) printf("%d ", f[ask[i]]); puts(""); | |
rep(i, 1, tot) { | |
int cur = pt[i]; G[cur].clear(); s[cur] = E0; | |
key[cur] = f[cur] = fa[cur] = sz[cur] = ask[i] = pt[i] = 0; | |
} | |
} | |
} sol; | |
int main() { | |
scanf("%d%d", &n, &q); | |
rep(i, 1, n) scanf("%d", c+i); | |
rep(i, 2, n) { | |
int u, v; scanf("%d%d", &u, &v); | |
G[u].pb(v); G[v].pb(u); | |
} | |
cut.init(); | |
while(q--) { | |
int op; scanf("%d", &op); | |
if(op&1) { | |
int u, v, w; scanf("%d%d%d", &u, &v, &w); | |
cut.modify(u, v, w); | |
} else { | |
int m; scanf("%d", &m); sol.init(m); sol.solve(); | |
} | |
} | |
return 0; | |
} |