title: NExT-OMNI离散Flow Matching统一模型
date: 2026-05-08
description: 通过离散Flow Matching实现任意到任意理解和生成,支持跨模态检索和多轮多模态交互
tags:

  • multimodal
  • omnimodal
  • discrete-flow-matching
  • next-omni
    draft: false
    permalink:

简介

能够进行任意到任意跨模态生成和多轮交互的下一代多模态基础模型将成为通用人工智能系统的核心组件,在人机交互中发挥关键作用。然而,大多数现有的多模态模型仍然受限于自回归架构,其固有限制阻碍了理解和生成能力的均衡整合。虽然已经探索了混合和分离策略来在统一框架内分别处理这些任务,但其冗余、非整合的设计限制了其在更广泛场景(如跨模态检索)中的适用性。本文介绍NExT-OMNI,一个通过离散Flow范式实现统一建模的开源全能多模态基础模型。通过利用度量诱导概率路径和动力学最优速度,NExT-OMNI原生支持任意到任意的理解和生成,同时通过简洁的统一表示而非任务解耦设计实现增强的响应效率,同时支持更广泛的应用场景。在大规模交错的文本、图像、视频和音频数据上训练后,NExT-OMNI在多模态生成和理解基准测试上展现出竞争力的性能,同时在多轮多模态交互和跨模态检索方面优于先前的统一模型,突显其作为下一代多模态基础模型的架构优势。1

背景与动机

AR架构的固有限制

自回归(AR)架构在统一多模态模型中存在几个关键限制:

  1. 序列化生成:必须按顺序生成token,无法并行
  2. 误差累积:长序列生成中误差累积严重
  3. 训练-推理不一致:teacher forcing训练但自回归推理
  4. 跨模态不平衡:文本和图像的生成动态不同

Flow Matching的优势

Flow Matching是一种新兴的生成范式:

  • 并行生成:一次前向传播完成生成
  • 连续插值:在噪声和数据之间平滑插值
  • 训练目标简单:简单的均方误差
  • 理论保证:具有最优传输解释

NExT-OMNI核心设计

核心思想

NExT-OMNI的核心创新是将离散Flow Matching应用于统一多模态建模:

  1. 度量诱导概率路径:根据语义相似度设计概率路径
  2. 动力学最优速度:自适应调整采样动态
  3. 离散token空间:在离散token而非连续空间进行Flow Matching

整体架构

┌─────────────────────────────────────────────────────────────────┐
│                        NExT-OMNI Architecture                     │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │              Modality-Specific Encoders                  │    │
│  │  ┌─────────┐  ┌─────────┐  ┌─────────┐  ┌─────────┐   │    │
│  │  │   Text  │  │  Image  │  │  Video  │  │  Audio  │   │    │
│  │  │Encoder  │  │Encoder  │  │Encoder  │  │Encoder  │   │    │
│  │  └────┬────┘  └────┬────┘  └────┬────┘  └────┬────┘   │    │
│  │       └─────┬─────┘       ┌─────┘       ┌─────┘       │    │
│  └─────────────┴─────────────┴─────────────┴───────────────┘    │
│                           │                                      │
│                           ▼                                      │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │              Unified Token Space                         │    │
│  │         所有模态映射到统一的离散token空间                │    │
│  │         V = V_text ∪ V_image ∪ V_video ∪ V_audio       │    │
│  └────────────────────────────┬────────────────────────────┘    │
│                                │                                 │
│                                ▼                                 │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │              Discrete Flow Matching Module               │    │
│  │                                                          │    │
│  │  概率路径: P_t(x|x_0) = σ(t)·x_1 + (1-σ(t))·x_0      │    │
│  │  度量诱导: 基于语义相似度调整路径                       │    │
│  │  最优速度: 自适应采样动态                              │    │
│  └────────────────────────────┬────────────────────────────┘    │
│                                │                                 │
│                                ▼                                 │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │              Metric-Induced Probability Path              │    │
│  │                                                          │    │
│  │  d(x_t)/dt = v(x_t, x_1) - α(t)·x_t                   │    │
│  │                                                          │    │
│  │  其中 v(x_t, x_1) = (x_1 - x_0)·s(x_t, x_1)           │    │
│  │  s(x_t, x_1) = 度量诱导的相似度函数                   │    │
│  └────────────────────────────┬────────────────────────────┘    │
│                                │                                 │
│  ┌────────────────────────────┼────────────────────────────┐    │
│  │                    Tasks Heads                           │    │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐     │    │
│  │  │ Understanding│  │  Generation │  │  Retrieval  │     │    │
│  │  │    Head     │  │    Head     │  │    Head     │     │    │
│  │  └─────────────┘  └─────────────┘  └─────────────┘     │    │
│  └─────────────────────────────────────────────────────────┘    │
└─────────────────────────────────────────────────────────────────┘

核心组件实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict, List
 
class MetricInducedProbabilityPath:
    """
    度量诱导概率路径
    
    核心思想:根据数据点之间的语义相似度调整Flow路径
    """
    
    def __init__(self, sigma: float = 0.5):
        self.sigma = sigma
        
    def compute_similarity(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
        """
        计算语义相似度
        
        使用学习的度量函数而非简单的欧氏距离
        """
        # L2距离
        l2_dist = torch.norm(x0 - x1, dim=-1)
        
        # 转换为相似度 (0到1)
        similarity = torch.exp(-l2_dist / self.sigma)
        
        return similarity
    
    def compute_velocity(
        self,
        x_t: torch.Tensor,
        x_0: torch.Tensor,
        x_1: torch.Tensor,
        t: torch.Tensor
    ) -> torch.Tensor:
        """
        计算Flow速度
        
        v(x_t, x_1) = (x_1 - x_0) · s(x_t, x_1)
        
        其中 s(x_t, x_1) 是度量诱导的相似度
        """
        # 度量诱导相似度
        similarity = self.compute_similarity(x_t, x_1)
        
        # 速度 = 方向 × 相似度
        velocity = (x_1 - x_0) * similarity.unsqueeze(-1)
        
        return velocity
    
    def flow_step(
        self,
        x_t: torch.Tensor,
        x_0: torch.Tensor,
        x_1: torch.Tensor,
        t: float,
        dt: float = 0.01
    ) -> torch.Tensor:
        """
        单步Flow更新
        
        dx/dt = v(x_t, x_1) - α(t)·x_t
        """
        # 速度
        velocity = self.compute_velocity(x_t, x_0, x_1, t)
        
        # 漂移项
        alpha_t = t  # 简化的线性调度
        
        # 更新
        x_next = x_t + dt * (velocity - alpha_t * x_t)
        
        return x_next
 
 
class KineticOptimalVelocity:
    """
    动力学最优速度
    
    自适应调整Flow动态以实现更快更稳定的采样
    """
    
    def __init__(self, num_steps: int = 10):
        self.num_steps = num_steps
        
    def compute_optimal_schedule(self, difficulties: torch.Tensor) -> torch.Tensor:
        """
        根据难度计算最优时间调度
        
        困难样本(高不确定性)需要更多步骤
        """
        # 难度归一化
        difficulties = difficulties / difficulties.sum()
        
        # 时间步分配(反比于难度)
        t_schedule = torch.cumsum(
            difficulties / difficulties.sum(),
            dim=0
        )
        
        return t_schedule
    
    def sample_with_adaptive_steps(
        self,
        model: nn.Module,
        x_0: torch.Tensor,
        condition: Optional[torch.Tensor] = None,
        difficulties: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        自适应步数采样
        
        简单样本大步长,复杂样本小步长
        """
        if difficulties is None:
            difficulties = torch.ones(self.num_steps)
        
        # 计算最优调度
        t_schedule = self.compute_optimal_schedule(difficulties)
        
        x_t = x_0
        for i, t_i in enumerate(t_schedule):
            dt = 1.0 / self.num_steps
            
            # 预测速度
            velocity = model(x_t, t_i, condition)
            
            # 更新
            x_t = x_t + dt * velocity
            
            # 自适应:困难样本额外步
            if difficulties[i] > difficulties.mean():
                x_t = x_t + 0.5 * dt * velocity
        
        return x_t
 
 
class UnifiedDiscreteTokenizer:
    """
    统一离散Tokenizer
    
    将所有模态映射到统一的离散token空间
    """
    
    def __init__(
        self,
        vocab_size: int = 131072,
        modality_dims: Dict[str, int] = None
    ):
        self.vocab_size = vocab_size
        
        if modality_dims is None:
            modality_dims = {
                'text': 512,
                'image': 1024,
                'video': 2048,
                'audio': 512
            }
        
        self.modality_dims = modality_dims
        
        # 模态特定码本
        self.codebooks = nn.ModuleDict({
            modality: nn.Embedding(vocab_size // sum(modality_dims.values()) * dim, dim)
            for modality, dim in modality_dims.items()
        })
        
        # 共享码本(用于跨模态对齐)
        self.shared_codebook = nn.Embedding(8192, 1024)
        
    def tokenize(
        self,
        modality: str,
        values: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Token化任意模态输入
        
        Args:
            modality: 'text', 'image', 'video', 'audio'
            values: 原始输入
            
        Returns:
            tokens: 离散token IDs
            embeddings: 连续嵌入
        """
        if modality == 'text':
            return self._tokenize_text(values)
        elif modality == 'image':
            return self._tokenize_image(values)
        elif modality == 'video':
            return self._tokenize_video(values)
        elif modality == 'audio':
            return self._tokenize_audio(values)
        else:
            raise ValueError(f"Unknown modality: {modality}")
    
    def _tokenize_image(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """图像token化"""
        # 简化的图像token化
        B, C, H, W = x.shape
        
        # 展平并投影
        x_flat = x.flatten(2).transpose(1, 2)  # [B, H*W, C]
        
        # 量化到码本
        codebook = self.codebooks['image']
        B, N, D = x_flat.shape
        
        # 随机量化(实际中用VQ)
        tokens = torch.randint(0, codebook.num_embeddings, (B, N))
        embeddings = codebook(tokens)
        
        return tokens, embeddings
    
    def _tokenize_text(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """文本token化"""
        codebook = self.codebooks['text']
        tokens = x  # 假设已经是token IDs
        embeddings = codebook(tokens)
        return tokens, embeddings
    
    def detokenize(
        self,
        tokens: torch.Tensor,
        modality: str,
        shape: Tuple
    ) -> torch.Tensor:
        """反token化"""
        if modality == 'image':
            return self._detokenize_image(tokens, shape)
        elif modality == 'text':
            return tokens  # 返回token IDs
        else:
            raise NotImplementedError()
 
 
class NExTOMNI(nn.Module):
    """
    NExT-OMNI: 离散Flow Matching全能多模态模型
    """
    
    def __init__(
        self,
        vocab_size: int = 131072,
        hidden_size: int = 1024,
        num_layers: int = 24,
        num_heads: int = 16
    ):
        super().__init__()
        
        # 统一tokenizer
        self.tokenizer = UnifiedDiscreteTokenizer(vocab_size)
        
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        
        # 度量诱导Flow模块
        self.metric_path = MetricInducedProbabilityPath(sigma=0.5)
        self.kinetic_velocity = KineticOptimalVelocity(num_steps=10)
        
        # Transformer backbone
        self.layers = nn.ModuleList([
            FlowTransformerLayer(
                hidden_size=hidden_size,
                num_heads=num_heads
            )
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(hidden_size)
        
        # 任务头
        self.understanding_head = UnderstandingHead(hidden_size, vocab_size)
        self.generation_head = GenerationHead(hidden_size, vocab_size)
        self.retrieval_head = RetrievalHead(hidden_size, vocab_size)
        
    def forward(
        self,
        modality: str,
        inputs: torch.Tensor,
        task: str = 'understanding',
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """
        前向传播
        """
        # Token化
        tokens, embeddings = self.tokenizer.tokenize(modality, inputs)
        
        # 嵌入
        hidden_states = self.embedding(tokens)
        
        # Flow Matching主干
        for layer in self.layers:
            hidden_states = layer(hidden_states, **kwargs)
        
        hidden_states = self.norm(hidden_states)
        
        # 任务特定输出
        if task == 'understanding':
            return self.understanding_head(hidden_states)
        elif task == 'generation':
            return self.generation_head(hidden_states)
        elif task == 'retrieval':
            return self.retrieval_head(hidden_states)
        else:
            raise ValueError(f"Unknown task: {task}")
    
    def generate(
        self,
        condition: torch.Tensor,
        target_modality: str,
        num_steps: int = 10
    ) -> torch.Tensor:
        """
        生成任意模态内容
        """
        # 从噪声开始
        x_t = torch.randn_like(condition)
        
        # 条件编码
        cond_emb = self.embedding(condition)
        
        # 离散Flow Matching采样
        for t in torch.linspace(0, 1, num_steps):
            # 预测速度
            velocity = self.forward_flow(x_t, t, cond_emb)
            
            # 更新
            dt = 1.0 / num_steps
            x_t = x_t + dt * velocity
        
        # 反token化
        return self.tokenizer.detokenize(x_t.argmax(dim=-1), target_modality)
 
 
def demo_next_omni():
    """演示NExT-OMNI能力"""
    print("=== NExT-OMNI全能多模态模型演示 ===\n")
    
    print("核心能力:")
    print("  1. 任意到任意理解和生成")
    print("  2. 跨模态检索")
    print("  3. 多轮多模态交互")
    print("  4. 并行生成(Flow Matching优势)")
    
    print("\n支持模态:")
    print("  - 文本 ↔ 图像")
    print("  - 图像 ↔ 视频")
    print("  - 音频 ↔ 文本")
    print("  - 视频 ↔ 图像")
    
    print("\n技术特点:")
    print("  - 度量诱导概率路径")
    print("  - 动力学最优速度")
    print("  - 统一离散token空间")
 
 
if __name__ == "__main__":
    demo_next_omni()

度量诱导概率路径详解

核心公式

NExT-OMNI的概率路径定义为:

其中:

  • :标准高斯噪声
  • :目标数据分布
  • :时间调度函数
  • :度量诱导速度
  • :语义相似度函数

度量函数设计

class LearnedMetric:
    """
    学习到的语义度量
    
    超越简单的L2距离,考虑语义关系
    """
    
    def __init__(self, hidden_dim: int = 1024):
        # 双MLP头
        self.encoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        计算学习的度量距离
        
        d(x, y) = ||f(x) - f(y)||²
        """
        fx = self.encoder(x)
        fy = self.encoder(y)
        
        distance = torch.norm(fx - fy, dim=-1)
        return distance
    
    def to_similarity(self, distance: torch.Tensor) -> torch.Tensor:
        """距离转相似度"""
        return torch.exp(-distance / 0.5)

实验结果

基准性能

=== 多模态理解基准 ===
模型              | VQAv2 | GQA  | MMMU | MMBench
-----------------|-------|------|------|--------
GPT-4o           | 86.5  | 84.2 | 69.1 | 83.2
Gemini 1.5       | 85.1  | 83.5 | 67.8 | 81.5
NExT-OMNI-7B    | 83.2  | 81.8 | 65.4 | 79.8

=== 文本到图像生成基准 ===
模型              | FID  | CLIP-S | 推理速度
-----------------|------|--------|--------
SDXL             | 7.6  | 0.81   | 1.0x
DALL-E 3         | 6.8  | 0.84   | 0.5x
NExT-OMNI-7B    | 7.8  | 0.82   | 2.5x

跨模态检索

=== 跨模态检索性能 ===
任务                    | R@1  | R@5  | MRR
----------------------|------|------|----
文本→图像              | 78.5 | 95.2 | 0.86
图像→文本              | 75.8 | 93.1 | 0.83
视频→文本              | 72.3 | 91.5 | 0.81
音频→图像              | 68.9 | 88.7 | 0.78

生成效率对比

=== 生成效率对比 ===
模型              | 并行生成 | 生成步数 | 总时间(相对)
-----------------|---------|---------|-------------
AR模型            | 否      | N       | 1.0x
标准Flow          | 是      | 50      | 0.8x
NExT-OMNI        | 是      | 10      | 0.3x

与其他模型对比

方面NExT-OMNIBAGELEMMAMogao
生成范式离散FlowARARAR
并行生成
任意到任意部分部分部分
跨模态检索
多轮交互

总结

NExT-OMNI通过离散Flow Matching范式实现真正的全能多模态建模:

  1. 度量诱导路径:根据语义相似度自适应调整Flow
  2. 最优速度调度:根据难度自适应采样步数
  3. 统一离散空间:支持任意模态的映射和对齐
  4. 跨模态检索:原生支持检索任务

NExT-OMNI展示了Flow Matching在统一多模态建模中的潜力,为下一代多模态基础模型提供了新的设计范式。

Footnotes

  1. Source: NExT-OMNI: Towards Any-to-Any Omnimodal Foundation Models with Discrete Flow Matching