概述

领域适应(Domain Adaptation)是迁移学习的一个重要分支,旨在利用源领域(Source Domain)的标注数据来提升模型在目标领域(Target Domain)上的性能。1

问题定义

领域 由两部分组成:

  • 特征空间
  • 边缘概率分布

任务 由标签空间 和预测函数 组成。

在领域适应中:

  • 源领域,有标签
  • 目标领域,可能有少量或无标签

目标:学习一个预测器 上表现良好。

领域适应的类型

类型源标签目标标签说明
监督领域适应有(少量)目标领域有少量标注数据
半监督领域适应部分有目标领域有少量标注数据
无监督领域适应只有目标领域的无标签数据

领域适应的理论框架

泛化误差分解

是从源领域学到的分类器, 是最优目标分类器。目标误差:

可以分解为:

领域偏移的来源

边缘分布偏移(Covariate Shift):

条件分布偏移(Conditional Shift):

联合分布偏移

理论边界

基于 -散度的领域适应边界(Ben-David et al., 2010):

其中 -散度,衡量两个领域在假设空间 下的差异。


分布差异度量

-散度

-散度定义为:

在实际中,通过有限样本估计:

对抗散度(Adversarial Divergence)

基于分类器的散度估计:

import torch
import torch.nn as nn
 
def compute_adversarial_divergence(source_features, target_features, discriminator):
    """
    计算对抗散度
    
    使用域判别器估计领域差异
    """
    # 源域标签为0,目标域标签为1
    source_labels = torch.zeros(len(source_features))
    target_labels = torch.ones(len(target_features))
    
    # 判别器预测
    source_pred = discriminator(source_features)
    target_pred = discriminator(target_features)
    
    # 计算散度
    divergence = torch.mean(torch.log(source_pred + 1e-8)) + \
                torch.mean(torch.log(1 - target_pred + 1e-8))
    
    return -divergence  # 负散度用于最小化

最大均值差异(MMD)

MMD(Maximum Mean Discrepancy)通过比较两个分布的核均值嵌入来衡量差异:

def mmd_linear(source, target, kernel='rbf', sigma=1.0):
    """
    最大均值差异(MMD)
    """
    def rbf_kernel(x, y, sigma):
        diff = x.unsqueeze(1) - y.unsqueeze(0)
        return torch.exp(-torch.sum(diff ** 2, dim=-1) / (2 * sigma ** 2))
    
    def linear_kernel(x, y):
        return torch.mm(x, y.T)
    
    if kernel == 'rbf':
        k = rbf_kernel
    else:
        k = linear_kernel
    
    # 计算各项
    xx = k(source, source).mean()  # E[k(x_s, x_s')]
    yy = k(target, target).mean()  # E[k(x_t, x_t')]
    xy = k(source, target).mean()  # E[k(x_s, x_t')]
    
    return xx + yy - 2 * xy

Wasserstein距离

使用最优传输理论衡量分布差异:

from scipy.stats import wasserstein_distance
 
def compute_wasserstein(source, target):
    """
    计算一维Wasserstein距离
    """
    return wasserstein_distance(source, target)
 
def compute_wasserstein_nd(source, target, n_projections=100):
    """
    计算多维Wasserstein距离的近似
    """
    n_dims = source.shape[1]
    distances = []
    
    for _ in range(n_projections):
        # 随机投影方向
        direction = np.random.randn(n_dims)
        direction = direction / np.linalg.norm(direction)
        
        # 投影
        source_proj = source @ direction
        target_proj = target @ direction
        
        # 计算一维Wasserstein距离
        dist = wasserstein_distance(source_proj, target_proj)
        distances.append(dist)
    
    return np.mean(distances)

主要方法

分布匹配方法

核方法

通过特征变换使源域和目标域的分布匹配:

class KernelFDA:
    """
    基于核的领域适应(Kernel Fisher Discriminant Analysis)
    """
    def __init__(self, kernel='rbf', gamma=1.0):
        self.kernel = kernel
        self.gamma = gamma
    
    def fit_transform(self, source, target):
        """
        学习域不变的表示
        """
        # 合并数据
        X = np.vstack([source, target])
        n_source = len(source)
        n_total = len(X)
        
        # 构建域分类器
        domain_labels = np.zeros(n_total)
        domain_labels[n_source:] = 1
        
        # 训练域分类器并提取域不变特征
        # ... (简化实现)
        
        return self.transform(X)
    
    def transform(self, X):
        """
        特征变换
        """
        # 返回域不变特征
        return self.kernel_transform(X)

样本加权

通过重新加权源样本来修正分布偏移:

class CovariateShiftAdaptation:
    """
    协变量偏移适应
    """
    def __init__(self):
        self.ratio_estimator = None
    
    def estimate_density_ratio(self, X_source, X_target):
        """
        估计密度比 w(x) = P_t(x) / P_s(x)
        
        使用分类器方法
        """
        # 构建二分类数据集
        X = np.vstack([X_source, X_target])
        y = np.hstack([np.zeros(len(X_source)), np.ones(len(X_target))])
        
        # 训练分类器
        from sklearn.linear_model import LogisticRegression
        clf = LogisticRegression()
        clf.fit(X, y)
        
        # 密度比估计
        prob_target = clf.predict_proba(X_target)[:, 1]
        prob_source = clf.predict_proba(X_source)[:, 0]
        
        # 重要性权重
        weights = (1 - prob_source) / (prob_target + 1e-8)
        weights = np.clip(weights, 0.1, 10)  # 裁剪极端权重
        
        return weights
    
    def fit_weighted_classifier(self, X_source, y_source, weights):
        """
        使用加权样本训练分类器
        """
        from sklearn.linear_model import LogisticRegression
        clf = LogisticRegression()
        clf.fit(X_source, y_source, sample_weight=weights)
        return clf

对抗性领域适应

DANN(Domain-Adversarial Neural Network)

核心思想:训练一个域不变的特征提取器,同时欺骗域判别器。

class DomainAdversarialNetwork(nn.Module):
    """
    DANN: 域对抗神经网络
    """
    def __init__(self, input_dim, num_classes, hidden_dim=256):
        super().__init__()
        
        # 特征提取器
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 标签分类器
        self.label_classifier = nn.Sequential(
            nn.Linear(hidden_dim, num_classes)
        )
        
        # 域判别器
        self.domain_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 2)
        )
    
    def forward(self, x, alpha=1.0):
        """
        前向传播
        
        参数:
            x: 输入特征
            alpha: 梯度反转强度
        """
        # 提取特征
        features = self.feature_extractor(x)
        
        # 标签预测
        class_logits = self.label_classifier(features)
        
        # 域预测(梯度反转)
        # 使用 GradientReversalLayer 实现
        domain_logits = gradient_reverse(features, alpha)
        domain_logits = self.domain_classifier(domain_logits)
        
        return class_logits, domain_logits, features
 
class GradientReversalFunction(torch.autograd.Function):
    """
    梯度反转函数
    """
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None
 
def gradient_reverse(x, alpha):
    return GradientReversalFunction.apply(x, alpha)
 
def dann_loss(class_logits, domain_logits, y_source, d_source, alpha=1.0):
    """
    DANN损失函数
    """
    # 分类损失
    cls_loss = nn.CrossEntropyLoss()(class_logits, y_source)
    
    # 域损失(域判别器试图区分源域和目标域)
    # 特征提取器试图欺骗域判别器
    dom_loss = nn.CrossEntropyLoss()(domain_logits, d_source)
    
    # 总损失
    total_loss = cls_loss + alpha * dom_loss
    
    return total_loss, cls_loss, dom_loss

ADDA(Adversarial Discriminative Domain Adaptation)

ADDA使用非对称的特征变换:

class ADDA:
    """
    ADDA: 对抗性判别领域适应
    """
    def __init__(self, source_encoder, target_encoder, discriminator):
        self.source_encoder = source_encoder
        self.target_encoder = target_encoder
        self.discriminator = discriminator
    
    def train(self, X_source, y_source, X_target, n_iterations=10000):
        """
        训练过程
        """
        # 阶段1:预训练源编码器
        self.pretrain_source(X_source, y_source)
        
        # 阶段2:对抗适应
        for iteration in range(n_iterations):
            # 更新目标编码器 + 判别器
            self.update_target_and_discriminator(X_source, X_target)
            
            # 冻结目标编码器,更新判别器
            self.update_discriminator_only(X_source, X_target)
    
    def update_target_and_discriminator(self, X_source, X_target):
        """
        更新目标编码器和判别器
        """
        # 源特征(固定)
        source_features = self.source_encoder(X_source).detach()
        
        # 目标特征(优化)
        target_features = self.target_encoder(X_target)
        
        # 判别器预测
        source_pred = self.discriminator(source_features)
        target_pred = self.discriminator(target_features)
        
        # 判别器损失
        d_loss = adversarial_loss(source_pred, target_pred)
        
        # 更新
        d_loss.backward()
        self.discriminator_optimizer.step()
        self.target_encoder_optimizer.step()

表示学习方法

DeepCORAL(Correlation Alignment)

对齐源域和目标域的二阶统计量:

def coral_loss(source, target):
    """
    CORAL损失:对齐协方差矩阵
    
    CORAL loss = (1/4d²) ||C_s - C_t||_F²
    """
    d = source.shape[1]
    
    # 中心化
    source = source - source.mean(dim=0, keepdim=True)
    target = target - target.mean(dim=0, keepdim=True)
    
    # 协方差矩阵
    C_s = torch.mm(source.T, source) / (source.size(0) - 1)
    C_t = torch.mm(target.T, target) / (target.size(0) - 1)
    
    # Frobenius范数的平方
    loss = torch.norm(C_s - C_t, p='fro') ** 2
    loss = loss / (4 * d * d)
    
    return loss

MMD方法

class MMDLoss(nn.Module):
    """
    最大均值差异损失
    """
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
        super().__init__()
        self.kernel_type = kernel_type
        self.kernel_mul = kernel_mul
        self.kernel_num = kernel_num
    
    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        """
        计算高斯核
        """
        n_samples = source.size(0) + target.size(0)
        total = torch.cat([source, target], dim=0)
        
        # 计算L2距离矩阵
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0 - total1) ** 2).sum(2)
        
        # 计算带宽
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance) / (n_samples ** 2 - n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
        
        # 高斯核
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)
    
    def forward(self, source, target):
        batch_size = int(source.size(0))
        kernels = self.guassian_kernel(source, target, 
                                        kernel_mul=self.kernel_mul,
                                        kernel_num=self.kernel_num)
        
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        
        loss = torch.mean(XX + YY - XY - YX)
        return loss

领域适应的方法分类

基于差异的方法

方法损失项特点
CORAL二阶统计量对齐简单高效
MMD核均值匹配理论基础强
CMD中央矩差分高阶统计量
Wasserstein最优传输几何解释

对抗性方法

方法对抗目标网络结构
DANN欺骗域判别器共享编码器
ADDA区分源/目标表示非对称编码器
Co-DA多域对齐协作判别
CDAN条件对抗条件判别器

自监督方法

class SelfSupervisedDA:
    """
    自监督领域适应
    """
    def __init__(self, backbone, num_clusters):
        self.backbone = backbone
        self.num_clusters = num_clusters
    
    def train(self, X_source, y_source, X_target):
        """
        训练过程
        """
        # 源域监督学习
        source_features = self.backbone(X_source)
        cls_loss = self.classification_loss(source_features, y_source)
        
        # 目标域伪标签
        pseudo_labels = self.cluster_and_label(X_target)
        
        # 自监督一致性
        consistency_loss = self.consistency_loss(X_target, pseudo_labels)
        
        return cls_loss + consistency_loss
    
    def cluster_and_label(self, X_target):
        """
        聚类并生成伪标签
        """
        target_features = self.backbone(X_target)
        
        # K-means聚类
        from sklearn.cluster import KMeans
        kmeans = KMeans(n_clusters=self.num_clusters)
        pseudo_labels = kmeans.fit_predict(target_features.detach().numpy())
        
        return torch.tensor(pseudo_labels)
    
    def consistency_loss(self, X_target, pseudo_labels):
        """
        一致性损失
        """
        # 对目标样本应用增强
        X_aug1 = self.augment(X_target)
        X_aug2 = self.augment(X_target)
        
        # 提取特征
        f1 = self.backbone(X_aug1)
        f2 = self.backbone(X_aug2)
        
        # 一致性损失
        return torch.mean((f1 - f2) ** 2)

实践指南

方法选择

场景推荐方法
小数据集MMD, CORAL
大数据集DANN, ADDA
边缘分布偏移样本加权
条件分布偏移条件对抗
计算资源有限CORAL

超参数设置

class DomainAdaptationConfig:
    """
    领域适应超参数配置
    """
    default_config = {
        # 分布对齐
        'mmd_sigma': 1.0,
        'coral_weight': 1.0,
        
        # 对抗训练
        'dann_alpha': 1.0,
        'lr_discriminator': 0.001,
        'lr_encoder': 0.001,
        
        # 伪标签
        'pseudo_threshold': 0.9,
        'pseudo_confidence_weight': 0.5,
        
        # 正则化
        'weight_decay': 1e-4,
        'dropout': 0.5
    }

参考


相关链接

Footnotes

  1. Pan, S. J., & Yang, Q. (2009). A Survey on Transfer Learning. IEEE Transactions on Knowledge and Data Engineering.