多模态对齐与融合方法

多模态学习(Multimodal Learning)的核心挑战在于如何有效地对齐(Alignment)来自不同模态的表示,并将其融合(Fusion)以完成下游任务。1 本文系统梳理多模态对齐的理论基础、主流方法(CLIP、SigLIP、VISTA),以及多模态融合的主要范式(Early Fusion、Late Fusion、Cross Fusion),并讨论多层视觉特征融合的最佳实践。


1. 跨模态对齐理论

1.1 对齐的数学定义

跨模态对齐(Cross-modal Alignment)的目标是学习一个映射,使得不同模态的表示空间在语义上对齐。形式化地说,给定两个模态 ,对齐的目标是学习编码器 ,使得对齐后的表示 在语义空间中距离相近,当且仅当 描述相同语义内容时。2

度量学习视角:对齐可以视为度量学习的特例,其目标函数为:

其中 为正样本对分布, 为负样本对分布, 为距离度量(如余弦距离), 为间隔阈值。

1.2 对齐损失函数:InfoNCE 与 NT-Xent

InfoNCE(Information Noise-Contrastive Estimation)是对比学习的核心损失函数,定义为:

其中 为余弦相似度, 为温度参数。3

NT-Xent(Normalized Temperature-scaled Cross Entropy)是InfoNCE的对称版本,同时考虑两个方向的预测:

温度参数 的影响

特性适用场景
聚焦最相似的正样本,忽略负样本硬负样本挖掘
平衡正负样本权重CLIP默认设置
所有样本权重趋于均匀均匀表示学习

1.3 对齐质量的评估指标

检索任务指标

  • Recall@K:正确检索结果在前K个中的比例
  • mAP@K:平均精度均值
  • MRR:平均倒数排名

表示空间指标

  • Alignment Score:衡量配对样本在表示空间中的接近程度
  • Uniformity Score:衡量表示分布的均匀性

2. 对比学习对齐方法

2.1 CLIP: Contrastive Language-Image Pre-training

CLIP由OpenAI于2021年提出,是多模态对齐领域的里程碑工作。4

双塔架构

CLIP采用双塔(Dual-Encoder)架构,分别用视觉编码器和文本编码器处理图像和文本:

图像 → Vision Transformer → $z_I \in \mathbb{R}^D$
文本 → Text Transformer   → $z_T \in \mathbb{R}^D$

对齐:通过对比损失拉近配对的 $(z_I, z_T)$

架构特点

  • 视觉编码器:ViT(Vision Transformer)或 ResNet
  • 文本编码器:Transformer decoder
  • 输出维度统一为 维表示空间

对比损失函数

CLIP使用对称的InfoNCE损失,同时优化图像→文本和文本→图像两个方向:

其中:

训练目标

CLIP的预训练目标可以理解为最大化配对样本之间的互信息下界:

详见 CLIP对比学习理论

2.2 SigLIP: Sigmoid Loss for Vision-Language

SigLIP由Google Research提出,使用Sigmoid损失替代CLIP中的softmax归一化损失。5

Sigmoid对比损失

CLIP的softmax归一化

SigLIP的sigmoid损失(独立二分类):

其中 当且仅当 (正样本对),否则

核心差异

特性CLIP (softmax)SigLIP (sigmoid)
归一化全局softmax独立sigmoid
负样本关系隐式竞争独立处理
训练稳定性温度可学习固定
可扩展性负样本数量受限支持大批量

多语言支持

SigLIP在大规模多语言数据上训练,支持超过100种语言的图文对齐:

def siglip_loss(image_features, text_features, temperature):
    """
    SigLIP Sigmoid损失实现
    image_features: (B, D)
    text_features: (B, D)
    """
    # 计算相似度矩阵
    scores = image_features @ text_features.t()  # (B, B)
    
    # 创建标签:对角线为1,其余为0
    labels = torch.eye(scores.shape[0], device=scores.device)
    
    # Sigmoid概率
    probs = torch.sigmoid(scores / temperature)
    
    # 二分类交叉熵
    loss = -labels * torch.log(probs + 1e-8) - (1 - labels) * torch.log(1 - probs + 1e-8)
    
    return loss.mean()

2.3 SigLIP 2改进

SigLIP 2在SigLIP基础上进行了多项改进。6

主要改进点

  1. 更强的视觉编码器:使用更大的ViT架构(ViT-g, ViT-G/14)
  2. 改进的文本编码器:采用Gemma/LLaMA风格的decoder-only架构
  3. 动态温度学习:引入可学习的温度参数
  4. 更大的训练规模:使用更大的图文数据集

架构对比

版本视觉编码器文本编码器损失函数
CLIPViT-B/LTransformerSoftmax CE
SigLIPViT-B/L/GTransformerSigmoid CE
SigLIP 2EVA-CLIP ViTDecoder-only LLMSigmoid CE + Hard Negatives

3. 跨模态互信息最大化(VISTA)

VISTA(Variational Information Scaling for Text-Image Alignment)提出了一种基于互信息最大化的统一对齐框架。7

3.1 互信息在多模态中的应用

跨模态学习的核心目标是最大化不同模态之间的互信息:

在多模态场景下,互信息度量了从一种模态中可以获取的关于另一种模态的信息量。

3.2 下界估计方法

直接计算互信息是困难的,通常使用变分下界(VLB)估计:

其中 是可学习的判别器, 是负样本数量。

VISTA的核心创新:引入可学习的温度缩放因子

其中 通过梯度上升自适应调整。

3.3 实践应用

class VISTALoss(nn.Module):
    """
    VISTA: Variational Information Scaling for Text-Image Alignment
    """
    def __init__(self, temperature=0.07, beta=0.1):
        super().__init__()
        self.temperature = temperature
        self.beta = beta  # 负样本权重
        self.alpha = nn.Parameter(torch.ones(1))  # 可学习缩放因子
    
    def forward(self, image_features, text_features):
        # 相似度矩阵
        sim = image_features @ text_features.t() / self.temperature
        
        # 正样本对:对角线
        batch_size = sim.shape[0]
        pos_mask = torch.eye(batch_size, device=sim.device, dtype=torch.bool)
        
        # 正样本损失
        pos_sim = sim[pos_mask]
        loss_pos = -torch.log(torch.sigmoid(self.alpha * pos_sim) + 1e-8).mean()
        
        # 负样本损失
        neg_sim = sim[~pos_mask].view(batch_size, -1)
        loss_neg = -torch.log(1 - torch.sigmoid(neg_sim) + 1e-8).mean()
        
        # 总损失
        return loss_pos + self.beta * loss_neg

4. 多模态融合方法

多模态融合(Multimodal Fusion)旨在组合来自不同模态的信息以完成下游任务。根据融合发生的阶段,主要分为三类:Early FusionLate FusionCross Fusion8

4.1 Early Fusion

Early Fusion(早期融合)将原始或浅层特征在输入层进行拼接,然后通过统一的模型进行处理。

特征拼接

其中 是第 个模态的初始表示, 为投影参数。

class EarlyFusion(nn.Module):
    def __init__(self, dim_vision, dim_text, hidden_dim):
        super().__init__()
        # 投影到统一空间
        self.proj_vision = nn.Linear(dim_vision, hidden_dim)
        self.proj_text = nn.Linear(dim_text, hidden_dim)
        
        # 融合层
        self.fusion_layer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
    
    def forward(self, vision_feat, text_feat):
        # 投影到统一空间
        v = self.proj_vision(vision_feat)
        t = self.proj_text(text_feat)
        
        # 拼接
        fused = torch.cat([v, t], dim=-1)
        return self.fusion_layer(fused)

优缺点分析

优点缺点
模态间交互充分需要模态对齐的原始数据
联合优化难以处理异构特征
捕获深层交互计算复杂度高

4.2 Late Fusion

Late Fusion(晚期融合)分别处理各模态,然后在决策层进行融合。

决策融合

class LateFusion(nn.Module):
    def __init__(self, vision_encoder, text_encoder, classifier):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder
        self.classifier = classifier
        self.fusion_weight = nn.Parameter(torch.ones(2) / 2)  # 可学习权重
    
    def forward(self, images, texts):
        # 分别编码
        z_vision = self.vision_encoder(images)
        z_text = self.text_encoder(texts)
        
        # 独立预测
        logits_v = self.classifier(z_vision)
        logits_t = self.classifier(z_text)
        
        # 加权融合
        weights = F.softmax(self.fusion_weight, dim=0)
        logits_fused = weights[0] * logits_v + weights[1] * logits_t
        
        return logits_fused, {'vision': logits_v, 'text': logits_t}

优缺点分析

优点缺点
模态独立训练无法捕获模态间交互
可处理异构数据融合策略固定
容错性强可能丢失互补信息

4.3 Cross Fusion(FUSION, FLARE)

Cross Fusion(跨模态融合)通过层次化交互实现深度模态融合,典型方法包括FUSION9和FLARE。

层次化交互

FUSION框架的核心思想是在多个语义层级上进行模态交互:

class CrossAttentionFusion(nn.Module):
    """
    基于Cross-Attention的跨模态融合
    """
    def __init__(self, dim_vision, dim_text, num_heads=8):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=dim_vision, 
            num_heads=num_heads,
            kdim=dim_text,
            vdim=dim_text
        )
        self.norm = nn.LayerNorm(dim_vision)
        self.ffn = nn.Sequential(
            nn.Linear(dim_vision, dim_vision * 4),
            nn.GELU(),
            nn.Linear(dim_vision * 4, dim_vision)
        )
    
    def forward(self, vision_feat, text_feat):
        """
        vision_feat: (L_v, B, D_v) - 视觉特征序列
        text_feat: (L_t, B, D_t) - 文本特征序列
        """
        # Cross-attention: vision queries text
        attn_out, _ = self.cross_attn(vision_feat, text_feat, text_feat)
        vision_feat = self.norm(vision_feat + attn_out)
        
        # FFN
        vision_feat = self.norm(vision_feat + self.ffn(vision_feat))
        
        return vision_feat

注意力机制

FLARE(Fusion with Latent Regularization)引入潜在正则化来平衡模态贡献:

class FLAREFusion(nn.Module):
    """
    FLARE: Fusion with Latent Regularization
    """
    def __init__(self, dim, num_modalities=2, lambda_reg=0.1):
        super().__init__()
        self.dim = dim
        self.lambda_reg = lambda_reg
        
        # 模态特定投影
        self.proj = nn.ModuleList([
            nn.Linear(dim, dim) for _ in range(num_modalities)
        ])
        
        # 融合权重(可学习)
        self.fusion_weights = nn.Parameter(torch.ones(num_modalities))
        
        # 潜在空间对齐
        self.latent_proj = nn.Linear(dim, dim)
    
    def forward(self, features_list):
        """
        features_list: [feat1, feat2, ...] 各模态特征
        """
        # 归一化权重
        weights = F.softmax(self.fusion_weights, dim=0)
        
        # 投影各模态
        projected = [self.proj[i](feat) for i, feat in enumerate(features_list)]
        
        # 加权融合
        fused = sum(w * p for w, p in zip(weights, projected))
        
        # 潜在正则化:鼓励各模态在潜在空间中对齐
        latent_loss = 0
        for i in range(len(projected)):
            for j in range(i + 1, len(projected)):
                latent_loss += torch.norm(projected[i] - projected[j], p=2)
        
        return fused, self.lambda_reg * latent_loss

实际实现

FUSION框架的完整实现

class FUSIONModel(nn.Module):
    """
    Hierarchical Multimodal Fusion
    arXiv:2504.09925
    """
    def __init__(self, vision_dim, text_dim, hidden_dim, num_layers=3):
        super().__init__()
        
        # 模态特定编码器
        self.vision_encoder = nn.Linear(vision_dim, hidden_dim)
        self.text_encoder = nn.Linear(text_dim, hidden_dim)
        
        # 多层跨模态交互
        self.fusion_layers = nn.ModuleList([
            CrossAttentionFusion(hidden_dim, hidden_dim)
            for _ in range(num_layers)
        ])
        
        # 层级注意力(融合多层特征)
        self.layer_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=8
        )
        
        # 输出分类器
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, vision_feat, text_feat, return_layers=False):
        # 初始投影
        v = self.vision_encoder(vision_feat)
        t = self.text_encoder(text_feat)
        
        # 多层融合
        layer_outputs = []
        for layer in self.fusion_layers:
            v = layer(v, t)
            layer_outputs.append(v)
        
        # 层级注意力聚合
        layer_outputs = torch.stack(layer_outputs, dim=0)  # (num_layers, B, D)
        fused, _ = self.layer_attention(layer_outputs, layer_outputs, layer_outputs)
        
        # 最终输出
        final_feat = fused.mean(dim=0)  # (B, D)
        
        if return_layers:
            return self.classifier(final_feat), layer_outputs
        return self.classifier(final_feat)

5. 多层视觉特征融合

5.1 动机:浅层 vs 深层特征

视觉模型的不同层捕获不同级别的语义信息:

层级特征类型特性
浅层局部纹理、边缘、颜色高分辨率,细粒度
中层部件、模式中等语义
深层高级语义、类别全局上下文,低分辨率

5.2 Multi-Layer Feature Fusion方法

特征金字塔融合(Feature Pyramid Fusion):

class MultiLayerFeatureFusion(nn.Module):
    """
    多层视觉特征融合
    """
    def __init__(self, feature_dims, output_dim):
        super().__init__()
        
        # 投影层(统一维度)
        self.projections = nn.ModuleList([
            nn.Linear(dim, output_dim) for dim in feature_dims
        ])
        
        # 注意力权重(可学习)
        self.layer_weights = nn.Parameter(torch.ones(len(feature_dims)))
        
        # 上采样层(处理不同分辨率)
        self.upsamples = nn.ModuleList([
            nn.Identity() if i == len(feature_dims) - 1 
            else nn.Sequential(
                nn.Linear(output_dim, output_dim * 4),
                nn.GELU(),
                nn.Linear(output_dim * 4, output_dim)
            )
            for i in range(len(feature_dims))
        ])
    
    def forward(self, multi_layer_features):
        """
        multi_layer_features: List of features from different layers
        """
        projected = []
        for i, feat in enumerate(multi_layer_features):
            p = self.projections[i](feat)
            p = self.upsamples[i](p)
            projected.append(p)
        
        # 注意力加权
        weights = F.softmax(self.layer_weights, dim=0)
        fused = sum(w * p for w, p in zip(weights, projected))
        
        return fused

U-Net风格的跳跃连接融合

class UNetStyleFusion(nn.Module):
    """
    类似U-Net的编码器-解码器特征融合
    """
    def __init__(self, encoder_dims, decoder_dim):
        super().__init__()
        
        # 解码器(上采样路径)
        self.decoder_blocks = nn.ModuleList()
        self.skip_connections = nn.ModuleList()
        
        for i, enc_dim in enumerate(encoder_dims):
            # 跳跃连接投影
            self.skip_connections.append(
                nn.Linear(enc_dim, decoder_dim)
            )
            
            if i < len(encoder_dims) - 1:
                # 上采样块
                self.decoder_blocks.append(
                    nn.Sequential(
                        nn.Linear(decoder_dim, decoder_dim * 2),
                        nn.GELU(),
                        nn.Linear(decoder_dim * 2, decoder_dim)
                    )
                )
    
    def forward(self, encoder_features):
        """
        encoder_features: 从浅到深的多层特征
        """
        x = encoder_features[-1]  # 从最深层开始
        
        for i in range(len(encoder_features) - 2, -1, -1):
            # 跳跃连接
            skip = self.skip_connections[i](encoder_features[i])
            
            # 融合
            x = x + skip
            
            if i < len(encoder_features) - 1:
                x = self.decoder_blocks[i](x)
        
        return x

5.3 最佳实践指南

特征选择策略

  1. 任务相关性:根据下游任务选择关键层级

    • 图像分类:深层特征为主
    • 目标检测:多层特征金字塔
    • 分割任务:浅层+深层结合
  2. 维度对齐:确保融合前各层特征维度一致

  3. 注意力机制:使用可学习的注意力权重自适应融合

class AdaptiveFeatureFusion(nn.Module):
    """
    自适应特征融合(基于任务自适应权重)
    """
    def __init__(self, feature_dims, output_dim, num_heads=4):
        super().__init__()
        
        # 特征投影
        self.projections = nn.ModuleList([
            nn.Linear(dim, output_dim) for dim in feature_dims
        ])
        
        # 任务相关注意力
        self.task_attention = nn.MultiheadAttention(
            embed_dim=output_dim,
            num_heads=num_heads,
            kdim=output_dim,
            vdim=output_dim
        )
        
        # 门控机制
        self.gates = nn.ModuleList([
            nn.Sequential(
                nn.Linear(output_dim, 1),
                nn.Sigmoid()
            ) for _ in feature_dims
        ])
    
    def forward(self, features, task_embedding=None):
        # 投影所有层
        projected = [proj(feat) for proj, feat in zip(self.projections, features)]
        
        # 门控加权
        gated = [g(p) * p for g, p in zip(self.gates, projected)]
        
        # 聚合
        fused = sum(gated) / len(gated)
        
        return fused

6. 最新进展

6.1 动态融合策略

动态融合根据输入内容自适应调整融合方式:

class DynamicFusion(nn.Module):
    """
    输入自适应的动态融合
    """
    def __init__(self, vision_dim, text_dim, hidden_dim):
        super().__init__()
        
        # 模态编码器
        self.vision_proj = nn.Linear(vision_dim, hidden_dim)
        self.text_proj = nn.Linear(text_dim, hidden_dim)
        
        # 动态融合权重预测器
        self.fusion_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),  # 两个模态的权重
            nn.Softmax(dim=-1)
        )
        
        # 融合操作选择器
        self.fusion_ops = nn.ModuleDict({
            'concat': ConcatFusion(hidden_dim),
            'attention': AttentionFusion(hidden_dim),
            'product': ProductFusion(hidden_dim)
        })
        
        self.op_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, 3),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, vision_feat, text_feat):
        # 投影
        v = self.vision_proj(vision_feat)
        t = self.text_proj(text_feat)
        
        # 预测融合权重
        fusion_weights = self.fusion_predictor(torch.cat([v, t], dim=-1))
        
        # 预测融合操作
        op_weights = self.op_predictor(torch.cat([v, t], dim=-1))
        
        # 执行多种融合操作
        outputs = []
        for name, op in self.fusion_ops.items():
            outputs.append(op(v, t))
        
        outputs = torch.stack(outputs, dim=0)  # (num_ops, B, D)
        
        # 加权融合操作
        fused = (outputs * op_weights.unsqueeze(1)).sum(dim=0)
        
        # 加权模态
        final = fusion_weights[0] * v + fusion_weights[1] * t
        
        return fused + 0.1 * final  # 结合两种融合

6.2 自适应融合

Modality DropoutGated Networks 是常见的自适应融合技术:

class GatedMultimodalFusion(nn.Module):
    """
    门控多模态融合
    """
    def __init__(self, vision_dim, text_dim, hidden_dim):
        super().__init__()
        
        # 模态编码
        self.vision_proj = nn.Linear(vision_dim, hidden_dim)
        self.text_proj = nn.Linear(text_dim, hidden_dim)
        
        # 门控网络
        self.gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )
        
        # 特征变换
        self.transform = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh()
        )
    
    def forward(self, vision_feat, text_feat, training=True):
        v = self.vision_proj(vision_feat)
        t = self.text_proj(text_feat)
        
        # 门控权重
        gate_values = self.gate(torch.cat([v, t], dim=-1))
        
        # 变换后的特征
        transformed = self.transform(torch.cat([v, t], dim=-1))
        
        # 门控融合
        fused = gate_values * transformed
        
        # 模态dropout(训练时随机丢弃模态)
        if training and self.training:
            batch_size = vision_feat.shape[0]
            drop_mask = torch.rand(batch_size, 1, device=vision_feat.device) > 0.5
            fused = fused * drop_mask.float()
        
        return fused

6.3 稀疏融合

稀疏融合通过稀疏注意力机制减少计算复杂度:

class SparseMultimodalFusion(nn.Module):
    """
    基于稀疏注意力的多模态融合
    """
    def __init__(self, dim, num_heads=8, sparsity_ratio=0.3):
        super().__init__()
        self.num_heads = num_heads
        self.sparsity_ratio = sparsity_ratio
        
        # 多头注意力
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        
        # 稀疏性预测器
        self.sparsity_predictor = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.Sigmoid()
        )
    
    def forward(self, vision_feat, text_feat):
        B = vision_feat.shape[0]
        
        # 计算Query, Key, Value
        q = self.q_proj(vision_feat)
        k_v = self.k_proj(vision_feat)
        k_t = self.k_proj(text_feat)
        v_v = self.v_proj(vision_feat)
        v_t = self.v_proj(text_feat)
        
        # 预测稀疏性
        sparse_weights = self.sparsity_predictor(torch.cat([q, k_t], dim=-1))
        
        # 应用稀疏性(随机mask一部分连接)
        if self.training:
            mask = torch.rand_like(sparse_weights) > self.sparsity_ratio
            sparse_weights = sparse_weights * mask.float()
        
        # 跨模态注意力
        cross_attn = torch.sigmoid(q @ k_t.transpose(-2, -1)) / (self.num_heads ** 0.5)
        cross_attn = cross_attn * sparse_weights
        
        # 聚合文本信息
        output = cross_attn @ v_t
        
        # 残差连接
        return vision_feat + output

7. 实践建议

融合策略选择指南

场景推荐策略
模态特征相似Early Fusion
模态异构性强Late Fusion
需要深度交互Cross Fusion
计算资源受限稀疏融合
模态可能缺失Late Fusion + Gating

常见陷阱与解决方案

  1. 模态不平衡:使用加权损失或动态权重
  2. 过拟合:早停、正则化、模态dropout
  3. 梯度冲突:使用梯度平衡技术(如GradNorm)

参考文献

Footnotes

  1. Survey on Multimodal Alignment and Fusion Techniques, arXiv:2411.17040, 2024

  2. Multimodal Learning: Theories and Applications, Springer, 2023

  3. Oord et al., Representation Learning with Contrastive Predictive Coding, arXiv:1807.03748, 2018

  4. Radford et al., Learning Transferable Visual Models From Natural Language Supervision, ICML 2021

  5. Zhai et al., SigLIP: Simple Sigmoid Loss for Language-Image Pre-Training, arXiv:2312.12245, 2023

  6. SigLIP 2: Improved Sigmoid Loss for Vision-Language Pre-training, arXiv:2502.14786, 2025

  7. VISTA: Variational Information Scaling for Text-Image Alignment, arXiv:2505.10917, 2025

  8. Ramachandran et al., STAND-UP: Sparse Multimodal Fusion for Detection and Localization, CVPR 2021

  9. FUSION: Hierarchical Multimodal Fusion with Cross-Modal Attention, arXiv:2504.09925, 2025