Problem
给一棵根为 1 的有根树,点 \(i\) 具有一个权值 \(A_i\) 。
定义一个点对的值 \(f(u, v)=\max \left(A_u, A_v\right) \times\left|A_u-A_v\right|\) 。
你需要对于每个节点 \(i\) ,计算 \(a n s_i=\sum_{u \in \operatorname{subtree}(i), v \in \operatorname{subtree}(i)} f(u, v)\) ,其中 \(\operatorname{subtree}(i)\) 表示 \(i\) 的子树。
请你输出 \(\oplus\left(a n s_i \bmod 2^{64}\right)\) ,其中 \(\oplus\) 表示 XOR。
\(n \leq 5 \times 10^5, 1 \leq A_i \leq 10^6\)
Solution
先来愉快的推式子。
其实 \(\max \left(A_u, A_v\right) \times\left|A_u-A_v\right|\) 其实就是 \(\max^2-\max \cdot \min\),这两部分可以分开思考。
对于 \(\max\cdot\min\),其实就是在 \(i\) 的子树内任选两个点 \(u,v\in \operatorname{subtree}(i)\) 相乘
\[
\begin{align}
&\sum_{u} \sum_{v } A_u\times A_v\\
=&\sum_{u }A_u\sum_{v}A_v\\
=&(\sum_{u}A_u)^2
\end{align}
\] 对于 \(\max ^2\),也就是 \(\sum_{u,v\in \operatorname{subtree}(i)}(\max(A_u,A_v))^2\),我们需要思考子树合并的情况。
假设我们已经计算了节点 \(u\) 的所有子节点的子树的内部信息,\(v\) 是 \(u\) 的某个儿子,此时我们需要计算
- \(u\) 与 \(\operatorname{subtree}(v)\) 之间的贡献
- \(\operatorname{subtree}(v_i)\) 与 \(\operatorname{subtree}(v_j)\) 之间的贡献(即跨点 \(u\) 的两点之间的贡献)
我们按照以下方式合并的同时计算贡献(以下步骤来自于题解)
- \(\operatorname{subtree}(u)\) 初始为 \(\{u\}\) 。
- 计算 \(\operatorname{subtree}(v)\) 和当前 \(\operatorname{subtree}(u)\) 之间点对的答案。(跨越 \(u\) 节点的部分)。
- 把 \(\operatorname{subtree}(v)\) 子树内的答案直接累加。(不跨越 \(u\) 节点的部分)。
- \(\operatorname{subtree}(u) \leftarrow \operatorname{subtree}(u)+\operatorname{subtree}(v)\) (将 \(v\) 的子树加入到 \(u\) 中)。
我们需要维护两个变量:一个子树内的权值出现次数 \(cnt\) 与权值平方和 \(sum\)。
当前子树 \(\operatorname{subtree}(u)\) 内加入一个权重为 \(w\) 的点,对于答案贡献多少呢?
对于 \(\operatorname{subtree}(u)\) 中每个权值小于 \(w\) 的点,贡献 \(1\times w^2\),总计 \(2\times\sum_{i=1}^{w-1} cnt_i\times w^2\)。
对于 \(\operatorname{subtree}(u)\) 中每个权值大于等于 \(w\) 的点(权重为 \(w^\prime\)),贡献 \(1\times {w^\prime}^2\),总计 \(2\times\sum_{i=w}^{10^6}sum_i\)
对于每个节点,我们开一颗线段树。初始时,每个节点的线段树只包含其本身点权。计算完某个点所有儿子的 \(ans\) 之后,我们将所有儿子的线段树合并到其自己上,同时计算贡献。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
| #define N 500010 #define M 1000000 ULL a[N]; int n;
namespace Tree { int head[N],nxt[N*2],ver[N*2],f[N]; int cnt; void insert(int x,int y) { nxt[++cnt]=head[x]; head[x]=cnt; ver[cnt]=y; } };
using Tree::insert; using Tree::head; using Tree::nxt; using Tree::ver;
namespace Seg { struct Node { int ls,rs,l,r; ULL sum,cnt; #define ls(x) a[x].ls #define rs(x) a[x].rs #define l(x) a[x].l #define r(x) a[x].r #define sum(x) a[x].sum #define cnt(x) a[x].cnt }a[N*40]; int cnt; int root[N]; int new_node(int l,int r) { cnt++; l(cnt)=l; r(cnt)=r; return cnt; } void add(int &p,int x) { debug if(p==0) p=new_node(1,M); if(l(p)==r(p)) { debug sum(p)+=(ULL)(x)*x; cnt(p)++; return; } int mid=(l(p)+r(p))/2; if(x<=mid) { if(!ls(p)) ls(p)=new_node(l(p),mid); add(ls(p),x); } else { if(!rs(p)) rs(p)=new_node(mid+1,r(p)); add(rs(p),x); } cnt(p)=cnt(ls(p))+cnt(rs(p)); sum(p)=sum(ls(p))+sum(rs(p)); } int merge(int x,int y,ULL &ans) { if(!x) return y; if(!y) return x; if(l(x)==r(x)) { ans+=2*cnt(x)*sum(y); sum(x)+=sum(y); cnt(x)+=cnt(y); return x; } sum(x)+=sum(y); cnt(x)+=cnt(y); ans+=2*cnt(ls(x))*sum(rs(y)); ans+=2*cnt(ls(y))*sum(rs(x)); ls(x)=merge(ls(x),ls(y),ans); rs(x)=merge(rs(x),rs(y),ans); return x; } };
using Seg::add; using Seg::merge; using Seg::root;
ULL sq[N],ans[N],sum[N];
void dfs(int x,int f) { sum[x]=a[x]; sq[x]=a[x]*a[x]; add(root[x],a[x]); for(int i=head[x];i;i=nxt[i]) { int y=ver[i]; if(y==f) continue; dfs(y,x); sum[x]+=sum[y]; sq[x]+=sq[y]; root[x]=merge(root[x],root[y],sq[x]); } ans[x]=sq[x]-sum[x]*sum[x]; }
int main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); cout.precision(10); int t=1;
while(t--) { cin>>n; for(int i=1;i<n;i++) { int x,y; cin>>x>>y; Tree::insert(x,y); Tree::insert(y,x); } for(int i=1;i<=n;i++) { cin>>a[i]; } dfs(1,0); ULL out=0; for(int i=1;i<=n;i++) { out^=ans[i];
} cout<<out<<endl; } return 0; }
|