树链剖分学习笔记

树链剖分是一种解决树上操作问题的有力工具,其本质思想是把树上问题转化为序列问题,再套用序列的数据结构(如线段树,树状数组)进行解决。

$1.$ 预处理

树剖的预处理主要由两次 dfs 组成。

先定义一些东西:

1.一个非叶子节点的重儿子指的是:该节点所有儿子中,子树大小最大的那个儿子的编号。

2.重边指的是:任意一个非叶子节点连向其重儿子的边。

3.轻边是除了重边外的所有边。

4.一条或多条重边相连,组成重链,特别的,一个节点也是一条重链。

例如以下这棵树:
oE7Ye.md.png

我们可以着重画出它的重边:
oEzcO.md.png
那么在这棵树中,有 $5$ 条重链,分别为:

这些东西有什么用呢?这些东西有一些很好的性质,可以帮助我们完成从树上问题到序列问题的转化。

性质 $1$:如果边 $(u,v)$ 是轻边,那么 $\operatorname{size}(v)\le\dfrac{\operatorname{size}(u)}{2}$。

证明:可以反证,若存在一条轻边 $(u,v)$ 使 $\operatorname{size}(v)>\operatorname{size}(u)/2$ 成立,那么 $\operatorname{size}(v)$ 肯定是$u$所有儿子中最大的,那么 $(u,v)$ 就应该是重边,与假设不符。

性质 $2$:在一棵有 $n$ 个节点的树中,根节点到任一节点 $v$ 的路径上,至多存在 $\log n$ 条轻边。

证明:显然在 $v$ 为叶子节点时,可以经过的轻边最多,那么根据性质 $1$ 可以知道,从根节点每往下走一层,若走的是轻边的话,该子树的 $\operatorname{size}$ 至少会除以2,所以最多走 $\log n$ 条轻边便会走到叶子节点。

性质 $3$:在一棵有 $n$ 个节点的树中,根节点到任一节点 $v$ 的路径上,至多存在 $\log n$ 条重链。

证明:显然我们可以发现,一条重链的两端肯定是轻边,而轻边最多有 $\log n$ 条,那么重链也最多有 $\log n$ 条。

性质 $4$:对于任意两点间的路径 $(u,v)$,这条路径上的重链条数是 $\log n$ 级的。

证明:考虑最坏情况,即 $\operatorname{lca}(u,v)=root$,那么$u$到根节点的重链条数是 $\log n$ 级的,$v$ 到根节点的重链条数也是 $\log n$ 级的。所以可以发现对于任意的路径,重链条数均为 $\log n$ 级,这为计算树剖的时间复杂度提供了基础。

再给出一些数组的定义:

size[i]表示以 $i$ 为根的子树大小(包括 $i$ 本身)。

dep[i]表示节点 $i$ 的深度(不妨设根节点的深度为 $1$)。

fa[i]表示节点 $i$ 的父亲的节点。

son[i]表示非叶子节点 $i$ 的的重儿子的编号。

第一次 dfs 就是要求出这些值,以上面那棵树为例(需要注意的是,第一次 dfs 对于访问子节点的顺序是无所谓的):

fa[] = { 0, 0, 1, 1, 1, 2, 2, 3, 6, 6, 7}
dep[] = {0, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4}
size[] ={0,10, 5, 3, 1, 1, 2, 1, 1, 1, 1}
son[] = {0, 0, 2, 7, 0, 0, 8,10, 0, 0, 0}

贴上第一次 dfs 的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int son[N], size[N], dep[N], fa[N];
void dfs1(int now, int Fa){
int v, Max = -1;
size[now] = 1;
for (int i = head[now]; i; i = e[i].nxt){
v = e[i].to;
if (v == Fa)continue;
fa[v] = now;
dep[v] = dep[now] + 1;
dfs1(v, now);
size[now] += size[v];
if (size[v] > Max)
Max = size[v], son[now] = v;
}
}

接下来是第二次 dfs ,完成这次 dfs 后,便可以实现从树上问题到序列问题的转化了。

还是一样,先知道目前需要计算的几个数组:

top[i]表示节点 $i$ 所在的重链的链顶的编号。

id[i]表示将树转化为序列后,节点 $i$ 的信息保存在下标为id[i]的数据结构(如线段树)中。

bui[i]表示将树转化为序列后的数组。

不妨先思考如何计算top[i]

对于边 $(u,v)$,如果这条边是重边,显然 $\operatorname{top}(v)=\operatorname{top}(u)$,反之,则 $\operatorname{top}(v)=v$(也就是自己是自己这条重链的起点)
在进行第二次 dfs 之前,需要先了解重链剖分是如何的:

我们需要保证,在转成序列之后,每条重链的序列上都是连续的.又根据性质四,我们将树上两点的路径拆成不超过$\log n$条重链,并分别维护。

综上所述,为了保证重链的编号连续,我们遍历到每一个非叶子节点时,都要优先遍历其重儿子
所以第二次dfs的代码便得出了(建议结合代码看上面的讲解):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int sum, top[N], id[N], input[N], bui[N];//input[]表示初始输入的点权
void dfs2(int now, int topnow){
id[now] = ++sum;
bui[sum] = input[now];
top[now] = topnow;
if (!son[now])return;
dfs2(son[now], topnow);
int v;
for (int i = head[now]; i; i = e[i].nxt){
v = e[i].to;
if (!id[v])
dfs2(v, v);
}
}

oEzcO.md.png
那么对于前面这棵树,top数组和id数组的计算结果为(由于未给出点权,故只展现这两个数组的计算结果):
top[] = {0, 1, 1, 3, 4, 5, 1, 3, 1, 9, 3}
id[] = { 0, 1, 2, 7,10, 6, 3, 8, 4, 5, 9}

$2.$ 操作

有关子树的操作

这部分的内容相对简单,因为对于一棵子树下的点,他们转到序列上后必然对应着一段连续的区间
比如上面那棵树以 $2$ 为根的子树,可以发现 $2,6,8,9,5$ 所对应的的 $\operatorname{id}\text{分别为}2,3,4,5,6$,而区间长度自然为 $\operatorname{size}(2)=5$,所以每次对于以 $i$ 为根的子树,在序列上所对应的区间为 $[\operatorname{id}(i), \operatorname{id}(i)+\operatorname{size}(i))$。
放代码

1
2
3
4
5
6
void update2(int x, int z){//把以x为根的子树的点权全部加z
update(id[x], id[x] + size[x] - 1, z);
}
int query2(int x){//查询以x为根的子树的点权和
return query(id[x], id[x] + size[x] - 1);
}

显然,单次操作的时间复杂度为 $\operatorname{O}(\log n)$。

有关树上路径的操作

这一部分的操作稍为复杂。

不妨设对边$(u,v)$进行操作。首先考虑一种简单的情况,若这两个点位于同一条重链,那么因为同一重链的点在序列上的编号是连续的,可以直接进行修改。

更进一步,如果在不同的链上呢?
跳到同一条链上就好了啊!
我们采用类似倍增求 LCA 的方法,让两个点不断往上跳,直到跳到同一条重链上,并且边跳边维护。

以上面那棵树为例,比如说我们要维护路径 $(9,7)$,可以看到 $9$ 的链顶是 $9$,$7$ 的链顶是 $3$,由于 $9$ 比 $3$ 深,所以我们让 $9$ 跳(想一想,为什么让链顶深度大的跳)。所以维护的是 $[\operatorname{id}(9),\operatorname{id}(9)]$ 这个区间。
跳完后,让 $9$ 变成它链顶的父节点,也就是 $6$.现在需要维护的是 $(6,7)$,
$6$ 的链顶是 $1$,$7$ 的链顶是 $3$,$3$ 更深,所以维护 $[\operatorname{id}(3),\operatorname{id}(7)]$,$7$ 变成 $3$ 的父节点 $1$。
这时,$6$ 和 $1$ 在同一条重链上,直接维护。

回顾刚才的操作,分析时间复杂度。可以发现决定时间复杂度的是两点间路径所包含的重链的条数。我们已经知道了这个数是$\log n$级的,所以总操作的时间复杂度是$\log ^2n$.
放代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void update1(int x, int y, int z){//将路径(x,y)上的点权全部加z
while(top[x] != top[y]){
if (dep[top[x]] < dep[top[y]]) swap(x, y);
update(id[top[x]], id[x], z);
x = fa[top[x]];
}
if (dep[x] > dep[y])swap(x, y);
update(id[x], id[y], z);
}
int query1(int x, int y){//查询路径(x,y)上的点权和(对p取模)
int s = 0;
while(top[x] != top[y]){
if (dep[top[x]] < dep[top[y]]) swap(x, y);
s = (s + query(id[top[x]], id[x])) % p;
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
s = (s + query(id[x], id[y])) % p;
return s;
}

P3384的完整代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include <iostream>
#include <cstdio>
#include <cstdlib>
using namespace std;
const int N = 1e5 + 5;
struct edge{int nxt, to;}e[200001];
int head[N], cnt;
void add(int u, int v){
e[++cnt].nxt = head[u];
e[cnt].to = v;
head[u] = cnt;
}
int son[N], size[N], dep[N], fa[N];
void dfs1(int now, int Fa){
int v, Max = -1;
size[now] = 1;
for (int i = head[now]; i; i = e[i].nxt){
v = e[i].to;
if (v == Fa)continue;
fa[v] = now;
dep[v] = dep[now] + 1;
dfs1(v, now);
size[now] += size[v];
if (size[v] > Max)
Max = size[v], son[now] = v;
}
}
int sum, top[N], id[N], input[N], bui[N];
void dfs2(int now, int topnow){
id[now] = ++sum;
bui[sum] = input[now];
top[now] = topnow;
if (!son[now])return;
dfs2(son[now], topnow);
int v;
for (int i = head[now]; i; i = e[i].nxt){
v = e[i].to;
if (!id[v])
dfs2(v, v);
}
}

int n, p;
struct Seg_T{
int l, r, s, tag;
}t[N << 2];
#define ls(x) x << 1
#define rs(x) x << 1 | 1
void build(int l = 1, int r = n, int now = 1){
t[now].l = l, t[now].r = r;
if (l == r){
t[now].s = bui[l] % p;
return;
}
int mid = l + r >> 1;
build(l, mid, ls(now));
build(mid + 1, r, rs(now));
t[now].s = (t[ls(now)].s + t[rs(now)].s) % p;
}
void pushdown(int now){
if (!t[now].tag)return;
t[ls(now)].tag = (t[ls(now)].tag + t[now].tag) % p;
t[rs(now)].tag = (t[rs(now)].tag + t[now].tag) % p;
t[ls(now)].s = (t[ls(now)].s + t[now].tag * (t[ls(now)].r - t[ls(now)].l + 1) % p) % p;
t[rs(now)].s = (t[rs(now)].s + t[now].tag * ((t[rs(now)].r - t[rs(now)].l + 1) % p) % p) % p;
t[now].tag = 0;
}
void update(int l, int r, int k, int now = 1){
if (l <= t[now].l && t[now].r <= r){
t[now].tag = (t[now].tag + k) % p;
t[now].s = (t[now].s + k * (t[now].r - t[now].l + 1) % p) % p;
return;
}
pushdown(now);
if (t[ls(now)].r >= l)update(l, r, k, ls(now));
if (t[rs(now)].l <= r)update(l, r, k, rs(now));
t[now].s = (t[ls(now)].s + t[rs(now)].s) % p;
}
int query(int l, int r, int now = 1){
if (l <= t[now].l && t[now].r <= r)return t[now].s;
int s = 0;
pushdown(now);
if (t[ls(now)].r >= l)s = query(l, r, ls(now));
if (t[rs(now)].l <= r)s = (s + query(l, r, rs(now))) % p;
return s;
}

void update1(int x, int y, int z){
while(top[x] != top[y]){
if (dep[top[x]] < dep[top[y]]) swap(x, y);
update(id[top[x]], id[x], z);
x = fa[top[x]];
}
if (dep[x] > dep[y])swap(x, y);
update(id[x], id[y], z);
}
int query1(int x, int y){
int s = 0;
while(top[x] != top[y]){
if (dep[top[x]] < dep[top[y]]) swap(x, y);
s = (s + query(id[top[x]], id[x])) % p;
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
s = (s + query(id[x], id[y])) % p;
return s;
}
void update2(int x, int z){
update(id[x], id[x] + size[x] - 1, z);
}
int query2(int x){
return query(id[x], id[x] + size[x] - 1);
}

inline void read(int &x){
int f = 1;char ch;
while(ch = getchar(), ch < '0' || ch > '9')if(ch == '-')f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9')
x = x * 10 + ch - '0';
x *= f;
}
int main (){
int m, root;
read(n), read(m), read(root), read(p);
for (int i = 1; i <= n; i++)
read(input[i]);
int u, v;
for (int i = 1; i < n; i++){
read(u), read(v);
add(u, v);
add(v, u);
}
dep[root] = 1;
dfs1(root, 0);
dfs2(root, root);
build();
int ch, x, y, z;
for (int i = 1; i <= m; i++){
read(ch);
switch(ch){
case 1:read(x), read(y), read(z);update1(x, y, z);break;
case 2:read(x), read(y);printf("%d\n", query1(x, y));break;
case 3:read(x), read(z);update2(x, z);break;
case 4:read(x);printf("%d\n", query2(x));
}
}
system("pause");
return 0;
}

+ +