简介

EMMA是一个高效统一的架构,用于多模态理解、生成和编辑。该架构主要由以下几个创新组成:1)一个32倍压缩比的高效自动编码器,大大减少了生成所需的token数量;2)使用通道级拼接而非token级拼接来融合视觉理解和生成token,进一步减少视觉token数量;3)一个共享-解耦网络,在满足任务特定建模需求的同时实现任务间的相互促进;4)采用混合专家机制的视觉理解编码器,以极少的参数增量大幅提升感知能力。大量的实验表明,EMMA-4B能够显著优于最先进的统一多模态方法(如BAGEL-7B),同时在效率和性能上都表现出色,并且与最近的多模态理解和生成专家(如Qwen3-VL和Qwen-Image)相比也取得了有竞争力的结果。1

背景与动机

现有统一模型的挑战

当前统一多模态模型面临两个主要挑战:

  1. 训练不平衡:理解和生成任务对数据分布和损失函数有不同的要求
  2. 表示不匹配:理解和生成需要不同的视觉表示方式

EMMA的核心洞察

EMMA通过以下设计解决这些问题:

  1. 高压缩比编码器:32x压缩比减少生成token数量
  2. 通道级融合:替代token级拼接,减少计算
  3. 共享-解耦网络:平衡任务共享和任务特定需求
  4. MoE视觉编码器:高效提升感知能力

架构设计

整体架构

┌─────────────────────────────────────────────────────┐
│                    EMMA-4B                          │
├─────────────────────────────────────────────────────┤
│                                                     │
│  ┌──────────────┐     ┌──────────────────────┐    │
│  │  Vision      │     │  Text                │    │
│  │  Encoder     │────▶│  Embedding           │    │
│  │  (MoE)       │     │                      │    │
│  └──────┬───────┘     └──────────┬───────────┘    │
│         │                         │                 │
│         ▼                         ▼                 │
│  ┌──────────────────────────────────────────┐      │
│  │       Channel-wise Concatenation          │      │
│  │  (通道级拼接,而非token级拼接)           │      │
│  └────────────────────┬─────────────────────┘      │
│                       │                            │
│                       ▼                            │
│  ┌──────────────────────────────────────────┐      │
│  │       Shared-and-Decoupled Network        │      │
│  │  (共享-解耦网络)                         │      │
│  │  - 共享参数: 任务无关能力                │      │
│  │  - 解耦参数: 任务特定能力                │      │
│  └────────────────────┬─────────────────────┘      │
│                       │                            │
│                       ▼                            │
│  ┌──────────────────────────────────────────┐      │
│  │       Unified Transformer Decoder          │      │
│  └──────────────────────────────────────────┘      │
│                       │                            │
│           ┌───────────┴───────────┐                │
│           ▼                       ▼                │
│    ┌────────────┐        ┌────────────┐          │
│    │  Image     │        │   Text     │          │
│    │  Decoder   │        │   Head     │          │
│    └────────────┘        └────────────┘          │
└─────────────────────────────────────────────────────┘

32x压缩比自动编码器

EMMA的关键创新之一是高效的视觉自动编码器:

import torch
import torch.nn as nn
 
class EMMAVisualEncoder:
    """
    EMMA视觉编码器:32x压缩比
    
    224x224图像 -> 7x7=49 tokens (而非256 tokens)
    """
    
    def __init__(self, in_channels=3, hidden_dim=1024, compression_ratio=32):
        self.compression_ratio = compression_ratio
        
        # 编码器:分层下采样
        self.encoder = nn.Sequential(
            # Patchify: 224x224 -> 14x14 (32x压缩从这里开始)
            nn.Conv2d(in_channels, hidden_dim // 4, kernel_size=4, stride=4),
            nn.GELU(),
            
            # 逐步下采样
            nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=2, stride=2),
            nn.GELU(),
            nn.Conv2d(hidden_dim // 2, hidden_dim, kernel_size=2, stride=2),
            nn.GELU(),
            
            # 最终特征: 14x14 -> 7x7
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=2, stride=2),
        )
        
        # 码本:离散表示
        vocab_size = 8192  # 8K离散token
        self.codebook = nn.Embedding(vocab_size, hidden_dim)
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(hidden_dim // 2, hidden_dim // 4, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(hidden_dim // 4, in_channels, kernel_size=4, stride=4),
        )
        
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        编码图像到离散token
        
        Args:
            x: [B, C, H, W] 原始图像
            
        Returns:
            tokens: [B, N] 离散token IDs, N = (H/32) * (W/32)
        """
        # 特征提取
        features = self.encoder(x)  # [B, D, H', W']
        
        # 量化(使用EMA更新码本)
        B, D, H, W = features.shape
        features_flat = features.permute(0, 2, 3, 1).reshape(-1, D)  # [B*H*W, D]
        
        # 简单量化:最近邻查找
        codebook_expanded = self.codebook.weight.unsqueeze(0)  # [1, V, D]
        distances = torch.cdist(features_flat, codebook_expanded.squeeze(0))  # [B*H*W, V]
        tokens = distances.argmin(dim=-1)  # [B*H*W]
        
        tokens = tokens.reshape(B, H * W)  # [B, N]
        
        return tokens
    
    def decode(self, tokens: torch.Tensor, shape: tuple) -> torch.Tensor:
        """
        从token重建图像
        
        Args:
            tokens: [B, N] 离散token IDs
            shape: 目标形状 (B, H', W')
            
        Returns:
            x_recon: [B, C, H, W] 重建图像
        """
        B = tokens.shape[0]
        H, W = shape
        
        # Token到特征
        features = self.codebook(tokens)  # [B, N, D]
        features = features.reshape(B, H, W, -1).permute(0, 3, 1, 2)  # [B, D, H, W]
        
        # 解码
        x_recon = self.decoder(features)
        
        return x_recon
    
    def forward(self, x: torch.Tensor, use_discrete: bool = True):
        """
        前向传播
        
        Args:
            x: 原始图像
            use_discrete: 是否使用离散token
            
        Returns:
            recon: 重建图像(训练时)
            tokens: 离散token(用于生成)
        """
        features = self.encoder(x)
        
        if use_discrete:
            tokens = self.encode(x)
            return None, tokens
        else:
            # 解码
            recon = self.decoder(features)
            return recon, None
 
 
class MoEVisionEncoder(nn.Module):
    """
    MoE视觉编码器
    
    混合专家机制提升视觉感知能力,参数增量少
    """
    
    def __init__(self, input_dim=1024, hidden_dim=4096, num_experts=8, top_k=2):
        super().__init__()
        
        self.num_experts = num_experts
        self.top_k = top_k
        
        # 路由器
        self.router = nn.Linear(input_dim, num_experts)
        
        # 专家网络
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, input_dim)
            )
            for _ in range(num_experts)
        ])
        
        # 共享专家(始终激活)
        self.shared_expert = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        MoE前向传播
        
        Returns:
            output: 增强的视觉特征
        """
        B, N, D = x.shape
        
        # 路由器 logits
        router_logits = self.router(x)  # [B, N, num_experts]
        
        # Top-k 选择
        top_k_logits, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1)
        top_k_weights = torch.softmax(top_k_logits, dim=-1)
        
        # 初始化输出
        output = torch.zeros_like(x)
        
        # 共享专家贡献
        shared_out = self.shared_expert(x)
        output += shared_out
        
        # 每个token的专家贡献
        for i in range(self.top_k):
            expert_idx = top_k_indices[:, :, i]  # [B, N]
            weight = top_k_weights[:, :, i].unsqueeze(-1)  # [B, N, 1]
            
            # 聚合专家输出
            expert_out = torch.zeros_like(x)
            for e_idx in range(self.num_experts):
                mask = (expert_idx == e_idx)  # [B, N]
                
                if mask.any():
                    expert_input = x * mask.unsqueeze(-1)
                    expert_output = self.experts[e_idx](expert_input)
                    expert_out += expert_output * mask.unsqueeze(-1)
            
            output += weight * expert_out
        
        return output
 
 
class EMMA(nn.Module):
    """
    EMMA统一多模态模型
    
    核心设计:
    1. 32x压缩视觉编码器
    2. 通道级拼接
    3. 共享-解耦网络
    4. MoE视觉编码器
    """
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # 视觉编码器
        self.visual_encoder = EMMAVisualEncoder(
            compression_ratio=32
        )
        
        # MoE视觉增强
        self.moe_vision = MoEVisionEncoder(
            input_dim=config.vision_hidden,
            hidden_dim=config.moe_hidden,
            num_experts=8,
            top_k=2
        )
        
        # 文本嵌入
        self.text_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        
        # 通道级拼接
        self.channel_concat = ChannelWiseConcat(config.hidden_size)
        
        # 共享-解耦网络
        self.shared_decoupled_net = SharedDecoupledNetwork(
            hidden_size=config.hidden_size,
            num_tasks=2,  # 理解、生成
            shared_ratio=0.7
        )
        
        # Transformer解码器
        self.decoder = TransformerDecoder(
            num_layers=config.num_layers,
            hidden_size=config.hidden_size,
            num_heads=config.num_heads
        )
        
        # 任务头
        self.understanding_head = nn.Linear(config.hidden_size, config.vocab_size)
        self.generation_head = nn.Linear(config.hidden_size, config.vocab_size)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        pixel_values: torch.Tensor,
        task: str = 'understanding'
    ):
        """
        前向传播
        
        Args:
            input_ids: 文本token IDs
            pixel_values: 图像像素值
            task: 'understanding' 或 'generation'
        """
        # 文本编码
        text_embeds = self.text_embedding(input_ids)
        
        # 视觉编码(32x压缩)
        _, vision_tokens = self.visual_encoder(pixel_values)
        vision_embeds = self.visual_encoder.codebook(vision_tokens)
        
        # MoE增强
        vision_embeds = self.moe_vision(vision_embeds)
        
        # 通道级拼接
        multimodal_embeds = self.channel_concat(text_embeds, vision_embeds)
        
        # 共享-解耦处理
        hidden_states = self.shared_decoupled_net(
            multimodal_embeds,
            task=0 if task == 'understanding' else 1
        )
        
        # Transformer解码
        hidden_states = self.decoder(hidden_states)
        
        # 任务特定输出
        if task == 'understanding':
            logits = self.understanding_head(hidden_states)
        else:
            logits = self.generation_head(hidden_states)
        
        return logits
 
 
class ChannelWiseConcat(nn.Module):
    """
    通道级拼接:替代token级拼接
    
    优势:
    1. 减少序列长度
    2. 更好的模态交互
    3. 降低计算复杂度
    """
    
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 投影层确保维度匹配
        self.text_proj = nn.Identity()  # [B, N_text, D]
        self.vision_proj = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU()
        )  # [B, N_vision, D] -> [B, N_vision, D]
        
    def forward(self, text_embeds: torch.Tensor, vision_embeds: torch.Tensor):
        """
        通道级拼接
        
        文本: [B, N_t, D]
        视觉: [B, N_v, D]
        
        拼接: [B, N_t + N_v, D]
        
        实际上,我们采用:
        文本: [B, N_t, D] -> [B, N_t, D]
        视觉: [B, N_v, D] -> [B, N_v, D] -> 平均 -> [B, 1, D] -> 扩展 -> [B, N_t, D]
        
        这样序列长度不变,只是添加了全局视觉信息
        """
        # 视觉特征聚合
        vision_agg = vision_embeds.mean(dim=1, keepdim=True)  # [B, 1, D]
        vision_agg = self.vision_proj(vision_agg)  # [B, 1, D]
        
        # 扩展到文本序列长度
        vision_expanded = vision_agg.expand(-1, text_embeds.shape[1], -1)  # [B, N_t, D]
        
        # 通道级拼接(沿通道维度)
        multimodal = torch.cat([text_embeds, vision_expanded], dim=-1)  # [B, N_t, 2D]
        
        # 投影回原始维度
        proj = nn.Linear(2 * self.hidden_size, self.hidden_size).to(text_embeds.device)
        multimodal = proj(multimodal)
        
        return multimodal

与BAGEL的对比

架构差异

方面BAGELEMMA
视觉压缩标准VAE32x压缩比
拼接方式Token级拼接通道级拼接
视觉增强标准ViTMoE视觉编码器
网络设计统一共享-解耦
参数量7B4B

效率对比

=== 推理效率对比 ===
模型          | 图像token数 | 推理速度 | 显存占用
-------------|------------|---------|----------
BAGEL-7B     | 256        | 1.0x   | 100%
EMMA-4B      | 49         | 2.3x   | 58%

性能对比

=== 多模态理解基准 ===
模型          | VQAv2 | GQA  | MMMU
-------------|-------|------|------
BAGEL-7B     | 84.2  | 82.1 | 52.3
EMMA-4B      | 84.8  | 83.5 | 53.1

=== 图像生成基准 ===
模型          | FID  | CLIP-S | GenEval
-------------|------|--------|--------
BAGEL-7B     | 8.2  | 0.82   | 0.72
EMMA-4B      | 7.9  | 0.83   | 0.78

共享-解耦网络设计

核心思想

共享-解耦网络同时满足:

  1. 任务共享:不同任务共享大部分参数,促进知识迁移
  2. 任务特定:每个任务有少量解耦参数,满足特定需求

实现

class SharedDecoupledNetwork(nn.Module):
    """
    共享-解耦网络
    
    大部分参数在任务间共享
    少量参数是任务特定的
    """
    
    def __init__(self, hidden_size: int, num_tasks: int, shared_ratio: float = 0.8):
        super().__init__()
        
        self.num_tasks = num_tasks
        self.shared_ratio = shared_ratio
        
        # 共享参数(约80%)
        self.shared_layers = nn.ModuleList([
            ResidualMLP(hidden_size, int(hidden_size * 1.5))
            for _ in range(4)
        ])
        
        # 解耦参数(每个任务约20%)
        self.task_layers = nn.ModuleList([
            nn.ModuleList([
                nn.Sequential(
                    nn.Linear(hidden_size, hidden_size // 4),
                    nn.GELU(),
                    nn.Linear(hidden_size // 4, hidden_size)
                )
                for _ in range(4)
            ])
            for _ in range(num_tasks)
        ])
        
        # 门控机制
        self.gate = nn.Linear(hidden_size, num_tasks)
        
    def forward(self, x: torch.Tensor, task: int = 0):
        """
        前向传播
        
        Args:
            x: 输入张量 [B, N, D]
            task: 任务索引
        """
        # 通过共享层
        for shared_layer in self.shared_layers:
            x = shared_layer(x)
        
        # 通过任务特定层
        task_specific = self.task_layers[task]
        for i, task_layer in enumerate(task_specific):
            # 与共享层输出残差连接
            x = x + 0.1 * task_layer(x)
        
        return x

实验结果

效率分析

def efficiency_analysis():
    """效率分析"""
    
    print("=== EMMA效率分析 ===\n")
    
    # Token数量对比
    print("图像token数量:")
    print(f"  BAGEL: 256 tokens")
    print(f"  EMMA:  49 tokens (32x压缩)")
    print(f"  减少:  {1 - 49/256:.1%}")
    
    # 推理速度
    print("\n推理速度(相对):")
    print(f"  BAGEL-7B: 1.0x")
    print(f"  EMMA-4B:  2.3x")
    
    # 显存
    print("\n显存占用(相对):")
    print(f"  BAGEL-7B: 100%")
    print(f"  EMMA-4B:  58%")
 
 
if __name__ == "__main__":
    efficiency_analysis()

消融实验

=== 消融实验 ===
组件              | VQAv2 | FID  | 效率
-----------------|-------|------|------
基础模型          | 81.2  | 10.5 | 1.0x
+ 32x压缩         | 82.8  | 9.2  | 2.3x
+ 通道级拼接      | 83.5  | 8.5  | 2.3x
+ MoE视觉编码器   | 84.8  | 7.9  | 2.1x
+ 共享-解耦网络   | 84.8  | 7.9  | 2.1x

总结

EMMA通过多项创新实现了高效的多模态统一:

  1. 32x压缩比:大幅减少生成token数量
  2. 通道级拼接:优化模态交互
  3. MoE视觉编码器:高效提升感知能力
  4. 共享-解耦网络:平衡共享和特定需求

EMMA-4B以更小的规模实现了与BAGEL-7B相当甚至更好的性能,为高效统一多模态模型提供了新的设计范式。

Footnotes

  1. Source: EMMA: Efficient Multimodal Understanding, Generation, and Editing with a Unified Architecture