一颗nn个节点的树,可以花pip_i代价把ii点染色,要求任2个相邻点至少有1个被染色。

mm组询问,每次强制两个点的状态(染/不染),求出每次的最小花费。

n,m,pi100000n,m,p_i\le 100000

Luogu P5024

Solution

暴力dp:f[i][0/1]f[i][0/1]表示以ii为根的子树中所有节点,ii号点不选/选,所花费的最小代价

考虑在此基础上再多记录一些东西

  • g[i][0/1]g[i][0/1]表示以ii号点不选/选,整棵树的最小代价

    这个东西可以通过ff求出

  • dp[i][j][0/1][0/1]dp[i][j][0/1][0/1]表示ii号点向上跳2j2^j的父亲ff的子树去掉以ii为根的子树后,ii号点不选/选,ff号点不选/选,所花费的最小代价(这里的去掉是指当作这个子树不存在,而并不仅仅是把子树的答案减掉)

    这个东西可以通过倍增预处理出来

接下来,询问就很好处理了,分两种情况讨论:

  • 一个点为另一个点的祖先:直接把较深的节点aa跳到另一个点bb的儿子处,然后讨论一下bb的选择情况

  • 否则,先把两个点都跳到它们lcalca的儿子节点处,然后讨论lcalca的选择情况

然后用预处理出的ffggdpdp算答案即可(具体见代码)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <bits/stdc++.h>

#define x first
#define y second
#define y1 Y1
#define y2 Y2
#define mp make_pair
#define pb push_back

using namespace std;

inline void proc_status ()
{
ifstream t ("/proc/self/status");
cerr << string (istreambuf_iterator <char> (t), istreambuf_iterator <char> ()) << endl;
}

typedef long long LL;
typedef pair <int, int> pii;

template <typename T> inline int Chkmin (T &a, T b) {return a > b ? a = b, 1 : 0; }
template <typename T> inline int Chkmax (T &a, T b) {return a < b ? a = b, 1 : 0; }

inline int read ()
{
int sum = 0, fl = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = (sum << 3) + (sum << 1) + ch - '0';
return sum * fl;
}

const int Maxn = 1e5 + 100;
const LL inf = 11000000000;

int N, M, A[Maxn];
int e, Begin[Maxn], To[Maxn << 1], Next[Maxn << 1];

inline void add_edge (int x, int y) { To[++e] = y; Next[e] = Begin[x]; Begin[x] = e; }

int a, val1, b, val2;
int anc[Maxn][22], dep[Maxn];
LL f[Maxn][2], g[Maxn][2];

struct mat
{
LL a[2][2];
}Dp[Maxn][20];

inline void dfs1 (int x, int father)
{
anc[x][0] = father;
dep[x] = dep[father] + 1;
for (int i = 1; i <= 17; ++i) anc[x][i] = anc[anc[x][i - 1]][i - 1];
for (int i = Begin[x]; i; i = Next[i])
{
int y = To[i];
if (y == father) continue;
dfs1(y, x);
f[x][0] += f[y][1];
f[x][1] += min(f[y][0], f[y][1]);
}
f[x][1] += A[x];
}

inline void dfs2 (int x)
{
for (int i = Begin[x]; i; i = Next[i])
{
int y = To[i];
if (y == anc[x][0]) continue;
Dp[y][0].a[0][0] = inf;
Dp[y][0].a[1][0] = f[x][0] - f[y][1];
Dp[y][0].a[0][1] = Dp[y][0].a[1][1] = f[x][1] - min(f[y][0], f[y][1]);
LL sum = g[x][1] - min(f[y][0], f[y][1]);
g[y][0] = f[y][0] + sum;
g[y][1] = f[y][1] + min(sum, g[x][0] - f[y][1]);
dfs2(y);
}
}

inline mat merge (const mat &A, const mat &B)
{
mat C;
memset(C.a, 0x3f, sizeof C.a);
for (int a = 0; a < 2; ++a)
for (int b = 0; b < 2; ++b)
for (int c = 0; c < 2; ++c)
Chkmin(C.a[a][b], A.a[a][c] + B.a[c][b]);
return C;
}

inline void Solve ()
{
dfs1(1, 0);
g[1][0] = f[1][0];
g[1][1] = f[1][1];
dfs2(1);
for (int j = 1; j <= 17; ++j)
{
for (int i = 1; i <= N; ++i)
Dp[i][j] = merge (Dp[i][j - 1], Dp[anc[i][j - 1]][j - 1]);
}
while (M--)
{
a = read(), val1 = read(), b = read(), val2 = read();
if (dep[a] < dep[b]) swap(a, b), swap(val1, val2);
if (!val1 && !val2 && anc[a][0] == b) { puts("-1"); continue; }

mat A, B;
memset(A.a, 0x3f, sizeof A.a), memset(B.a, 0x3f, sizeof B.a);
A.a[val1][val1] = f[a][val1], B.a[val2][val2] = f[b][val2];

for (int i = 17; i >= 0; --i)
if (dep[anc[a][i]] > dep[b])
A = merge (A, Dp[a][i]), a = anc[a][i];

if (anc[a][0] == b)
{
if (!val2) printf("%lld\n", g[b][0] - f[a][1] + A.a[val1][1]);
else printf("%lld\n", g[b][1] - min(f[a][0], f[a][1]) + min(A.a[val1][0], A.a[val1][1]));
}
else
{
if (dep[a] > dep[b]) A = merge (A, Dp[a][0]), a = anc[a][0];
for (int i = 17; i >= 0; --i)
if (anc[a][i] != anc[b][i])
{
A = merge (A, Dp[a][i]), B = merge (B, Dp[b][i]);
a = anc[a][i], b = anc[b][i];
}

int lca = anc[a][0];
LL sum0 = g[lca][0] - f[a][1] - f[b][1];
LL sum1 = g[lca][1] - min(f[a][0], f[a][1]) - min(f[b][0], f[b][1]);
sum0 += A.a[val1][1] + B.a[val2][1];
sum1 += min(A.a[val1][0], A.a[val1][1]) + min(B.a[val2][0], B.a[val2][1]);
printf("%lld\n", min(sum0, sum1));
}
}
}

inline void Input ()
{
N = read(), M = read();
char type[3];
scanf("%s", type);
for (int i = 1; i <= N; ++i) A[i] = read();
for (int i = 1; i < N; ++i)
{
int x = read(), y = read();
add_edge (x, y);
add_edge (y, x);
}
}

int main()
{
#ifdef hk_cnyali
freopen("defense.in", "r", stdin);
freopen("defense.out", "w", stdout);
#endif
Input();
Solve();
return 0;
}