给出一棵 nn 个点的无根树,在这棵树上选三个互不相同的节点,使得这个三个节点两两之间距离相等,输出方案数

n105n\le 10^5

BZOJ4543

Solution

不难发现,三个点两两的距离相同只有两种情况:

  1. 存在一个三个点公共的lcalca

  2. 存在一个点,使得这个点到另外两个子树中距离它为dd的点以及这个点的dd次祖先

本质上来说这其实也是同一类

考虑树形dp,通过枚举中心点来统计答案

f[x][i]f[x][i]表示在xx的子树中,距离xxii的点数

g[x][i]g[x][i]表示在xx的子树中,两个点的lcalcaxx的距离为did-i的点数,其中dd为两个点到lcalca的距离

计算答案的话就是ans+=f[y][i1]g[x][i]+f[x][i1]g[y][i]ans += f[y][i-1]*g[x][i]+f[x][i-1]*g[y][i]

更新 的话就是

f[x][i]+=f[y][i1]g[x][i1]+=g[y][i]g[x][i+1]+=f[x][i+1]f[y][i] \begin{aligned} f[x][i] &+= f[y][i - 1]\\ g[x][i-1] &+= g[y][i]\\ g[x][i+1] &+= f[x][i+1]*f[y][i] \end{aligned}

注意先更新答案再转移 ,这样就获得了一个O(n2)O(n^2)的做法


接下来介绍一下长链剖分:

观察dpdp式子可以发现,第二维状态只与深度有关,于是可以考虑使用长链剖分来维护这个dp

具体地,类似于轻重链剖分/dsu on tree的方法,每个点的重儿子定为子树内叶子深度最大的儿子

不难发现,如果只有一个儿子,我们可以直接将数组赋值

用指针实现的话实际上就是f[x]=f[y]1,g[x]=g[y]+1f[x] = f[y] - 1, g[x] = g[y] + 1

于是我们就做到了O(1)O(1)继承重儿子信息,而对于轻儿子则暴力合并

这样做总复杂度是O(n)O(n)

下面是对复杂度的证明:

不难发现所有轻儿子都是某一条重链的顶部,转移时的复杂度是重链长度

那么复杂度拆分成两个部分:直接从重儿子转移O(1)O(1),从轻儿子转移O(heavy_len)O(heavy\_len)

因为每个点有且仅有一个父亲,所以一条重链只被一个点暴力转移,而每次转移复杂度是重链长

所以总复杂度就是O(heavy_len)=O(n)O(\sum heavy\_len)=O(n)


所以如果需要快速合并这一类与深度相关的信息,就可以使用长链剖分了

如果不是与深度相关的话就可以考虑dsu on tree/线段树合并

具体写代码的时候有一些关于指针、动态分配内存的操作,需要熟练

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <bits/stdc++.h>

#define x first
#define y second
#define y1 Y1
#define y2 Y2
#define mp make_pair
#define pb push_back

using namespace std;

typedef long long LL;
typedef pair<int, int> pii;

template <typename T> inline int Chkmax (T &a, T b) { return a < b ? a = b, 1 : 0; }
template <typename T> inline int Chkmin (T &a, T b) { return a > b ? a = b, 1 : 0; }

inline void proc_status()
{
ifstream t ("/proc/self/status");
cerr << string (istreambuf_iterator <char> (t), istreambuf_iterator <char> ()) <<endl;
}

template <typename T> T read ()
{
T sum = 0, fl = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = (sum << 3) + (sum << 1) + ch - '0';
return sum * fl;
}

const int Maxn = 1e5 + 100;

int N, e, Begin[Maxn], To[Maxn << 1], Next[Maxn << 1];

inline void add_edge (int x, int y) { To[++e] = y; Next[e] = Begin[x]; Begin[x] = e; }

LL Pool[Maxn << 2];
LL *f[Maxn], *g[Maxn], *now(Pool + 1), ans;
int dep[Maxn], maxdep[Maxn], son[Maxn];

inline void new_node (int x)
{
/**/f[x] = now, now += ((maxdep[x] - dep[x] + 1) << 1);
/**/g[x] = now, now += ((maxdep[x] - dep[x] + 1) << 1);
}

inline void dfs_pre (int x, int fa)
{
maxdep[x] = dep[x] = dep[fa] + 1;
for (int i = Begin[x]; i; i = Next[i])
{
int y = To[i];
if (y == fa) continue;
dfs_pre (y, x);
Chkmax(maxdep[x], maxdep[y]);
if (maxdep[y] > maxdep[son[x]]) son[x] = y;
}
}

inline void dfs (int x, int fa)
{
if (son[x])
{
f[son[x]] = f[x] + 1, g[son[x]] = g[x] - 1;
dfs(son[x], x);
}
f[x][0] = 1, ans += g[x][0];
for (int i = Begin[x]; i; i = Next[i])
{
int y = To[i];
if (y == fa || y == son[x]) continue;
new_node(y);
dfs(y, x);
for (int j = 1; j <= maxdep[y] - dep[x]; ++j)
ans += 1ll * f[y][j - 1] * g[x][j] + 1ll * f[x][j - 1] * g[y][j];
for (int j = 0; j <= maxdep[y] - dep[x]; ++j)
{
if (j) f[x][j] += f[y][j - 1];
if (j) g[x][j - 1] += g[y][j];
g[x][j + 1] += 1ll * f[x][j + 1] * f[y][j];
}
}
}

inline void Solve ()
{
ans = 0;
dfs_pre(1, 0);
new_node(1);
dfs(1, 0);
cout<<ans<<endl;
}

inline void Input ()
{
for (int i = 1; i < N; ++i) { int x = read<int>(), y = read<int>(); add_edge (x, y), add_edge (y, x); }
}

inline void Init ()
{
e = 0;
memset(Begin, 0, sizeof Begin);
memset(son, 0, sizeof son);
while (now != Pool) *now = 0, --now; *now = 0; now = Pool + 1;
}

int main()
{
#ifndef ONLINE_JUDGE
freopen("three.in", "r", stdin);
freopen("three.out", "w", stdout);
#endif
while (scanf("%d", &N) != EOF)
{
if (!N) break;
Init();
Input();
Solve();
}
return 0;
}

Debug

  • 44, 45L: 动态分配内存的时候只要分配(maxdep[x]dep[x]+1)2(maxdep[x] - dep[x] + 1) * 2,一开始照着yyb博客里分配了maxdep[x]2maxdep[x] * 2,结果RE了