LLaVA:大型多模态模型

LLaVA(Large Language and Vision Assistant)是开源多模态大模型的重要代表,由Microsoft和威斯康星大学麦迪逊分校于2023年提出。1 它首次验证了”仅使用语言模型进行多模态训练”范式的可行性。

核心思想

关键洞察

LLaVA的核心洞察是:GPT-4展示的多模态能力可能主要来自语言模型的推理能力,而非视觉编码器本身

因此,LLaVA的设计目标是:

  1. 保持预训练LLM的推理能力
  2. 仅训练轻量级视觉-语言对齐模块

架构概览

图像 → CLIP ViT → 视觉特征 → 投影层(W) → 视觉tokens → [USER] <image>...<stop> [ASSISTANT]...
                                                                              ↓
文本 → Tokenizer → 文本tokens → ←——————————————— LLaMA ———————————————→ 输出

模型架构

LLaVA 1.0

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LLaVA(nn.Module):
    def __init__(self, vision_encoder, language_model, embed_dim=4096, 
                 vision_dim=1024, image_size=224):
        super().__init__()
        self.vision_encoder = vision_encoder  # CLIP ViT-L/14
        self.language_model = language_model  # Vicuna/LLaMA
        
        # 投影层:线性映射 or 2层MLP
        self.projection = nn.Linear(vision_dim, language_model.hidden_size)
    
    def forward(self, images, input_ids, attention_mask=None, labels=None):
        # 1. 视觉编码
        image_features = self.vision_encoder(images)  # (B, num_patches, vision_dim)
        
        # 2. 投影到语言空间
        image_tokens = self.projection(image_features)  # (B, num_patches, hidden)
        
        # 3. 获取文本embedding
        text_tokens = self.language_model.get_input_embeddings()(input_ids)
        
        # 4. 拼接视觉tokens和文本tokens
        # <image> token表示视觉内容的起始位置
        # <stop> token表示视觉内容的结束位置
        inputs_embeds = self._merge_inputs(image_tokens, text_tokens, input_ids)
        
        # 5. LLM前向
        outputs = self.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs
    
    def _merge_inputs(self, image_tokens, text_tokens, input_ids):
        """
        将视觉tokens嵌入到文本序列中
        假设 <image> token ID 为 32000
        """
        IMAGE_TOKEN_ID = 32000
        device = image_tokens.device
        
        # 找到 <image> token的位置
        image_mask = (input_ids == IMAGE_TOKEN_ID)
        
        # 获取每个batch中 <image> 之前的文本
        batch_size = input_ids.shape[0]
        merged = []
        
        for b in range(batch_size):
            img_tok = image_tokens[b]  # (num_patches, hidden)
            txt_emb = text_tokens[b]   # (seq_len, hidden)
            
            # 找到image token的位置
            img_positions = (input_ids[b] == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
            
            if len(img_positions) > 0:
                pos = img_positions[0].item()
                # 替换image token为视觉tokens
                new_emb = torch.cat([
                    txt_emb[:pos],           # 之前的文本
                    img_tok,                 # 视觉tokens
                    txt_emb[pos+1:]          # 之后的文本
                ], dim=0)
            else:
                new_emb = txt_emb
            
            merged.append(new_emb)
        
        # Padding到相同长度
        max_len = max(e.shape[0] for e in merged)
        result = torch.zeros(batch_size, max_len, text_tokens.shape[-1], device=device)
        for b, e in enumerate(merged):
            result[b, :e.shape[0]] = e
        
        return result

LLaVA 1.5:MLP投影 + 高分辨率

class LLaVA15(nn.Module):
    def __init__(self, vision_encoder, language_model):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.language_model = language_model
        
        # 改进:2层MLP投影
        vision_dim = vision_encoder.embed_dim
        hidden_dim = 4096  # 或与LLM对齐
        
        self.projection = nn.Sequential(
            nn.Linear(vision_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    
    def forward(self, images, input_ids, ...):
        # 视觉编码
        image_features = self.vision_encoder(images)
        
        # MLP投影
        image_tokens = self.projection(image_features)
        
        # 后续与LLaVA 1.0相同...

LLaVA-NeXT:高分辨率与多图像

class LLaVANeXT(nn.Module):
    def __init__(self, vision_encoder, language_model, patch_size=14):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.language_model = language_model
        self.patch_size = patch_size
        
        # 高分辨率处理:将图像分成多个子图
        self.projection = nn.Linear(
            vision_encoder.embed_dim, 
            language_model.hidden_size
        )
    
    def encode_images_high_res(self, image, grid_size=2):
        """
        将高分辨率图像分成多个子图分别编码
        grid_size: 2x2 网格
        """
        B, C, H, W = image.shape
        h, w = H // grid_size, W // grid_size
        
        patches = []
        for i in range(grid_size):
            for j in range(grid_size):
                patch = image[:, :, i*h:(i+1)*h, j*w:(j+1)*w]
                features = self.vision_encoder(patch)
                patches.append(features)
        
        # 全局特征
        global_features = self.vision_encoder(image)
        
        # 拼接所有特征
        all_features = torch.cat([global_features] + patches, dim=1)
        return all_features
    
    def forward(self, images, input_ids, grid_size=2, ...):
        # 支持多图像
        all_visual_tokens = []
        for img in images:
            if img.dim() == 3:
                img = img.unsqueeze(0)
            visual_tokens = self.encode_images_high_res(img, grid_size)
            all_visual_tokens.append(visual_tokens)
        
        # 拼接多图像的视觉tokens
        visual_tokens = torch.cat(all_visual_tokens, dim=1)
        visual_tokens = self.projection(visual_tokens)
        
        # 与文本tokens合并...

训练流程

两阶段训练

LLaVA采用两阶段训练范式:

阶段1: 预训练(特征对齐)
┌─────────────────┐     ┌─────────────────┐
│  CLIP ViT-L/14  │     │     Vicuna      │
│  (冻结)          │     │     (冻结)       │
└────────┬────────┘     └────────┬────────┘
         │                       │
         ↓                       ↓
    ┌────────────────────────────────┐
    │      投影层W (可训练)           │
    │  学习:将CLIP特征映射到LLM空间   │
    └────────────────────────────────┘
    目标:LLM能够理解CLIP提取的视觉信息
    
阶段2: 指令微调
┌─────────────────┐     ┌─────────────────┐
│  CLIP ViT-L/14  │     │     Vicuna      │
│  (冻结)          │     │    (微调)       │
└────────┬────────┘     └────────┬────────┘
         │                       │
         ↓                       ↓
    ┌────────────────────────────────┐
    │      投影层W (冻结)            │
    │  保持视觉-语言对齐             │
    └────────────────────────────────┘
    目标:让模型学会遵循视觉指令

训练数据

阶段1:预训练数据

使用CC3M数据集的子集(595K图文对),进行语言-图像对齐。

# 预训练配置
pretrain_config = {
    "batch_size": 128,
    "learning_rate": 1e-3,
    "epochs": 1,
    "optimizer": "AdamW",
    "trainable_layers": ["projection"],  # 仅训练投影层
}

阶段2:指令微调数据

使用GPT-4生成的视觉指令数据:

数据集规模描述
LLaVA-Instruct158K图像对话、详细描述、复杂推理
ShareGPT多轮对话转换为视觉问答格式
VQAv2真实问答补充真实场景
# 指令微调数据格式示例
instruction_data = {
    "conversations": [
        {"from": "human", "value": "<image>\nDescribe this image in detail."},
        {"from": "gpt", "value": "This image shows a serene landscape with mountains in the background..."}
    ],
    "image": "path/to/image.jpg"
}
 
# 训练配置
finetune_config = {
    "batch_size": 32,
    "learning_rate": 2e-5,
    "epochs": 3,
    "trainable_layers": ["projection", "llm"],  # 训练投影层和LLM
    "lora": True,  # 可选:使用LoRA微调LLM
}

LLaVA-OneVision

LLaVA-OneVision (2024) 是LLaVA系列的集大成者:

  • 单图像: 高分辨率处理
  • 多图像: 图像对比、关系推理
  • 视频: 帧级理解、时序推理
  • 统一架构: 一个模型处理所有任务
class LLaVAOneVision(nn.Module):
    def __init__(self, vision_encoder, language_model):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.language_model = language_model
        self.mm_projector = nn.Linear(
            vision_encoder.embed_dim, 
            language_model.hidden_size
        )
    
    def encode_multimodal(self, images, is_video=False):
        if is_video:
            # 视频:采样多帧
            frames = self.sample_video_frames(images)
            features = torch.stack([self.vision_encoder(f) for f in frames], dim=1)
            features = features.flatten(1, 2)  # (B, num_frames*num_patches, dim)
        else:
            features = self.vision_encoder(images)
        
        return self.mm_projector(features)

与GPT-4V的对比

特性LLaVAGPT-4V
视觉编码器CLIP ViT专有设计
语言模型Vicuna/LLaMAGPT-4
训练数据公开数据专有数据
开源
能力差距仍有差距最强baseline

LLaVA的优势

  1. 开源可复现 - 模型权重、训练代码开源
  2. 高效微调 - 支持LoRA、QLoRA等PEFT方法
  3. 社区生态 - 持续改进和扩展

LLaVA的局限

  1. 视觉理解能力受限于CLIP
  2. 推理能力受限于开源LLM
  3. 长上下文处理能力有限

PEFT应用:LLaVA + LoRA

from peft import LoraConfig, get_peft_model
 
class LLaVALoRA:
    @staticmethod
    def apply_lora(model, r=16, alpha=32, target_modules=["q_proj", "v_proj"]):
        lora_config = LoraConfig(
            r=r,
            lora_alpha=alpha,
            target_modules=target_modules,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        
        # 仅对LLM应用LoRA
        model.language_model = get_peft_model(
            model.language_model, 
            lora_config
        )
        return model

与现有内容的衔接

关联内容
TransformerLLaMA架构基础
LoRALLaVA的高效微调
PEFT多模态模型的高效微调技术
MoE专家模型在多模态中的应用
CLIP视觉编码器基础

参考文献

Footnotes

  1. Liu et al., Visual Instruction Tuning, NeurIPS 2023