2024杭电钉耙1-1003 HDOJ7435 树

Problem

给一棵根为 1 的有根树,点 \(i\) 具有一个权值 \(A_i\)

定义一个点对的值 \(f(u, v)=\max \left(A_u, A_v\right) \times\left|A_u-A_v\right|\)

你需要对于每个节点 \(i\) ,计算 \(a n s_i=\sum_{u \in \operatorname{subtree}(i), v \in \operatorname{subtree}(i)} f(u, v)\) ,其中 \(\operatorname{subtree}(i)\) 表示 \(i\) 的子树。

请你输出 \(\oplus\left(a n s_i \bmod 2^{64}\right)\) ,其中 \(\oplus\) 表示 XOR。

\(n \leq 5 \times 10^5, 1 \leq A_i \leq 10^6\)

Solution

先来愉快的推式子。

其实 \(\max \left(A_u, A_v\right) \times\left|A_u-A_v\right|\) 其实就是 \(\max^2-\max \cdot \min\),这两部分可以分开思考。

对于 \(\max\cdot\min\),其实就是在 \(i\) 的子树内任选两个点 \(u,v\in \operatorname{subtree}(i)\) 相乘

\[ \begin{align} &\sum_{u} \sum_{v } A_u\times A_v\\ =&\sum_{u }A_u\sum_{v}A_v\\ =&(\sum_{u}A_u)^2 \end{align} \] 对于 \(\max ^2\),也就是 \(\sum_{u,v\in \operatorname{subtree}(i)}(\max(A_u,A_v))^2\),我们需要思考子树合并的情况。

假设我们已经计算了节点 \(u\) 的所有子节点的子树的内部信息,\(v\)\(u\) 的某个儿子,此时我们需要计算

  • \(u\)\(\operatorname{subtree}(v)\) 之间的贡献
  • \(\operatorname{subtree}(v_i)\)\(\operatorname{subtree}(v_j)\) 之间的贡献(即跨点 \(u\) 的两点之间的贡献)

我们按照以下方式合并的同时计算贡献(以下步骤来自于题解)

  • \(\operatorname{subtree}(u)\) 初始为 \(\{u\}\)
  • 计算 \(\operatorname{subtree}(v)\) 和当前 \(\operatorname{subtree}(u)\) 之间点对的答案。(跨越 \(u\) 节点的部分)。
  • \(\operatorname{subtree}(v)\) 子树内的答案直接累加。(不跨越 \(u\) 节点的部分)。
  • \(\operatorname{subtree}(u) \leftarrow \operatorname{subtree}(u)+\operatorname{subtree}(v)\) (将 \(v\) 的子树加入到 \(u\) 中)。

我们需要维护两个变量:一个子树内的权值出现次数 \(cnt\) 与权值平方和 \(sum\)

当前子树 \(\operatorname{subtree}(u)\) 内加入一个权重为 \(w\) 的点,对于答案贡献多少呢?

  • 对于 \(\operatorname{subtree}(u)\) 中每个权值小于 \(w\) 的点,贡献 \(1\times w^2\),总计 \(2\times\sum_{i=1}^{w-1} cnt_i\times w^2\)

  • 对于 \(\operatorname{subtree}(u)\) 中每个权值大于等于 \(w\) 的点(权重为 \(w^\prime\)),贡献 \(1\times {w^\prime}^2\),总计 \(2\times\sum_{i=w}^{10^6}sum_i\)

对于每个节点,我们开一颗线段树。初始时,每个节点的线段树只包含其本身点权。计算完某个点所有儿子的 \(ans\) 之后,我们将所有儿子的线段树合并到其自己上,同时计算贡献。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#define N 500010
#define M 1000000
ULL a[N];
int n;

namespace Tree
{
int head[N],nxt[N*2],ver[N*2],f[N];
int cnt;
void insert(int x,int y)
{
nxt[++cnt]=head[x];
head[x]=cnt;
ver[cnt]=y;
}

};

using Tree::insert;
using Tree::head;
using Tree::nxt;
using Tree::ver;

namespace Seg
{
struct Node
{
int ls,rs,l,r;
ULL sum,cnt;
#define ls(x) a[x].ls
#define rs(x) a[x].rs
#define l(x) a[x].l
#define r(x) a[x].r
#define sum(x) a[x].sum
#define cnt(x) a[x].cnt
}a[N*40];
int cnt;
int root[N];
int new_node(int l,int r)
{
cnt++;
l(cnt)=l;
r(cnt)=r;
return cnt;
}

void add(int &p,int x)
{
debug
if(p==0) p=new_node(1,M);
if(l(p)==r(p))
{
debug
sum(p)+=(ULL)(x)*x;
cnt(p)++;
return;
}
int mid=(l(p)+r(p))/2;
if(x<=mid)
{
if(!ls(p)) ls(p)=new_node(l(p),mid);
add(ls(p),x);
}
else
{
if(!rs(p)) rs(p)=new_node(mid+1,r(p));
add(rs(p),x);
}
cnt(p)=cnt(ls(p))+cnt(rs(p));
sum(p)=sum(ls(p))+sum(rs(p));
}

int merge(int x,int y,ULL &ans)
{
if(!x) return y;
if(!y) return x;
if(l(x)==r(x))
{
ans+=2*cnt(x)*sum(y);
sum(x)+=sum(y);
cnt(x)+=cnt(y);
return x;
}

sum(x)+=sum(y);
cnt(x)+=cnt(y);

ans+=2*cnt(ls(x))*sum(rs(y));
ans+=2*cnt(ls(y))*sum(rs(x));

ls(x)=merge(ls(x),ls(y),ans);
rs(x)=merge(rs(x),rs(y),ans);

return x;
}

};

using Seg::add;
using Seg::merge;
using Seg::root;

ULL sq[N],ans[N],sum[N];

void dfs(int x,int f)
{
sum[x]=a[x];
sq[x]=a[x]*a[x];
add(root[x],a[x]);
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y==f) continue;
dfs(y,x);
sum[x]+=sum[y];
sq[x]+=sq[y];
root[x]=merge(root[x],root[y],sq[x]);
}
ans[x]=sq[x]-sum[x]*sum[x];
}

int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cout.precision(10);
int t=1;
// cin>>t;
while(t--)
{
cin>>n;
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
Tree::insert(x,y);
Tree::insert(y,x);
}
for(int i=1;i<=n;i++)
{
cin>>a[i];
}

dfs(1,0);

ULL out=0;
for(int i=1;i<=n;i++)
{
out^=ans[i];
// cout<<ans[i]<<" ";
}
cout<<out<<endl;

}
return 0;
}