(树形 dp、数学)AT_dp_v Subtree 题解
题意
给一棵 nnn 个节点的树,对每一个节点染成黑色或白色。
对于每一个节点,求强制把这个节点染成黑色的情况下,所有的黑色节点组成一个联通块的染色方案数,答案对 ppp 取模。
1≤n≤1051\le n\le 10^51≤n≤105.
思路
由于钦定选哪一个点强制涂色已经要 O(n)O(n)O(n),不可能把 nnn 提到根去算贡献的,那总共需要 O(n2)O(n^2)O(n2)。
因此考虑先做 111 为根,每个子树内的答案——这个好做。设 fif_ifi 表示,iii 强制涂色,子树 iii 的方案数。对于边 (u,v)(u,v)(u,v) 合并子树 uuu 和 vvv。uuu 即可以在 vvv 涂色合并所有 vvv 为根的连通块(fvf_vfv),也可以舍弃它不要(111)。于是:
fu=∏v=sonu(fv+1)f_u=\prod_{v=son_u}(f_v+1)fu=v=sonu∏(fv+1)
(+1+1+1 也可以类比计算因数和、幂次 +1+1+1)
我们发现,节点 uuu 作为根面向的所有节点,除了以 111 为根的子树(已经计算了 fuf_ufu)剩下的就是子树之外的所有点,剩下的点必然是连通的,不妨计算 gug_ugu 表示,uuu 强制选,uuu 子树之外的方案数。
对于边 (u,v)(u,v)(u,v),因为计算子树外的方案数,所以 gvg_vgv 由 gug_ugu 推过来。上面这幅图,呈现了 gvg_vgv 的推导过程。即在 gug_ugu 基础上,合并上除去 vvv 以 111 为根子树的贡献 fufv+1\dfrac{f_u}{f_v+1}fv+1fu:
gv=gu×fufv+1g_v=g_u\times\dfrac{f_u}{f_v+1}gv=gu×fv+1fu
这里出现了除法,但是因为模数 ppp 不一定为质数,很难搞逆元,因此转化为纯乘法:
gv=gu×∏r=sonu,r≠v(fr+1)g_v=g_u\times \prod_{r=son_u,r\neq v} (f_r+1)gv=gu×r=sonu,r=v∏(fr+1)
用动态数组预存储前后缀积即可。
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll N=1e5+9;
ll n,mod;
struct edge
{ll to,next;
}e[N<<1];
ll idx,head[N];
void addedge(ll u,ll v)
{idx++;e[idx].to=v;e[idx].next=head[u];head[u]=idx;
}
ll f[N],g[N];
ll tot[N];
vector<ll>pre[N],suf[N];
void dfs1(ll u,ll fa)
{f[u]=1;for(int i=head[u];i;i=e[i].next){ll v=e[i].to;if(v==fa)continue;dfs1(v,u);f[u]=f[u]*(f[v]+1)%mod;pre[u].push_back(f[v]+1);suf[u].push_back(f[v]+1);tot[u]++;}for(int i=1;i<pre[u].size();i++)pre[u][i]=pre[u][i]*pre[u][i-1]%mod;for(int i=suf[u].size()-2;i>=0;i--)suf[u][i]=suf[u][i]*suf[u][i+1]%mod;
}
void dfs2(ll u,ll fa)
{ll dfn=0;//遍历儿子的顺序相同 for(int i=head[u];i;i=e[i].next){ll v=e[i].to;if(v==fa)continue;dfn++;if(tot[u]==1)g[v]=g[u];else if(dfn==1)g[v]=g[u]*suf[u][dfn]%mod;//suf[u][dfn+1]else if(dfn==tot[u])g[v]=g[u]*pre[u][dfn-2]%mod;//pre[u][dfn-1]else g[v]=g[u]*pre[u][dfn-2]%mod*suf[u][dfn]%mod;g[v]++;//不合并子树外部分,子树内自称1个连通块dfs2(v,u);}
}
int main()
{scanf("%lld%lld",&n,&mod);for(int i=1;i<n;i++){ll u,v;scanf("%lld%lld",&u,&v);addedge(u,v);addedge(v,u);}dfs1(1,0);g[1]=1;dfs2(1,0);//子树外 for(int i=1;i<=n;i++)printf("%lld\n",f[i]*g[i]%mod);return 0;
}