概述
领域适应(Domain Adaptation)是迁移学习的一个重要分支,旨在利用源领域(Source Domain)的标注数据来提升模型在目标领域(Target Domain)上的性能。1
问题定义
领域 由两部分组成:
- 特征空间
- 边缘概率分布
任务 由标签空间 和预测函数 组成。
在领域适应中:
- 源领域:,有标签
- 目标领域:,可能有少量或无标签
目标:学习一个预测器 在 上表现良好。
领域适应的类型
| 类型 | 源标签 | 目标标签 | 说明 |
|---|---|---|---|
| 监督领域适应 | 有 | 有(少量) | 目标领域有少量标注数据 |
| 半监督领域适应 | 有 | 部分有 | 目标领域有少量标注数据 |
| 无监督领域适应 | 有 | 无 | 只有目标领域的无标签数据 |
领域适应的理论框架
泛化误差分解
设 是从源领域学到的分类器, 是最优目标分类器。目标误差:
可以分解为:
领域偏移的来源
边缘分布偏移(Covariate Shift):
条件分布偏移(Conditional Shift):
联合分布偏移:
理论边界
基于 -散度的领域适应边界(Ben-David et al., 2010):
其中 是 -散度,衡量两个领域在假设空间 下的差异。
分布差异度量
-散度
-散度定义为:
在实际中,通过有限样本估计:
对抗散度(Adversarial Divergence)
基于分类器的散度估计:
import torch
import torch.nn as nn
def compute_adversarial_divergence(source_features, target_features, discriminator):
"""
计算对抗散度
使用域判别器估计领域差异
"""
# 源域标签为0,目标域标签为1
source_labels = torch.zeros(len(source_features))
target_labels = torch.ones(len(target_features))
# 判别器预测
source_pred = discriminator(source_features)
target_pred = discriminator(target_features)
# 计算散度
divergence = torch.mean(torch.log(source_pred + 1e-8)) + \
torch.mean(torch.log(1 - target_pred + 1e-8))
return -divergence # 负散度用于最小化最大均值差异(MMD)
MMD(Maximum Mean Discrepancy)通过比较两个分布的核均值嵌入来衡量差异:
def mmd_linear(source, target, kernel='rbf', sigma=1.0):
"""
最大均值差异(MMD)
"""
def rbf_kernel(x, y, sigma):
diff = x.unsqueeze(1) - y.unsqueeze(0)
return torch.exp(-torch.sum(diff ** 2, dim=-1) / (2 * sigma ** 2))
def linear_kernel(x, y):
return torch.mm(x, y.T)
if kernel == 'rbf':
k = rbf_kernel
else:
k = linear_kernel
# 计算各项
xx = k(source, source).mean() # E[k(x_s, x_s')]
yy = k(target, target).mean() # E[k(x_t, x_t')]
xy = k(source, target).mean() # E[k(x_s, x_t')]
return xx + yy - 2 * xyWasserstein距离
使用最优传输理论衡量分布差异:
from scipy.stats import wasserstein_distance
def compute_wasserstein(source, target):
"""
计算一维Wasserstein距离
"""
return wasserstein_distance(source, target)
def compute_wasserstein_nd(source, target, n_projections=100):
"""
计算多维Wasserstein距离的近似
"""
n_dims = source.shape[1]
distances = []
for _ in range(n_projections):
# 随机投影方向
direction = np.random.randn(n_dims)
direction = direction / np.linalg.norm(direction)
# 投影
source_proj = source @ direction
target_proj = target @ direction
# 计算一维Wasserstein距离
dist = wasserstein_distance(source_proj, target_proj)
distances.append(dist)
return np.mean(distances)主要方法
分布匹配方法
核方法
通过特征变换使源域和目标域的分布匹配:
class KernelFDA:
"""
基于核的领域适应(Kernel Fisher Discriminant Analysis)
"""
def __init__(self, kernel='rbf', gamma=1.0):
self.kernel = kernel
self.gamma = gamma
def fit_transform(self, source, target):
"""
学习域不变的表示
"""
# 合并数据
X = np.vstack([source, target])
n_source = len(source)
n_total = len(X)
# 构建域分类器
domain_labels = np.zeros(n_total)
domain_labels[n_source:] = 1
# 训练域分类器并提取域不变特征
# ... (简化实现)
return self.transform(X)
def transform(self, X):
"""
特征变换
"""
# 返回域不变特征
return self.kernel_transform(X)样本加权
通过重新加权源样本来修正分布偏移:
class CovariateShiftAdaptation:
"""
协变量偏移适应
"""
def __init__(self):
self.ratio_estimator = None
def estimate_density_ratio(self, X_source, X_target):
"""
估计密度比 w(x) = P_t(x) / P_s(x)
使用分类器方法
"""
# 构建二分类数据集
X = np.vstack([X_source, X_target])
y = np.hstack([np.zeros(len(X_source)), np.ones(len(X_target))])
# 训练分类器
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(X, y)
# 密度比估计
prob_target = clf.predict_proba(X_target)[:, 1]
prob_source = clf.predict_proba(X_source)[:, 0]
# 重要性权重
weights = (1 - prob_source) / (prob_target + 1e-8)
weights = np.clip(weights, 0.1, 10) # 裁剪极端权重
return weights
def fit_weighted_classifier(self, X_source, y_source, weights):
"""
使用加权样本训练分类器
"""
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(X_source, y_source, sample_weight=weights)
return clf对抗性领域适应
DANN(Domain-Adversarial Neural Network)
核心思想:训练一个域不变的特征提取器,同时欺骗域判别器。
class DomainAdversarialNetwork(nn.Module):
"""
DANN: 域对抗神经网络
"""
def __init__(self, input_dim, num_classes, hidden_dim=256):
super().__init__()
# 特征提取器
self.feature_extractor = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# 标签分类器
self.label_classifier = nn.Sequential(
nn.Linear(hidden_dim, num_classes)
)
# 域判别器
self.domain_classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, 2)
)
def forward(self, x, alpha=1.0):
"""
前向传播
参数:
x: 输入特征
alpha: 梯度反转强度
"""
# 提取特征
features = self.feature_extractor(x)
# 标签预测
class_logits = self.label_classifier(features)
# 域预测(梯度反转)
# 使用 GradientReversalLayer 实现
domain_logits = gradient_reverse(features, alpha)
domain_logits = self.domain_classifier(domain_logits)
return class_logits, domain_logits, features
class GradientReversalFunction(torch.autograd.Function):
"""
梯度反转函数
"""
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.neg() * ctx.alpha, None
def gradient_reverse(x, alpha):
return GradientReversalFunction.apply(x, alpha)
def dann_loss(class_logits, domain_logits, y_source, d_source, alpha=1.0):
"""
DANN损失函数
"""
# 分类损失
cls_loss = nn.CrossEntropyLoss()(class_logits, y_source)
# 域损失(域判别器试图区分源域和目标域)
# 特征提取器试图欺骗域判别器
dom_loss = nn.CrossEntropyLoss()(domain_logits, d_source)
# 总损失
total_loss = cls_loss + alpha * dom_loss
return total_loss, cls_loss, dom_lossADDA(Adversarial Discriminative Domain Adaptation)
ADDA使用非对称的特征变换:
class ADDA:
"""
ADDA: 对抗性判别领域适应
"""
def __init__(self, source_encoder, target_encoder, discriminator):
self.source_encoder = source_encoder
self.target_encoder = target_encoder
self.discriminator = discriminator
def train(self, X_source, y_source, X_target, n_iterations=10000):
"""
训练过程
"""
# 阶段1:预训练源编码器
self.pretrain_source(X_source, y_source)
# 阶段2:对抗适应
for iteration in range(n_iterations):
# 更新目标编码器 + 判别器
self.update_target_and_discriminator(X_source, X_target)
# 冻结目标编码器,更新判别器
self.update_discriminator_only(X_source, X_target)
def update_target_and_discriminator(self, X_source, X_target):
"""
更新目标编码器和判别器
"""
# 源特征(固定)
source_features = self.source_encoder(X_source).detach()
# 目标特征(优化)
target_features = self.target_encoder(X_target)
# 判别器预测
source_pred = self.discriminator(source_features)
target_pred = self.discriminator(target_features)
# 判别器损失
d_loss = adversarial_loss(source_pred, target_pred)
# 更新
d_loss.backward()
self.discriminator_optimizer.step()
self.target_encoder_optimizer.step()表示学习方法
DeepCORAL(Correlation Alignment)
对齐源域和目标域的二阶统计量:
def coral_loss(source, target):
"""
CORAL损失:对齐协方差矩阵
CORAL loss = (1/4d²) ||C_s - C_t||_F²
"""
d = source.shape[1]
# 中心化
source = source - source.mean(dim=0, keepdim=True)
target = target - target.mean(dim=0, keepdim=True)
# 协方差矩阵
C_s = torch.mm(source.T, source) / (source.size(0) - 1)
C_t = torch.mm(target.T, target) / (target.size(0) - 1)
# Frobenius范数的平方
loss = torch.norm(C_s - C_t, p='fro') ** 2
loss = loss / (4 * d * d)
return lossMMD方法
class MMDLoss(nn.Module):
"""
最大均值差异损失
"""
def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
super().__init__()
self.kernel_type = kernel_type
self.kernel_mul = kernel_mul
self.kernel_num = kernel_num
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
"""
计算高斯核
"""
n_samples = source.size(0) + target.size(0)
total = torch.cat([source, target], dim=0)
# 计算L2距离矩阵
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0 - total1) ** 2).sum(2)
# 计算带宽
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance) / (n_samples ** 2 - n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
# 高斯核
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def forward(self, source, target):
batch_size = int(source.size(0))
kernels = self.guassian_kernel(source, target,
kernel_mul=self.kernel_mul,
kernel_num=self.kernel_num)
XX = kernels[:batch_size, :batch_size]
YY = kernels[batch_size:, batch_size:]
XY = kernels[:batch_size, batch_size:]
YX = kernels[batch_size:, :batch_size]
loss = torch.mean(XX + YY - XY - YX)
return loss领域适应的方法分类
基于差异的方法
| 方法 | 损失项 | 特点 |
|---|---|---|
| CORAL | 二阶统计量对齐 | 简单高效 |
| MMD | 核均值匹配 | 理论基础强 |
| CMD | 中央矩差分 | 高阶统计量 |
| Wasserstein | 最优传输 | 几何解释 |
对抗性方法
| 方法 | 对抗目标 | 网络结构 |
|---|---|---|
| DANN | 欺骗域判别器 | 共享编码器 |
| ADDA | 区分源/目标表示 | 非对称编码器 |
| Co-DA | 多域对齐 | 协作判别 |
| CDAN | 条件对抗 | 条件判别器 |
自监督方法
class SelfSupervisedDA:
"""
自监督领域适应
"""
def __init__(self, backbone, num_clusters):
self.backbone = backbone
self.num_clusters = num_clusters
def train(self, X_source, y_source, X_target):
"""
训练过程
"""
# 源域监督学习
source_features = self.backbone(X_source)
cls_loss = self.classification_loss(source_features, y_source)
# 目标域伪标签
pseudo_labels = self.cluster_and_label(X_target)
# 自监督一致性
consistency_loss = self.consistency_loss(X_target, pseudo_labels)
return cls_loss + consistency_loss
def cluster_and_label(self, X_target):
"""
聚类并生成伪标签
"""
target_features = self.backbone(X_target)
# K-means聚类
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=self.num_clusters)
pseudo_labels = kmeans.fit_predict(target_features.detach().numpy())
return torch.tensor(pseudo_labels)
def consistency_loss(self, X_target, pseudo_labels):
"""
一致性损失
"""
# 对目标样本应用增强
X_aug1 = self.augment(X_target)
X_aug2 = self.augment(X_target)
# 提取特征
f1 = self.backbone(X_aug1)
f2 = self.backbone(X_aug2)
# 一致性损失
return torch.mean((f1 - f2) ** 2)实践指南
方法选择
| 场景 | 推荐方法 |
|---|---|
| 小数据集 | MMD, CORAL |
| 大数据集 | DANN, ADDA |
| 边缘分布偏移 | 样本加权 |
| 条件分布偏移 | 条件对抗 |
| 计算资源有限 | CORAL |
超参数设置
class DomainAdaptationConfig:
"""
领域适应超参数配置
"""
default_config = {
# 分布对齐
'mmd_sigma': 1.0,
'coral_weight': 1.0,
# 对抗训练
'dann_alpha': 1.0,
'lr_discriminator': 0.001,
'lr_encoder': 0.001,
# 伪标签
'pseudo_threshold': 0.9,
'pseudo_confidence_weight': 0.5,
# 正则化
'weight_decay': 1e-4,
'dropout': 0.5
}参考
相关链接
Footnotes
-
Pan, S. J., & Yang, Q. (2009). A Survey on Transfer Learning. IEEE Transactions on Knowledge and Data Engineering. ↩