概述
几何深度学习(Geometric Deep Learning, GDL)是 Michael Bronstein、Joan Bruna、Taco Cohen、Petar Veličković 等人推动的统一框架1,将 CNN、GNN、Transformer、DeepSets 等多种深度学习架构纳入对称性原理(symmetry principles)的统一视角。
GDL 的核心论断:几乎所有成功的深度学习架构都是从数据所在几何域的对称性出发的归纳偏置。理解这些对称性,就能理解为什么某些架构在特定任务上有效,并能系统设计新的架构。
本文档深入讲解 GDL 的数学基础(群作用、商空间、规范),五大几何域(网格、群、图、规范、齐次空间),以及实际架构的对称性分析。
1. 几何深度学习的动机
1.1 传统深度学习的局限
传统深度学习假设数据处于欧氏空间 :
# 传统 MLP
y = nn.Linear(d_in, d_out)(x) # 假设 x ∈ ℝ^n但许多实际数据并非如此:
- 图像:2D 网格(平移、旋转对称)
- 文本:序列(置换对称)
- 图:图结构(节点重排对称)
- 3D 点云:3D 空间(旋转、平移对称)
- 球面信号:球面上的函数(SO(3) 对称)
关键问题:如何为这些非欧氏数据设计有效的神经网络?
1.2 GDL 的回答
GDL 的核心思想:先识别数据的对称性,再设计等变(equivariant)架构。
# GDL 风格
class EquivariantLayer(nn.Module):
"""对对称变换 g 等变的层"""
def forward(self, x):
# f(g·x) = g·f(x)
return self.transform(x)2. 群论基础
2.1 群的基本定义
群(Group)是集合 与运算 满足:
- 封闭性:
- 结合律:
- 单位元:
- 逆元:
2.2 深度学习中的常见群
| 群 | 符号 | 阶 | 几何含义 |
|---|---|---|---|
| 循环群 | 阶旋转 | ||
| 平移群 | 无限 | 欧氏空间平移 | |
| 旋转群 | SO() | 无限 | 维旋转 |
| 欧氏群 | E() | 无限 | 平移 + 旋转 |
| 置换群 | 元素置换 | ||
| 对称群 | O() | 无限 | 旋转 + 反射 |
2.3 群作用
群作用(Group Action):群 在集合 上的作用是映射:
满足:
- (单位元作用)
- (结合律)
2.4 表示理论
群表示(Group Representation)是群到向量空间线性变换的同态:
其中 是表示空间, 是 上的可逆线性变换群。
核心要求:,。
2.5 等变与不变
定义:函数 关于群作用是对称的,如果存在:
不变性(Invariance)
的输出不随 变化。
等变性(Equivariance)
的输出按 变换。
关键洞察:等变性比不变性更强,保留了更多信息。
3. 五大几何域
Bronstein et al. 在 GDL 蓝皮书中提出五大几何域:
3.1 网格(Grids)
定义
网格是规则离散化的空间:
其中 是维度(图像 ,视频 )。
对称性
平移群 :
反射:
等变架构:CNN
卷积是平移等变的:
平移后:
设 :
结论:CNN 是平移等变的。
3.2 群(Groups)
定义
群本身也是几何域。例如 SO(3)(3D 旋转群)。
对称性
群乘法 是 SO(3) 的几何变换。
等变架构:Group CNN
处理群上的函数 ,使用群卷积:
例子:SO(3) 等变网络用于 3D 旋转不变的任务。
3.3 图(Graphs)
定义
图 是节点和边的集合。
对称性
节点置换群 :重排节点编号不影响图结构。
等变架构:GNN
GNN 是节点置换等变的:
# GNN 的置换等变性
def gnn_forward(x, adj, perm):
# perm: 重排索引
return permute(gnn_forward(x[perm], adj[perm][perm]), inverse(perm))3.4 齐次空间(Homogeneous Spaces)
定义
齐次空间是 ,其中 是 的子群。
例子:
- 球面 :固定旋转轴的旋转
- 双曲空间
对称性
在 上的作用:。
等变架构
球面 CNN、双曲神经网络等。
3.5 规范(Gauges)
定义
规范理论关注局部参考系的选择。
例子:
- 3D 形状上的每个点都有一个局部切空间
- 不同点的切空间选择不同的”规范”
等变架构
Gauge CNN、Tensor Field Networks 等。
4. GDL 的五大原理
4.1 原理 1:对称性原理
陈述:成功的神经网络架构利用了数据的对称性。
实践:
- CNN:平移对称 → 图像
- RNN:时间平移对称 → 序列
- GNN:置换对称 → 图
- DeepSets:置换对称 → 集合
- Transformer:置换对称 + 全连接 → 序列
4.2 原理 2:尺度分离
陈述:自然数据在不同尺度上有不同结构。
实践:
- CNN:多尺度池化
- GNN:K 跳邻居聚合
- Transformer:层级注意力
4.3 原理 3:等变性
陈述:架构应该对数据对称变换保持等变。
实践:
- 等变 CNN(对旋转)
- 等变 GNN(对置换)
- 等变 Transformer(对某些子群)
4.4 原理 4:几何先验
陈述:在数据所在的几何域上设计架构。
实践:
- 双曲神经网络(树状数据)
- 球面神经网络(球面数据)
- 流形神经网络(流形数据)
4.5 原理 5:不变性 vs 等变性
陈述:根据任务选择不变或等变。
实践:
- 不变性:分类(输出类别不随输入变换)
- 等变性:分割(输出与输入同步变换)
5. 等变网络的数学框架
5.1 等变层的构造
设输入空间 ,输出空间 ,群 同时作用于两者。
等变层 满足:
构造方法:
- 平均池化:(不变)
- 等变卷积:保留等变结构
- Steerable filters:使用等变基函数
5.2 等变卷积
定义:等变卷积 :
性质:。
5.3 Steerable Filter
对于旋转群 SO(2),等变滤波器是旋转的谐波函数:
其中 是 阶谐波。
关键性质:,是旋转等变的。
6. 等变 CNN
6.1 普通 CNN 的对称性
普通 CNN 对平移等变,对旋转不等变。
6.2 E(2) 等变 CNN
对 2D 欧氏变换(平移 + 旋转 + 反射)等变。
架构:
class E2Conv(nn.Module):
"""E(2) 等变卷积"""
def __init__(self, in_channels, out_channels, num_rotations=8):
super().__init__()
self.num_rotations = num_rotations
# 每个旋转角度的滤波器
self.filters = nn.Parameter(
torch.randn(num_rotations, in_channels, out_channels, 3, 3)
)
def forward(self, x):
"""
x: (B, C, H, W)
"""
# 旋转滤波器
rotated_filters = []
for r in range(self.num_rotations):
angle = 2 * np.pi * r / self.num_rotations
rotated = rotate_filter(self.filters[r], angle)
rotated_filters.append(rotated)
# 应用每个旋转的滤波器
outputs = []
for f in rotated_filters:
outputs.append(F.conv2d(x, f, padding=1))
# 拼接
return torch.cat(outputs, dim=1)6.3 Group Equivariant CNN (Cohen & Welling, 2016)
关键创新:直接在群 上做卷积,而不是在 上。
例子:对 C4 群(4 个旋转)的等变 CNN:
7. 等变 GNN
7.1 GNN 的对称性
普通 GNN 对节点置换等变。
7.2 高阶 GNN
考虑边(而不是节点)作为消息传递的基本单位:
这对应二阶 WL 测试。
7.3 等变 GNN 的形式化
设图 ,节点特征 。
置换等变 GNN:
其中 是置换。
8. Transformer 的 GDL 视角
8.1 Transformer 的对称性
Transformer 是置换等变的:
# 验证置换等变性
def is_permutation_equivariant(model, x, perm):
out1 = model(x)
out2 = model(x[perm])[inverse(perm)]
return torch.allclose(out1, out2, atol=1e-6)8.2 Transformer 的 GDL 分类
| 几何域 | Transformer 变体 |
|---|---|
| 序列(置换对称) | Vanilla Transformer |
| 网格(平移对称) | ViT |
| 图(节点置换对称) | Graph Transformer |
| 群(如 SO(3)) | Equivariant Transformer |
8.3 等变 Transformer
将等变约束融入注意力:
对旋转等变需要:
- 使用等变表示(标量 + 矢量)
- 注意力权重保持标量
- 输出通过等变运算组合
9. DeepSets 与点云
9.1 集合的对称性
集合 对元素置换不变。
9.2 DeepSets 架构
Zaheer et al. (NeurIPS 2017) 证明:
定理:任何置换不变函数 可表示为:
其中 是任意函数。
9.3 PointNet
PointNet (Qi et al., CVPR 2017) 是 DeepSets 的实例:
class PointNet(nn.Module):
"""PointNet: 点云分类"""
def __init__(self, num_classes):
super().__init__()
self.phi = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 1024)
)
self.rho = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, num_classes)
)
def forward(self, points):
"""
points: (B, N, 3)
"""
# 应用 phi 到每个点
h = self.phi(points) # (B, N, 1024)
# 全局池化(置换不变)
h_global = h.max(dim=1)[0] # (B, 1024)
# 应用 rho
return self.rho(h_global)10. 双曲几何与神经网络
10.1 双曲空间的动机
树状数据(如知识图谱、层次结构)在欧氏空间难以表示,但在双曲空间自然。
10.2 双曲神经网络
在双曲空间 (双曲面模型)中:
- 距离:
- 平行移动:双曲空间测地线的移动
class HyperbolicLinear(nn.Module):
"""双曲空间线性层"""
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
def forward(self, x):
"""
x: (B, N, D) 在双曲空间
"""
# 映射到切空间
x_tan = log_map(x)
# 线性变换
h = x_tan @ self.weight.T
# 映射回双曲空间
h_hyp = exp_map(h)
return h_hyp11. 流形学习与神经网络
11.1 流形假设
假设:高维数据实际位于低维流形上。
11.2 流形上的神经网络
class ManifoldLayer(nn.Module):
"""流形上的神经网络层"""
def __init__(self, manifold, in_features, out_features):
super().__init__()
self.manifold = manifold # 流形对象
self.weight = nn.Parameter(torch.randn(out_features, in_features))
def forward(self, x):
# 切空间操作
x_tan = self.manifold.log_map(x)
h = x_tan @ self.weight.T
# 映射回流形
return self.manifold.exp_map(h)11.3 常见流形
- 欧氏空间
- 球面
- 双曲空间
- Stiefel 流形:正交矩阵集合
- Grassmann 流形:子空间集合
12. 实践指南
12.1 如何选择几何域
| 数据类型 | 推荐几何域 | 推荐架构 |
|---|---|---|
| 图像 | 网格(平移) | CNN, ViT |
| 视频 | 网格+时间 | 3D CNN, Video Transformer |
| 3D 形状 | 群(SO(3)) | Equivariant Networks |
| 图 | 图(置换) | GNN, Graph Transformer |
| 集合 | 集合(置换) | DeepSets, PointNet |
| 球面信号 | 球面 | Spherical CNN |
| 树状数据 | 双曲空间 | Hyperbolic NN |
12.2 等变 vs 不变选择
选择等变:
- 需要保留空间信息(分割、检测)
- 多个输出,每个输出与输入相关
选择不变:
- 输出是标量(分类)
- 输出不依赖输入的空间变换
12.3 实现库
| 库 | 几何域 | 语言 |
|---|---|---|
| e3nn | E(3) 等变 | Python |
| escnn | E(2) 等变 | Python |
| geomstats | 黎曼几何 | Python |
| PyTorch Geometric | 图 | Python |
| dgl-ke | 知识图谱 | Python/C++ |
13. 局限与挑战
13.1 计算复杂度
等变网络通常比非等变网络慢(如等变 CNN 慢 4-10 倍)。
13.2 通用性挑战
并非所有数据都有清晰的对称性(如蛋白质结构)。
13.3 理论不完善
部分等变性证明依赖简化假设。
13.4 离散化的复杂性
连续群作用在离散数据上需要谨慎处理。
14. 未来展望
14.1 趋势 1:物理启发的架构
未来的神经网络可能直接融合物理对称性:
- 相对论等变性(洛伦兹群)
- 规范对称性
- 离散对称性(晶体群)
14.2 趋势 2:自适应几何
数据驱动的几何选择:
- 学习数据所在的几何域
- 动态调整等变约束
14.3 趋势 3:组合对称性
多种对称性的组合:
- 平移 + 旋转
- 置换 + 平移
14.4 趋势 4:应用扩展
GDL 在新领域的应用:
- 蛋白质设计(SE(3) 等变)
- 物理模拟(E(3) 等变)
- 量子化学(酉群等变)
15. 总结
15.1 GDL 的核心贡献
- 统一视角:将 CNN、GNN、Transformer 等纳入同一框架
- 设计原则:从对称性出发设计架构
- 形式化:严格的数学框架(群论、表示论)
- 实用价值:指导新架构的设计
15.2 五大几何域速查
| 域 | 对称性 | 架构 | 数据 |
|---|---|---|---|
| 网格 | 平移 | CNN, ViT | 图像、视频 |
| 群 | 群乘法 | Group CNN | 旋转数据 |
| 图 | 置换 | GNN | 社交网络 |
| 齐次 | 群/子群 | Spherical CNN | 球面信号 |
| 规范 | 局部规范 | Gauge NN | 流形数据 |
15.3 关键洞察
- 对称性是归纳偏置的核心
- 等变性比不变性更强
- 几何先验提升泛化
- 群论统一架构设计
参考
Footnotes
-
Bronstein, Bruna, Cohen, Veličković, “Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges”, arXiv 2104.13478, 2021 ↩