P3369【模板】普通平衡树复盘:01Trie 维护有序多重集

前言

P3369 【模板】普通平衡树 是一道经典的数据结构模板题。常规做法通常是 Treap、Splay、fhq Treap 或 pbds。

这篇文章记录一种非常规实现:用 01Trie 维护有序多重集合
只要值域可控,题目要求的六类操作:

  • 插入
  • 删除
  • 查询排名
  • 查询第 kk
  • 查询前驱
  • 查询后继

都可以通过 Trie 上的计数信息完成。


一、问题本质

题目要求维护一个支持顺序统计的可重集。
从本质上看,它并不强制要求使用“平衡树”,而是要求维护以下信息:

  1. 元素集合中的有序性
  2. 某个值前面有多少元素
  3. 某个排名对应哪个值

若能在 01Trie 上维护“经过某节点的元素个数”,就可以完成这些操作。


二、核心思路

01Trie 按二进制从高位到低位建树:

  • 左儿子表示当前位为 0
  • 右儿子表示当前位为 1

在高位前缀相同的前提下:

左子树中的所有数一定小于右子树中的所有数

因此,Trie 也具备“局部有序性”。
若再维护每个节点经过了多少个元素,就可以像顺序统计树一样支持:

  • 查询严格小于某值的元素个数
  • 查询第 kk 小元素

前驱、后继则可由上述两个操作进一步推导。


三、值域平移

代码中使用了:

1
const int offset = 1e7;

所有输入值统一平移为:

1
x + offset

目的是将原本可能出现的负数整体映射到非负范围,便于按普通二进制进行 Trie 维护。

若原值范围为:

[107, 107] [-10^7,\ 10^7]

则平移后范围为:

[0, 2×107] [0,\ 2\times 10^7]

输出答案时再减去 offset 即可恢复原值。


四、位数选择

代码按如下方式枚举二进制位:

1
for (int i = 25; i >= 0; --i)

原因是平移后的最大值约为 2×1072\times 10^7。而:

  • 224=167772162^{24}=16777216
  • 225=335544322^{25}=33554432

因此需要使用第 25 位到第 0 位,共 26 位,足以覆盖全部取值。


五、维护的信息

核心数组如下:

1
int trie[maxn][2], cnt[maxn];

含义为:

  • trie[p][0]:节点 p 的 0 儿子
  • trie[p][1]:节点 p 的 1 儿子
  • cnt[p]:经过节点 p 的元素个数

这里的 cnt[p] 可以理解为:
以该节点为根的这棵子树中,共有多少个元素经过这个前缀。

由于所有数都固定走满 26 层,因此不需要单独维护结束标记。


六、插入与删除

1. 插入

1
2
3
4
5
6
7
8
9
10
11
void Insert(int x) {
int p = 0;
for (int i = 25; i >= 0; --i) {
int v = x >> i & 1;
if (!trie[p][v]) {
trie[p][v] = ++idx;
}
p = trie[p][v];
cnt[p]++;
}
}

从高位到低位依次取出每一位:

  • 若对应儿子不存在,则新建节点
  • 沿路径向下走
  • 将经过节点的计数加一

插入完成后,x 所在路径上的所有节点计数均被正确维护。


2. 删除

1
2
3
4
5
6
7
8
void Delete(int x) {
int p = 0;
for (int i = 25; i >= 0; --i) {
int v = x >> i & 1;
p = trie[p][v];
cnt[p]--;
}
}

删除时沿原路径走一遍,并将沿途 cnt 减一即可。
不必真正删除节点,只需保证计数正确。

该写法默认题目保证删除操作合法,即待删除元素一定存在。


七、排名查询 getRank

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int getRank(int x) {
int p = 0, rank = 0;
for (int i = 25; i >= 0; --i) {
int v = x >> i & 1;
if (v) {
rank += cnt[trie[p][0]];
}
p = trie[p][v];
if (!p) {
break;
}
}
return rank;
}

该函数返回的不是题目中的“排名”,而是:

严格小于 x 的元素个数

因此主函数中查询排名时需要输出:

1
getRank(x + offset) + 1

正确性分析

从高位到低位考虑当前位。

设当前位为 v

v = 0

若某个数在当前位取 1,则在高位前缀相同的前提下,它一定大于 x
因此不会对“小于 x 的元素个数”产生贡献,直接沿 0 分支继续即可。

v = 1

此时,所有高位前缀与 x 相同、但当前位取 0 的数,一定严格小于 x
因此可以直接累计:

1
rank += cnt[trie[p][0]];

随后继续沿 1 分支向下,统计剩余部分。

提前退出

若某一步 p 变为 0,说明当前前缀已不存在,后续更低位不可能再产生贡献,可以直接结束。


八、第 kk 小查询 getVal

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
int getVal(int x) {
int p = 0, val = 0;
for (int i = 25; i >= 0; --i) {
if (cnt[trie[p][0]] >= x) {
p = trie[p][0];
} else {
x -= cnt[trie[p][0]];
p = trie[p][1];
val |= 1 << i;
}
}
return val;
}

该函数返回当前集合中的第 x 小元素,x 为 1-based。


正确性分析

在某个 Trie 节点处:

  • 左子树对应当前位为 0
  • 右子树对应当前位为 1

由于当前位 0 < 1,因此左子树中的所有元素都小于右子树中的所有元素。

于是:

  • 若左子树元素个数不少于 x,则第 x 小一定在左子树中
  • 否则,第 x 小一定在右子树中,同时应减去左子树的元素个数

若走向右子树,则当前位应置为 1:

1
val |= 1 << i;

最终构造出完整数值。


九、前驱与后继

1. 前驱

1
2
3
int getPrev(int x) {
return getVal(getRank(x));
}

前驱定义为“严格小于 x 的最大值”。

设严格小于 x 的元素个数为 k,那么这些元素中最大的一个,恰好就是第 k 小元素,因此:

Prev(x)=getVal(getRank(x)) \text{Prev}(x)=\text{getVal}(\text{getRank}(x))

2. 后继

1
2
3
int getNext(int x) {
return getVal(getRank(x + 1) + 1);
}

后继定义为“严格大于 x 的最小值”。

由于 getRank(y) 返回的是严格小于 y 的元素个数,因此:

cnt(x)=cnt(<x+1) \text{cnt}(\le x)=\text{cnt}(<x+1)

所以:

1
getRank(x + 1)

表示集合中小于等于 x 的元素个数。

那么严格大于 x 的最小元素,其排名就是:

1
getRank(x + 1) + 1

再通过 getVal 取得对应值即可。


十、主函数中的操作对应

1
2
3
4
5
6
7
8
9
10
11
12
13
if (op == 1) {
Insert(x + offset);
} else if (op == 2) {
Delete(x + offset);
} else if (op == 3) {
cout << getRank(x + offset) + 1 << '\n';
} else if (op == 4) {
cout << getVal(x) - offset << '\n';
} else if (op == 5) {
cout << getPrev(x + offset) - offset << '\n';
} else {
cout << getNext(x + offset) - offset << '\n';
}

各操作含义如下:

  • 1 x:插入 x
  • 2 x:删除一个 x
  • 3 x:查询 x 的排名
  • 4 x:查询第 x 小的值
  • 5 x:查询 x 的前驱
  • 6 x:查询 x 的后继

由于 Trie 中维护的是平移后的值,凡是输出实际数值时都需要减去 offset


十一、复杂度分析

设值域大小为 VV

每次操作都只需沿 Trie 从高位走到低位,因此时间复杂度为:

  • 插入:O(logV)O(\log V)
  • 删除:O(logV)O(\log V)
  • 查询排名:O(logV)O(\log V)
  • 查询第 kk 小:O(logV)O(\log V)
  • 查询前驱:O(logV)O(\log V)
  • 查询后继:O(logV)O(\log V)

在本题中,值域长度固定为 26 位,因此单次操作的常数较小。

空间复杂度方面,最坏情况下每插入一个新数都会新建 26 个节点。
若操作规模为 10510^5 级别,则需要开约:

26×105=2.6×106 26\times 10^5 = 2.6\times 10^6

个节点,因此代码中定义:

1
const int maxn = 2.6e6;

十二、完整代码及提交记录

点击展开/折叠 最终AC代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2.6e6, offset = 1e7;

int n, idx, trie[maxn][2], cnt[maxn];

void Insert(int x) {
int p = 0;
for (int i = 25; i >= 0; --i) {
int v = x >> i & 1;
if (!trie[p][v]) {
trie[p][v] = ++idx;
}
p = trie[p][v];
cnt[p]++;
}
}

void Delete(int x) {
int p = 0;
for (int i = 25; i >= 0; --i) {
int v = x >> i & 1;
p = trie[p][v];
cnt[p]--;
}
}

int getRank(int x) {
int p = 0, rank = 0;
for (int i = 25; i >= 0; --i) {
int v = x >> i & 1;
if (v) {
rank += cnt[trie[p][0]];
}
p = trie[p][v];
if (!p) {
break;
}
}
return rank;
}

int getVal(int x) {
int p = 0, val = 0;
for (int i = 25; i >= 0; --i) {
if (cnt[trie[p][0]] >= x) {
p = trie[p][0];
} else {
x -= cnt[trie[p][0]];
p = trie[p][1];
val |= 1 << i;
}
}
return val;
}

int getPrev(int x) {
return getVal(getRank(x));
}

int getNext(int x) {
return getVal(getRank(x + 1) + 1);
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n;
for (int i = 1; i <= n; i++) {
int op, x;
cin >> op >> x;
if (op == 1) {
Insert(x + offset);
} else if (op == 2) {
Delete(x + offset);
} else if (op == 3) {
cout << getRank(x + offset) + 1 << '\n';
} else if (op == 4) {
cout << getVal(x) - offset << '\n';
} else if (op == 5) {
cout << getPrev(x + offset) - offset << '\n';
} else {
cout << getNext(x + offset) - offset << '\n';
}
}
return 0;
}

可以发现 01trie 写法在性能上要大大优于常规写法。


结语

这份实现的关键在于两点:

  1. 利用 01Trie 的二进制字典序维护数值大小关系
  2. 利用 cnt 实现顺序统计,再由 rankkth 推导前驱、后继

在值域可控的前提下,这是一种实现简洁、复杂度稳定的替代方案。