给你一棵 nn 个结点的二叉树,定义结点 xx 的权值为:

  1. xx 没有子结点,那么它的权值会在输入里给出,保证这类点中每个结点的权值互不相同

  2. xx 有子结点,那么它的权值有 pxp_x 的概率是它的子结点的权值的最大值,有 1px1-p_x的概率是它的子结点的权值的最小值。

假设 11 号结点的权值有 mm 种可能性,权值第 ii的可能性的权值是 ViV_i,它的概率为 DiD_i,求:

(i=1miViDi2) mod 998244353 (\sum_{i=1}^{m}i\cdot V_i\cdot D_i^2)~ \mathrm{mod}~998244353

n3105,wi109n\le 3*10^5, w_i\le10^9

LOJ2537

Solution

首先考虑n2n^2dp,设dp[x][i]dp[x][i]表示xx号点取到排名为ii权值的概率

通过枚举两个儿子的权值直接转移是O(n3)O(n^3)的,用前缀和优化就变成O(n2)O(n^2)的了

转移式为

dp[x][i]=son=01dp[ch[x][son]][i]×Pre[son1][i1]×p[x]+dp[ch[x][son]][i]×Suf[son1][i+1]×(1p[x]) \begin{aligned} dp[x][i] = &\sum_{son=0}^{1}dp[ch[x][son]][i] \times Pre[son\oplus 1][i-1]\times p[x]\\ +&dp[ch[x][son]][i]\times Suf[son\oplus1][i + 1]\times(1-p[x]) \end{aligned}

然后考虑优化

通过观察dp方程可以发现这是一个很经典的套路

注意到转移实际上有两个限制,一是合并子树信息,二是与权值排名大小有关的一个限制

因此可以用线段树合并来优化,需要维护一个区间乘标记

代码中的sumx,sumysumx,sumy就是分别维护两棵子树的Pre[][]×p[x]+Suf[][]×(1p[x])Pre[][]\times p[x]+Suf[][]\times(1-p[x])

一个这么傻逼的地方我理解了好久,对线段树合并的理解还是不太够啊。。。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <bits/stdc++.h>

#define x first
#define y second
#define y1 Y1
#define y2 Y2
#define mp make_pair
#define pb push_back

using namespace std;

typedef long long LL;
typedef pair<int, int> pii;

template <typename T> inline int Chkmax (T &a, T b) { return a < b ? a = b, 1 : 0; }
template <typename T> inline int Chkmin (T &a, T b) { return a > b ? a = b, 1 : 0; }

inline void proc_status()
{
ifstream t ("/proc/self/status");
cerr << string (istreambuf_iterator <char> (t), istreambuf_iterator <char> ()) <<endl;
}

template <typename T> T read ()
{
T sum = 0, fl = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = (sum << 3) + (sum << 1) + ch - '0';
return sum * fl;
}

const int Maxn = 3e5 + 100, Mod = 998244353;

inline int Pow (int a, int b)
{
int ans = 1;
while (b) { if (b & 1) ans = 1ll * ans * a % Mod; a = 1ll * a * a % Mod; b >>= 1; }
return ans;
}

int N, inv = Pow(10000, Mod - 2);
int ch[Maxn][3], son[Maxn], fa[Maxn], P[Maxn], A[Maxn], Hash[Maxn];
int Root[Maxn * 20], sum_leaf;

namespace SEG
{
#define ls Tree[root].ch[0]
#define rs Tree[root].ch[1]
#define lson ls, l, mid
#define rson rs, mid + 1, r
int node_cnt;
struct tree
{
int ch[2], sum, tag;
}Tree[Maxn * 20];

inline void push_up (int root) { Tree[root].sum = (Tree[ls].sum + Tree[rs].sum) % Mod; }

inline void push_down (int root)
{
if (Tree[root].tag <= 1) return ;
Tree[ls].tag = 1ll * Tree[ls].tag * Tree[root].tag % Mod;
Tree[ls].sum = 1ll * Tree[ls].sum * Tree[root].tag % Mod;
Tree[rs].tag = 1ll * Tree[rs].tag * Tree[root].tag % Mod;
Tree[rs].sum = 1ll * Tree[rs].sum * Tree[root].tag % Mod;
Tree[root].tag = 1;
}

inline void update (int &root, int l, int r, int x, int val)
{
if (!root) root = ++node_cnt, Tree[root].tag = 1;
if (l == r) { Tree[root].sum = val; return ; }
int mid = l + r >> 1;
push_down(root);
if (x <= mid) update (lson, x, val);
else update (rson, x, val);
push_up(root);
}

inline int merge (int x, int y, int sumx, int sumy, int p)
{
if (!y) { Tree[x].sum = 1ll * Tree[x].sum * sumy % Mod; Tree[x].tag = 1ll * Tree[x].tag * sumy % Mod; return x; }
if (!x) { Tree[y].sum = 1ll * Tree[y].sum * sumx % Mod; Tree[y].tag = 1ll * Tree[y].tag * sumx % Mod; return y; }
push_down(x), push_down (y);
int xl = Tree[Tree[x].ch[0]].sum, xr = Tree[Tree[x].ch[1]].sum, yl = Tree[Tree[y].ch[0]].sum, yr = Tree[Tree[y].ch[1]].sum;
Tree[x].ch[0] = merge (Tree[x].ch[0], Tree[y].ch[0], (sumx + 1ll * (1 - p + Mod) % Mod * xr % Mod) % Mod, (sumy + 1ll * (1 - p + Mod) % Mod * yr % Mod) % Mod, p);
Tree[x].ch[1] = merge (Tree[x].ch[1], Tree[y].ch[1], (sumx + 1ll * p * xl % Mod) % Mod, (sumy + 1ll * p * yl % Mod) % Mod, p);
push_up(x);
return x;
}

inline int query (int root, int l, int r)
{
if (l == r) return 1ll * l * Hash[l] % Mod * Tree[root].sum % Mod * Tree[root].sum % Mod;
int mid = l + r >> 1;
push_down(root);
return (query (lson) + query(rson)) % Mod;
}
}

inline void dfs (int x)
{
if (ch[x][0]) dfs(ch[x][0]);
if (ch[x][1]) dfs(ch[x][1]);
if (!son[x]) SEG :: update (Root[x], 1, sum_leaf, A[x], 1);
else if (son[x] == 1) Root[x] = Root[ch[x][0]];
else Root[x] = SEG :: merge (Root[ch[x][0]], Root[ch[x][1]], 0, 0, P[x]);
}

inline void Solve ()
{
dfs(1);
printf("%d\n", SEG :: query (Root[1], 1, sum_leaf));
}

inline void Input ()
{
N = read<int>();
for (int i = 1; i <= N; ++i) fa[i] = read<int>(), ch[fa[i]][son[fa[i]]++] = i;
for (int i = 1; i <= N; ++i)
{
if (!son[i]) A[i] = Hash[++sum_leaf] = read<int>();
else P[i] = 1ll * read<int>() * inv % Mod;
}

sort(Hash + 1, Hash + sum_leaf + 1);
for (int i = 1; i <= N; ++i) A[i] = lower_bound(Hash + 1, Hash + sum_leaf + 1, A[i]) - Hash;
}

int main()
{
#ifndef ONLINE_JUDGE
freopen("A.in", "r", stdin);
freopen("A.out", "w", stdout);
#endif
Input();
Solve();
return 0;
}