Prompt Tuning方法详解

Prompt Tuning是一类通过修改输入提示来实现高效微调的方法,不需要修改模型参数。

分类概览

┌─────────────────────────────────────────────────────────────────┐
│                        Prompt Tuning                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐  │
│  │ Soft Prompt     │  │ Prefix Tuning   │  │ P-Tuning        │  │
│  │ (Prompt Tunning)│  │                 │  │                 │  │
│  │                 │  │                 │  │                 │  │
│  │ 可学习的连续     │  │ 可学习的连续     │  │ 连续提示 +      │  │
│  │ 提示向量         │  │ 前缀向量         │  │ MLP编码器       │  │
│  └─────────────────┘  └─────────────────┘  └─────────────────┘  │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

1. Soft Prompt (Prompt Tuning)

核心思想

将离散的token替换为可学习的连续向量:

传统Prompt:
[CLS] Sentiment: The movie is [MASK] [SEP]
         ↓ 离散token,固定
         ↓ 可学习的连续向量
Soft Prompt:
[CLS] [v₁] [v₂] [v₃] [v₄] [v₅] The movie is [SEP]
       ↑ 可学习参数
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
 
class SoftPrompt(nn.Module):
    """
    Soft Prompt: 可学习的连续提示向量
    """
    def __init__(self, 
                 num_virtual_tokens: int,    # 软提示长度
                 embedding_dim: int,          # 嵌入维度
                 init_type: str = "uniform"): # 初始化方式
        super().__init__()
        self.num_virtual_tokens = num_virtual_tokens
        
        # 可学习的软提示嵌入
        self.soft_prompt = nn.Parameter(
            torch.randn(num_virtual_tokens, embedding_dim)
        )
        
        # 不同初始化方式
        if init_type == "uniform":
            nn.init.uniform_(self.soft_prompt, -1, 1)
        elif init_type == "normal":
            nn.init.normal_(self.soft_prompt, mean=0, std=0.02)
        elif init_type == "text":
            # 使用真实词嵌入初始化
            pass
    
    def forward(self, batch_size: int) -> torch.Tensor:
        """
        复制软提示到指定batch大小
        """
        return self.soft_prompt.unsqueeze(0).expand(batch_size, -1, -1)
 
 
class PromptTuningModel(nn.Module):
    """
    带Prompt Tuning的完整模型
    """
    def __init__(self, 
                 model_name: str,
                 num_virtual_tokens: int = 20):
        super().__init__()
        
        # 加载预训练模型
        self.backbone = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.config = self.backbone.config
        
        # Soft Prompt
        self.soft_prompt = SoftPrompt(
            num_virtual_tokens=num_virtual_tokens,
            embedding_dim=self.config.hidden_size
        )
        
        # 冻结backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, attention_mask, labels=None):
        batch_size = input_ids.shape[0]
        
        # 获取文本嵌入
        inputs_embeds = self.backbone.get_input_embeddings()(input_ids)
        
        # 获取软提示嵌入
        soft_embeds = self.soft_prompt(batch_size)
        
        # 拼接: [batch, num_virtual_tokens + seq_len, hidden]
        inputs_embeds = torch.cat([soft_embeds, inputs_embeds], dim=1)
        
        # 调整attention mask
        soft_mask = torch.ones(batch_size, self.soft_prompt.num_virtual_tokens, 
                              device=attention_mask.device)
        attention_mask = torch.cat([soft_mask, attention_mask], dim=1)
        
        # 前向传播
        outputs = self.backbone(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )
        
        return outputs

Prompt Tuning配置

from peft import PromptTuningConfig, get_peft_model
 
config = PromptTuningConfig(
    task_type="CAUSAL_LM",
    num_virtual_tokens=20,        # 软提示长度
    prompt_tuning_init="TEXT",   # 初始化方式
    prompt_tuning_init_text="Translate the following to French:",  # 初始文本
    tokenizer_name_or_path="gpt2"
)
 
model = get_peft_model(base_model, config)

2. Prefix Tuning

核心思想

在每层Transformer的注意力计算前添加可学习前缀:

原始Attention:
Q = W_q x, K = W_k x, V = W_v x
Attention = softmax(QK^T / √d) V

Prefix Tuning:
Q = W_q x, K = W_k [prefix; x], V = W_v [prefix; x]
                  ↑ 可学习前缀拼接在前面
class PrefixTuningConfig:
    """Prefix Tuning配置"""
    def __init__(self,
                 num_virtual_tokens: int = 20,
                 encoder_hidden_size: int = None,
                 encoder_num_layers: int = 2,
                 prefix_projection: bool = False):
        self.num_virtual_tokens = num_virtual_tokens
        self.encoder_hidden_size = encoder_hidden_size
        self.encoder_num_layers = encoder_num_layers
        self.prefix_projection = prefix_projection
 
 
class PrefixTuningEmbedding(nn.Module):
    """Prefix Embedding层"""
    def __init__(self, config: PrefixTuningConfig, model_config):
        super().__init__()
        self.num_virtual_tokens = config.num_virtual_tokens
        self.hidden_size = model_config.hidden_size
        
        # 可学习的prefix embedding
        self.prefix_tokens = nn.Parameter(
            torch.zeros(2, num_virtual_tokens, self.hidden_size)  # 2: 前缀(不用于生成)和后缀(用于生成)
        )
        
        # 可选的MLP投影层
        if config.prefix_projection:
            self.prefix_encoder = nn.Sequential(
                nn.Linear(self.hidden_size, config.encoder_hidden_size),
                nn.Tanh(),
                nn.Linear(config.encoder_hidden_size, self.hidden_size)
            )
        else:
            self.prefix_encoder = None
    
    def get_prefix(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取前缀向量"""
        prefix = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1, -1)
        
        if self.prefix_encoder is not None:
            prefix = self.prefix_encoder(prefix)
        
        return prefix[:, 0], prefix[:, 1]  # (prefix_emb, past_key_values_prefix)

完整Prefix Tuning实现

class PrefixTuningModel(nn.Module):
    """
    Prefix Tuning完整实现
    """
    def __init__(self, base_model, config: PrefixTuningConfig):
        super().__init__()
        self.base_model = base_model
        self.config = config
        
        # 冻结base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Prefix embedding
        self.prefix_emb = PrefixTuningEmbedding(config, base_model.config)
    
    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        batch_size = input_ids.shape[0]
        
        # 获取prefix
        prefix_kv, prefix_q = self.prefix_emb.get_prefix(batch_size)
        
        # 将prefix拼接到输入
        # ... (具体实现依赖于base_model类型)
        
        return self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            past_key_values=past_key_values,  # 预计算的prefix kv
            use_cache=True
        )

Prefix Tuning vs Soft Prompt

特性Prefix TuningSoft Prompt
插入位置每层注意力前仅输入嵌入层
参数量较多较少
表达能力更强较弱
训练难度较难较易
效果通常更好尚可

3. P-Tuning

核心思想

使用MLP编码器将可学习的伪token映射到模型输入空间:

P-Tuning v1:
[CLS] [v₁] [v₂] ... [vₘ] [SEP] → MLP编码 → [e₁] [e₂] ... [eₘ] → 原始输入

P-Tuning v2:
在每层都添加可学习的伪token,类似Prefix Tuning但使用MLP编码
class PTokenEncoder(nn.Module):
    """
    P-Tuning的伪token编码器
    """
    def __init__(self, 
                 num_virtual_tokens: int,
                 hidden_size: int,
                 encoder_hidden_size: int = 128,
                 num_layers: int = 2,
                 dropout: float = 0.0):
        super().__init__()
        
        # 可学习的伪token ID
        self.pseudo_token_ids = nn.Parameter(
            torch.arange(num_virtual_tokens),
            requires_grad=False  # 不更新,只作为索引
        )
        
        # MLP编码器
        encoder_layers = []
        input_size = hidden_size
        
        for _ in range(num_layers - 1):
            encoder_layers.extend([
                nn.Linear(input_size, encoder_hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            input_size = encoder_hidden_size
        
        encoder_layers.append(nn.Linear(input_size, hidden_size))
        
        self.encoder = nn.Sequential(*encoder_layers)
        
        # 初始化伪token
        self.embedding = nn.Embedding(num_virtual_tokens, hidden_size)
    
    def forward(self, batch_size: int) -> torch.Tensor:
        # 获取伪token的embedding
        pseudo_tokens = self.embedding(self.pseudo_token_ids)
        
        # 通过MLP编码
        encoded = self.encoder(pseudo_tokens)
        
        # 复制到batch维度
        return encoded.unsqueeze(0).expand(batch_size, -1, -1)
 
 
class PTuningModel(nn.Module):
    """
    P-Tuning模型
    """
    def __init__(self, base_model, num_virtual_tokens=20):
        super().__init__()
        self.base_model = base_model
        
        # P-Token编码器
        self.p_token_encoder = PTokenEncoder(
            num_virtual_tokens=num_virtual_tokens,
            hidden_size=base_model.config.hidden_size
        )
        
        # 冻结base model
        for param in self.base_model.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, labels=None):
        batch_size = input_ids.shape[0]
        
        # 编码伪token
        pseudo_embeds = self.p_token_encoder(batch_size)
        
        # 获取原始输入嵌入
        inputs_embeds = self.base_model.get_input_embeddings()(input_ids)
        
        # 拼接伪token和原始嵌入
        inputs_embeds = torch.cat([pseudo_embeds, inputs_embeds], dim=1)
        
        # 前向传播
        return self.base_model(
            inputs_embeds=inputs_embeds,
            labels=labels
        )

P-Tuning v2

P-Tuning v2在每层都添加可学习提示,性能接近全量微调:

class PTuningV2Config:
    """P-Tuning v2配置"""
    num_layers = 24      # 与模型层数相同
    num_virtual_tokens = 16
    hidden_size = 1024
    layer_prefix_tokens = 16  # 每层的前缀token数
 
class PTuningV2Model(nn.Module):
    """
    P-Tuning v2: 每层都有可学习提示
    """
    def __init__(self, base_model, config: PTuningV2Config):
        super().__init__()
        self.base_model = base_model
        self.config = config
        
        # 每层独立的前缀提示
        self.layer_prefixes = nn.ParameterList([
            nn.Parameter(torch.randn(
                config.layer_prefix_tokens, 
                config.hidden_size
            ))
            for _ in range(config.num_layers)
        ])
    
    def _inject_prefix(self, layer_idx: int, hidden_states: torch.Tensor) -> torch.Tensor:
        """在特定层注入前缀"""
        prefix = self.layer_prefixes[layer_idx].unsqueeze(0)
        batch_size = hidden_states.shape[0]
        prefix = prefix.expand(batch_size, -1, -1)
        
        return torch.cat([prefix, hidden_states], dim=1)

4. 比较与选择

方法对比

方法可训练参数推理开销表达能力收敛速度效果
Soft Prompt<0.1%中等较好
Prefix Tuning0.1-3%中等
P-Tuning v1<0.1%MLP前向中等尚可
P-Tuning v20.1-1%中等很好

选择指南

┌─────────────────────────────────────────────────────────────────┐
│                         PEFT方法选择                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  模型规模:                                                       │
│  ├─ 超大模型 (>100B)  →  Prompt Tuning / Prefix Tuning         │
│  └─ 中大模型 (<100B)  →  LoRA / Adapter                         │
│                                                                 │
│  任务类型:                                                       │
│  ├─ 文本生成     →  Prefix Tuning / LoRA                       │
│  ├─ 分类任务     →  Soft Prompt / LoRA                         │
│  └─ 多任务学习   →  Adapter / Multi-LoRA                        │
│                                                                 │
│  资源限制:                                                       │
│  ├─ 极端受限     →  Prompt Tuning                               │
│  ├─ 资源有限     →  LoRA / QLoRA                               │
│  └─ 资源充足     →  Full FT / Adapter                          │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

5. 提示学习的理论分析

连续提示的表达能力

# 定理:连续提示可以近似任意离散提示
class PromptExpressivity:
    """
    提示表达能力分析
    """
    
    @staticmethod
    def can_represent():
        """
        定理:在足够大的嵌入空间中,
        连续提示可以表示任意离散提示序列
        """
        # 离散提示: [CLS] "Translate to French" [SEP]
        #           ↓ 编码
        # 连续提示: [e₁, e₂, ..., eₘ]
        #
        # 由于嵌入空间通常是高维的(768-4096),
        # 而离散词表大小有限(~50000),
        # 存在足够自由度来表示任意序列
        pass
    
    @staticmethod
    def optimal_init():
        """
        提示初始化策略
        
        1. 文本初始化:使用相关任务的真实prompt词嵌入
           - 收敛更快
           - 效果更好
        
        2. 随机初始化:使用均匀/正态分布
           - 更通用
           - 需要更多训练步数
        """
        pass

提示长度与效果的关系

import matplotlib.pyplot as plt
import numpy as np
 
def plot_prompt_length_effect():
    """
    软提示长度vs效果的关系图
    
    典型曲线:
    - 长度 < 5: 欠拟合,表达能力不足
    - 长度 5-20: 快速提升期
    - 长度 20-100: 边际收益递减
    - 长度 > 100: 近似全量微调
    """
    lengths = np.arange(1, 101)
    # 模拟效果曲线 (sigmoid-like)
    performance = 50 + 45 * (1 - np.exp(-lengths / 20))
    
    return lengths, performance
 
# 图示效果
"""
性能
 ^
 |        ┌─────────────────────────
 |       /
 |      /
 |     /
 |    /
 |   /
 |  /
 | /
 +--------------------------------→ 提示长度
   5    20   50   100

参考