1. 引言
在现实世界中,不同机构往往持有同一批实体的不同特征,而非不同的样本。例如:
- 银行持有用户的信用特征(收入、负债)
- 电商持有用户的消费特征(购买记录、浏览历史)
- 医院持有用户的健康特征(体检数据、病史)
垂直联邦学习(Vertical Federated Learning, VFL)正是解决这种”特征分散、样本对齐”场景的技术。
2. 水平联邦 vs 垂直联邦
| 维度 | 水平联邦学习 (HFL) | 垂直联邦学习 (VFL) |
|---|---|---|
| 数据划分 | 样本划分 | 特征划分 |
| 样本重叠 | 无重叠 | 高度重叠 |
| 标签分布 | 标签分布在各客户端 | 标签可能在一方 |
| 典型场景 | 跨银行建模 | 跨行业联合建模 |
| 核心挑战 | 数据异构 | 特征对齐 |
水平联邦学习: 垂直联邦学习:
┌─────┬─────┬─────┐ ┌─────────┐
│ A B │ C D │ E F │ │ A │ B │ C │
├─────┼─────┼─────┤ ├─────────┤
│样本1│样本2│样本3│ │特征1│特征2│
│样本4│样本5│样本6│ │ │ │
└─────┴─────┴─────┘ └─────────┘
银行A 银行B 银行C 银行 电商
3. 垂直联邦学习的核心问题
3.1 样本对齐
不同机构需要找到共同的样本ID进行对齐,但不能直接共享ID:
def privacy_preserving_entity_alignment(parties, k=100):
"""
隐私保护实体对齐
基于secure set intersection
"""
# Step 1: 哈希ID
hashed_ids = [hash_ids(party.ids) for party in parties]
# Step 2: 排序和比较
sorted_hashes = [sorted(h) for h in hashed_ids]
# Step 3: 使用PSI(Private Set Intersection)进行对齐
aligned_ids = psi_protocol(sorted_hashes, k)
return aligned_ids3.2 标签分布
在VFL中,标签通常只存在于一方:
| 场景 | 标签持有方 | 挑战 |
|---|---|---|
| 信用评分 | 银行 | 标签不离开银行 |
| 广告点击 | 广告平台 | 特征来自多方 |
| 疾病预测 | 医院 | 需要安全的多方计算 |
3.3 隐私保护要求
- ID对齐时:不泄露各方拥有的用户列表
- 模型训练时:不泄露各方的特征和标签
- 模型推理时:不泄露中间结果
4. VFL的模型架构
4.1 基本架构
class VerticalFederatedLearning:
"""
垂直联邦学习基本架构
"""
def __init__(self, parties, label_holder):
"""
Args:
parties: 各参与方
label_holder: 持有标签的一方
"""
self.parties = parties
self.label_holder = label_holder
# 为每个参与方创建特征提取器
self.encoders = {
party.id: self.create_encoder(party.feature_dim)
for party in parties
}
# 标签持有方创建预测头
self.aggregator = self.create_aggregator()
self.predictor = self.create_predictor()
def forward(self, batch):
"""
前向传播
"""
# Step 1: 各方本地计算embedding
embeddings = {}
for party in self.parties:
embeddings[party.id] = party.compute_embedding(batch[party.id])
# Step 2: 安全聚合embedding
aggregated = self.secure_aggregate(embeddings)
# Step 3: 标签持有方进行预测
if self.is_label_holder:
logits = self.predictor(aggregated)
return logits
return None
def secure_aggregate(self, embeddings):
"""
安全聚合(不暴露各方embedding)
"""
# 使用秘密分享或同态加密
return homomorphic_sum(embeddings)4.2 联邦嵌入学习
class FederatedEmbeddingLearning:
"""
联邦嵌入学习:各参与方学习共享的嵌入空间
"""
def __init__(self, embedding_dim):
self.embedding_dim = embedding_dim
def party_forward(self, party_id, features):
"""
单个参与方的前向传播
"""
# 本地特征编码
encoding = self.local_encoder[party_id](features)
# 投影到共享嵌入空间
embedding = self.projection[party_id](encoding)
return embedding
def aggregate_embeddings(self, embeddings):
"""
聚合各方的embedding
"""
# 加权平均
return sum(embeddings) / len(embeddings)
def training_step(self, batch):
"""
训练步骤
"""
# 各方本地计算embedding
embeddings = {
pid: self.party_forward(pid, batch[pid])
for pid in self.parties
}
# 聚合
agg_embedding = self.aggregate_embeddings(embeddings)
# 标签持有方计算损失
logits = self.predictor(agg_embedding)
loss = self.loss_fn(logits, batch['labels'])
# 反向传播
grad_predictor = torch.autograd.grad(loss, self.predictor.parameters())
grad_embedding = torch.autograd.grad(
loss,
[self.projection[pid] for pid in self.parties]
)
# 分发梯度给各参与方
for pid in self.parties:
self.send_gradient_to_party(pid, grad_embedding[pid])
return loss5. 隐私保护技术
5.1 安全多方计算
class SecureVFL:
"""
基于安全多方计算的VFL
"""
def __init__(self, parties):
self.parties = parties
def secure_dot_product(self, vec_a, vec_b, party_a, party_b):
"""
安全点积:双方计算 dot(a, b) 而不泄露各自的值
"""
# 使用不经意传输
# Party A: 生成随机向量r
r = torch.randn_like(vec_a)
# Party A: 发送 vec_a + r 给 Party B
masked_a = vec_a + r
send_to_party_b(masked_a)
# Party B: 计算 (vec_a + r) · vec_b = vec_a · vec_b + r · vec_b
partial_result = torch.dot(masked_a, vec_b)
# Party B: 发送 r · vec_b 给 Party A
r_dot_b = torch.dot(r, vec_b)
send_to_party_a(r_dot_b)
# Party A: 计算 final = partial - r · vec_b = vec_a · vec_b
final = partial_result - r_dot_b
return final
def secure_aggregation(self, embeddings):
"""
安全聚合
"""
# 树形聚合结构
# ...
return aggregated5.2 同态加密
class HomomorphicVFL:
"""
基于同态加密的VFL
"""
def __init__(self, encryption_key):
self.pk = encryption_key
self.encrypted = {} # 加密的embedding
def encrypt_embedding(self, party_id, embedding):
"""
加密参与方的embedding
"""
encrypted = paillier.encrypt(embedding, self.pk)
self.encrypted[party_id] = encrypted
return encrypted
def aggregate_encrypted(self):
"""
在密文空间聚合
"""
# 同态加法
result = self.encrypted[self.parties[0]]
for pid in self.parties[1:]:
result = result + self.encrypted[pid]
return result
def decrypt_result(self, encrypted_result, secret_key):
"""
解密最终结果
"""
return paillier.decrypt(encrypted_result, secret_key)5.3 联邦学习中的DNN
class FederatedDeepNeuralNetwork:
"""
垂直联邦深度神经网络
"""
def __init__(self, party_configs, label_dim):
# 各参与方的本地网络
self.local_networks = nn.ModuleDict({
pid: nn.Sequential(
nn.Linear(config['input_dim'], config['hidden_dim']),
nn.ReLU(),
nn.Linear(config['hidden_dim'], config['output_dim'])
)
for pid, config in party_configs.items()
})
# 聚合器
self.aggregator = nn.Linear(
sum(config['output_dim'] for config in party_configs.values()),
256
)
# 预测头
self.predictor = nn.Linear(256, label_dim)
def party_forward(self, party_id, features):
"""参与方本地前向传播"""
return self.local_networks[party_id](features)
def secure_inference(self, party_embeddings):
"""
安全推理:聚合各方embedding后进行预测
"""
# 聚合
concat = torch.cat(list(party_embeddings.values()), dim=-1)
aggregated = self.aggregator(concat)
# 预测(需要安全计算)
return self.predictor(aggregated)6. 联邦学习算法
6.1 联邦逻辑回归
class FederatedLogisticRegression:
"""
垂直联邦逻辑回归
"""
def __init__(self, n_features, n_parties):
self.n_parties = n_parties
# 特征划分
self.feature_splits = self.split_features(n_features, n_parties)
# 模型参数
self.weights = {
pid: nn.Parameter(torch.randn(split_size))
for pid, split_size in self.feature_splits.items()
}
self.bias = nn.Parameter(torch.zeros(1))
def split_features(self, n_features, n_parties):
"""将特征分配给各参与方"""
split_size = n_features // n_parties
splits = {}
start = 0
for i in range(n_parties):
end = start + split_size if i < n_parties - 1 else n_features
splits[i] = end - start
start = end
return splits
def party_compute(self, party_id, features):
"""
单个参与方计算本地部分
"""
w = self.weights[party_id]
x = features[:, self.feature_splits[party_id]:]
return torch.dot(x, w) # 局部激活
def secure_aggregate(self, local_activations):
"""
安全聚合各方激活
"""
# 各方本地激活之和
return sum(local_activations) + self.bias
def sigmoid(self, z):
"""sigmoid激活(标签持有方执行)"""
return 1 / (1 + torch.exp(-z))
def forward(self, party_features, party_ids):
"""
完整前向传播
"""
# Step 1: 各方计算本地激活
local_activations = {
pid: self.party_compute(pid, party_features[pid])
for pid in party_ids
}
# Step 2: 安全聚合
z = self.secure_aggregate(local_activations)
# Step 3: 标签持有方计算sigmoid
if self.is_label_holder:
return self.sigmoid(z)
return z
def backward(self, party_features, party_ids, labels):
"""
反向传播
"""
# 前向传播
predictions = self.forward(party_features, party_ids)
# 计算损失
loss = F.binary_cross_entropy(predictions, labels)
# 反向传播
# ...
# 梯度分发给各方
gradients = {
pid: self.compute_local_gradient(pid, party_features[pid])
for pid in party_ids
}
return gradients6.2 联邦XGBoost
class FederatedXGBoost:
"""
垂直联邦XGBoost
"""
def __init__(self, n_trees=100, max_depth=6):
self.n_trees = n_trees
self.max_depth = max_depth
self.trees = []
def fit(self, party_data, label_party):
"""
训练联邦XGBoost
"""
# 各方计算统计量
histograms = self.compute_local_histograms(party_data)
# 安全聚合统计量
aggregated = self.secure_aggregate(histograms)
# 构建树
tree = self.build_tree(aggregated)
self.trees.append(tree)
def compute_local_histograms(self, party_data):
"""
各方计算本地直方图
"""
histograms = {}
for party_id, data in party_data.items():
# 计算一阶和二阶梯度统计量
g = self.compute_gradient(data, labels) # 一阶
h = self.compute_hessian(data, labels) # 二阶
histograms[party_id] = {
'g_sum': g.sum(),
'h_sum': h.sum(),
'counts': len(data)
}
return histograms
def secure_aggregate(self, histograms):
"""
安全聚合各方统计量
"""
# 同态加密或秘密分享
total_g = sum(h['g_sum'] for h in histograms.values())
total_h = sum(h['h_sum'] for h in histograms.values())
return {'total_g': total_g, 'total_h': total_h}
def build_tree(self, aggregated_stats):
"""
构建决策树
"""
# 使用聚合的统计量构建树
# ...
pass7. 实际应用
7.1 金融风控
class CreditScoringVFL:
"""
信用评分垂直联邦学习
"""
def __init__(self):
# 银行:收入、负债、信用历史
self.bank = BankFeatures()
# 电商:消费记录、浏览历史
self.ecommerce = EcommerceFeatures()
# 运营商:通话记录、社交网络
self.telecom = TelecomFeatures()
# 标签持有方:银行持有信用标签
self.label_holder = self.bank
# 初始化VFL模型
self.model = VerticalFederatedLearning(
parties=[self.bank, self.ecommerce, self.telecom],
label_holder=self.label_holder
)
def train(self, n_rounds):
"""
联邦训练
"""
for round_id in range(n_rounds):
# 样本对齐(隐私保护)
aligned_samples = self.privacy_preserving_alignment()
# 各方本地计算
local_outputs = self.compute_local_outputs(aligned_samples)
# 安全聚合
prediction = self.secure_aggregate(local_outputs)
# 标签持有方计算损失并分发梯度
loss = self.compute_loss(prediction, aligned_samples.labels)
gradients = self.backward(loss)
# 分发梯度
self.distribute_gradients(gradients)
print(f"Round {round_id}: Loss = {loss.item():.4f}")7.2 医疗诊断
class MedicalDiagnosisVFL:
"""
医疗诊断垂直联邦学习
"""
def __init__(self):
# 医院A:影像特征
self.hospital_a = ImageFeatures()
# 医院B:检验报告
self.hospital_b = LabResults()
# 医院C:病史记录
self.hospital_c = MedicalHistory()
# 标签持有方
self.label_holder = self.hospital_a
# 初始化模型
self.model = VerticalFederatedLearning(...)
def predict(self, patient_data):
"""
联邦推理
"""
# 各方加密特征
encrypted_features = {
party_id: encrypt(feature)
for party_id, feature in patient_data.items()
}
# 安全聚合
encrypted_agg = self.secure_aggregate(encrypted_features)
# 解密并预测
prediction = self.model.predict(encrypted_agg)
return prediction8. 参考文献
9. 相关主题
- federated-learning-fundamentals — 联邦学习基础
- federated-learning-llm-finetuning — 联邦LLM微调
- federated-learning-privacy-dp — 差分隐私保护