LLaVA:大型多模态模型
LLaVA(Large Language and Vision Assistant)是开源多模态大模型的重要代表,由Microsoft和威斯康星大学麦迪逊分校于2023年提出。1 它首次验证了”仅使用语言模型进行多模态训练”范式的可行性。
核心思想
关键洞察
LLaVA的核心洞察是:GPT-4展示的多模态能力可能主要来自语言模型的推理能力,而非视觉编码器本身。
因此,LLaVA的设计目标是:
- 保持预训练LLM的推理能力
- 仅训练轻量级视觉-语言对齐模块
架构概览
图像 → CLIP ViT → 视觉特征 → 投影层(W) → 视觉tokens → [USER] <image>...<stop> [ASSISTANT]...
↓
文本 → Tokenizer → 文本tokens → ←——————————————— LLaMA ———————————————→ 输出
模型架构
LLaVA 1.0
import torch
import torch.nn as nn
import torch.nn.functional as F
class LLaVA(nn.Module):
def __init__(self, vision_encoder, language_model, embed_dim=4096,
vision_dim=1024, image_size=224):
super().__init__()
self.vision_encoder = vision_encoder # CLIP ViT-L/14
self.language_model = language_model # Vicuna/LLaMA
# 投影层:线性映射 or 2层MLP
self.projection = nn.Linear(vision_dim, language_model.hidden_size)
def forward(self, images, input_ids, attention_mask=None, labels=None):
# 1. 视觉编码
image_features = self.vision_encoder(images) # (B, num_patches, vision_dim)
# 2. 投影到语言空间
image_tokens = self.projection(image_features) # (B, num_patches, hidden)
# 3. 获取文本embedding
text_tokens = self.language_model.get_input_embeddings()(input_ids)
# 4. 拼接视觉tokens和文本tokens
# <image> token表示视觉内容的起始位置
# <stop> token表示视觉内容的结束位置
inputs_embeds = self._merge_inputs(image_tokens, text_tokens, input_ids)
# 5. LLM前向
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
return outputs
def _merge_inputs(self, image_tokens, text_tokens, input_ids):
"""
将视觉tokens嵌入到文本序列中
假设 <image> token ID 为 32000
"""
IMAGE_TOKEN_ID = 32000
device = image_tokens.device
# 找到 <image> token的位置
image_mask = (input_ids == IMAGE_TOKEN_ID)
# 获取每个batch中 <image> 之前的文本
batch_size = input_ids.shape[0]
merged = []
for b in range(batch_size):
img_tok = image_tokens[b] # (num_patches, hidden)
txt_emb = text_tokens[b] # (seq_len, hidden)
# 找到image token的位置
img_positions = (input_ids[b] == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
if len(img_positions) > 0:
pos = img_positions[0].item()
# 替换image token为视觉tokens
new_emb = torch.cat([
txt_emb[:pos], # 之前的文本
img_tok, # 视觉tokens
txt_emb[pos+1:] # 之后的文本
], dim=0)
else:
new_emb = txt_emb
merged.append(new_emb)
# Padding到相同长度
max_len = max(e.shape[0] for e in merged)
result = torch.zeros(batch_size, max_len, text_tokens.shape[-1], device=device)
for b, e in enumerate(merged):
result[b, :e.shape[0]] = e
return resultLLaVA 1.5:MLP投影 + 高分辨率
class LLaVA15(nn.Module):
def __init__(self, vision_encoder, language_model):
super().__init__()
self.vision_encoder = vision_encoder
self.language_model = language_model
# 改进:2层MLP投影
vision_dim = vision_encoder.embed_dim
hidden_dim = 4096 # 或与LLM对齐
self.projection = nn.Sequential(
nn.Linear(vision_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
def forward(self, images, input_ids, ...):
# 视觉编码
image_features = self.vision_encoder(images)
# MLP投影
image_tokens = self.projection(image_features)
# 后续与LLaVA 1.0相同...LLaVA-NeXT:高分辨率与多图像
class LLaVANeXT(nn.Module):
def __init__(self, vision_encoder, language_model, patch_size=14):
super().__init__()
self.vision_encoder = vision_encoder
self.language_model = language_model
self.patch_size = patch_size
# 高分辨率处理:将图像分成多个子图
self.projection = nn.Linear(
vision_encoder.embed_dim,
language_model.hidden_size
)
def encode_images_high_res(self, image, grid_size=2):
"""
将高分辨率图像分成多个子图分别编码
grid_size: 2x2 网格
"""
B, C, H, W = image.shape
h, w = H // grid_size, W // grid_size
patches = []
for i in range(grid_size):
for j in range(grid_size):
patch = image[:, :, i*h:(i+1)*h, j*w:(j+1)*w]
features = self.vision_encoder(patch)
patches.append(features)
# 全局特征
global_features = self.vision_encoder(image)
# 拼接所有特征
all_features = torch.cat([global_features] + patches, dim=1)
return all_features
def forward(self, images, input_ids, grid_size=2, ...):
# 支持多图像
all_visual_tokens = []
for img in images:
if img.dim() == 3:
img = img.unsqueeze(0)
visual_tokens = self.encode_images_high_res(img, grid_size)
all_visual_tokens.append(visual_tokens)
# 拼接多图像的视觉tokens
visual_tokens = torch.cat(all_visual_tokens, dim=1)
visual_tokens = self.projection(visual_tokens)
# 与文本tokens合并...训练流程
两阶段训练
LLaVA采用两阶段训练范式:
阶段1: 预训练(特征对齐)
┌─────────────────┐ ┌─────────────────┐
│ CLIP ViT-L/14 │ │ Vicuna │
│ (冻结) │ │ (冻结) │
└────────┬────────┘ └────────┬────────┘
│ │
↓ ↓
┌────────────────────────────────┐
│ 投影层W (可训练) │
│ 学习:将CLIP特征映射到LLM空间 │
└────────────────────────────────┘
目标:LLM能够理解CLIP提取的视觉信息
阶段2: 指令微调
┌─────────────────┐ ┌─────────────────┐
│ CLIP ViT-L/14 │ │ Vicuna │
│ (冻结) │ │ (微调) │
└────────┬────────┘ └────────┬────────┘
│ │
↓ ↓
┌────────────────────────────────┐
│ 投影层W (冻结) │
│ 保持视觉-语言对齐 │
└────────────────────────────────┘
目标:让模型学会遵循视觉指令
训练数据
阶段1:预训练数据
使用CC3M数据集的子集(595K图文对),进行语言-图像对齐。
# 预训练配置
pretrain_config = {
"batch_size": 128,
"learning_rate": 1e-3,
"epochs": 1,
"optimizer": "AdamW",
"trainable_layers": ["projection"], # 仅训练投影层
}阶段2:指令微调数据
使用GPT-4生成的视觉指令数据:
| 数据集 | 规模 | 描述 |
|---|---|---|
| LLaVA-Instruct | 158K | 图像对话、详细描述、复杂推理 |
| ShareGPT | 多轮对话 | 转换为视觉问答格式 |
| VQAv2 | 真实问答 | 补充真实场景 |
# 指令微调数据格式示例
instruction_data = {
"conversations": [
{"from": "human", "value": "<image>\nDescribe this image in detail."},
{"from": "gpt", "value": "This image shows a serene landscape with mountains in the background..."}
],
"image": "path/to/image.jpg"
}
# 训练配置
finetune_config = {
"batch_size": 32,
"learning_rate": 2e-5,
"epochs": 3,
"trainable_layers": ["projection", "llm"], # 训练投影层和LLM
"lora": True, # 可选:使用LoRA微调LLM
}LLaVA-OneVision
LLaVA-OneVision (2024) 是LLaVA系列的集大成者:
- 单图像: 高分辨率处理
- 多图像: 图像对比、关系推理
- 视频: 帧级理解、时序推理
- 统一架构: 一个模型处理所有任务
class LLaVAOneVision(nn.Module):
def __init__(self, vision_encoder, language_model):
super().__init__()
self.vision_encoder = vision_encoder
self.language_model = language_model
self.mm_projector = nn.Linear(
vision_encoder.embed_dim,
language_model.hidden_size
)
def encode_multimodal(self, images, is_video=False):
if is_video:
# 视频:采样多帧
frames = self.sample_video_frames(images)
features = torch.stack([self.vision_encoder(f) for f in frames], dim=1)
features = features.flatten(1, 2) # (B, num_frames*num_patches, dim)
else:
features = self.vision_encoder(images)
return self.mm_projector(features)与GPT-4V的对比
| 特性 | LLaVA | GPT-4V |
|---|---|---|
| 视觉编码器 | CLIP ViT | 专有设计 |
| 语言模型 | Vicuna/LLaMA | GPT-4 |
| 训练数据 | 公开数据 | 专有数据 |
| 开源 | 是 | 否 |
| 能力差距 | 仍有差距 | 最强baseline |
LLaVA的优势
- 开源可复现 - 模型权重、训练代码开源
- 高效微调 - 支持LoRA、QLoRA等PEFT方法
- 社区生态 - 持续改进和扩展
LLaVA的局限
- 视觉理解能力受限于CLIP
- 推理能力受限于开源LLM
- 长上下文处理能力有限
PEFT应用:LLaVA + LoRA
from peft import LoraConfig, get_peft_model
class LLaVALoRA:
@staticmethod
def apply_lora(model, r=16, alpha=32, target_modules=["q_proj", "v_proj"]):
lora_config = LoraConfig(
r=r,
lora_alpha=alpha,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# 仅对LLM应用LoRA
model.language_model = get_peft_model(
model.language_model,
lora_config
)
return model与现有内容的衔接
| 关联 | 内容 |
|---|---|
| Transformer | LLaMA架构基础 |
| LoRA | LLaVA的高效微调 |
| PEFT | 多模态模型的高效微调技术 |
| MoE | 专家模型在多模态中的应用 |
| CLIP | 视觉编码器基础 |
参考文献
Footnotes
-
Liu et al., Visual Instruction Tuning, NeurIPS 2023 ↩