概述

Temporal Fusion Transformer(TFT)是 Google Cloud AI 于2019年提出的时间序列预测架构,它结合了 Transformer 的自注意力机制和专门为时间序列设计的多尺度处理模块。1

TFT 的核心特点是可解释性:通过可视化注意力权重、变量重要性等,帮助理解模型决策。这使其在需要模型解释的商业场景(如金融预测、能源调度)中具有独特优势。


一、模型架构概览

1.1 整体结构

TFT 采用序列到序列(Seq2Seq)架构,包含以下核心组件:

┌─────────────────────────────────────────────────────────────────┐
│                        TFT 架构                                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌──────────────┐                                               │
│  │   输入处理   │  Static Covariates → Static Context          │
│  │   (Input     │  Known/Unknown Inputs → Temporal Features    │
│  │    Processing)│                                              │
│  └──────┬───────┘                                               │
│         │                                                       │
│  ┌──────▼───────┐                                               │
│  │   门控残差   │  Gated Residual Network (GRN)                 │
│  │   网络(GRN) │  跳过连接 + 可选层归一化                      │
│  └──────┬───────┘                                               │
│         │                                                       │
│  ┌──────▼───────┐                                               │
│  │  局部编码   │  Time-to-first-event encoding                  │
│  │ (Local       │  位置编码 + 线性投影                         │
│  │  Processing) │                                               │
│  └──────┬───────┘                                               │
│         │                                                       │
│  ┌──────▼───────┐                                               │
│  │  序列编码   │  LSTM Encoder                                 │
│  │ (Sequence    │  编码历史信息                                │
│  │  Encoding)  │                                               │
│  └──────┬───────┘                                               │
│         │                                                       │
│  ┌──────▼───────┐                                               │
│  │  多头注意力 │  Interpretable Multi-Head Attention            │
│  │  (Temporal   │  捕获长程依赖                                │
│  │  Self-       │                                               │
│  │  Attention) │                                               │
│  └──────┬───────┘                                               │
│         │                                                       │
│  ┌──────▼───────┐                                               │
│  │  序列解码   │  LSTM Decoder                                 │
│  │ (Sequence    │  解码未来信息                                │
│  │  Decoding)  │                                               │
│  └──────┬───────┘                                               │
│         │                                                       │
│  ┌──────▼───────┐                                               │
│  │   输出层    │  分位数输出层                                  │
│  │   (Output)  │  输出多个分位数预测                           │
│  └──────────────┘                                               │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

1.2 输入类型

TFT 将时间序列数据分为三类输入:

输入类型说明示例
Static Covariates不随时间变化的特征产品类别、地理位置
Known Future Inputs未来已知的信息节日日历、计划事件
Observed Inputs历史观察到的特征历史销量、过去价格

二、核心组件详解

2.1 门控残差网络(GRN)

GRN 是 TFT 的核心构建块,用于自适应控制信息流:

class GatedResidualNetwork(nn.Module):
    """
    门控残差网络 (Gated Residual Network)
    
    核心思想:通过门控机制控制残差连接的强度,
    让网络自适应决定是否需要跳过当前层
    """
    
    def __init__(self, input_dim, hidden_dim, output_dim, 
                 dropout=0.1, context_dim=None, activation='gelu'):
        super().__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.context_dim = context_dim
        self.hidden_dim = hidden_dim
        
        # 线性投影层
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        # 可选的上下文投影
        if context_dim is not None:
            self.ctx_fc1 = nn.Linear(context_dim, hidden_dim, bias=False)
        
        # 门控层
        self.gate_fc = nn.Linear(output_dim, output_dim)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # 激活函数
        self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
        
        # 可选:线性残差连接
        self.skip_fc = nn.Linear(input_dim, output_dim) if input_dim != output_dim else None
        
    def forward(self, x, context=None):
        """
        Args:
            x: 输入张量 [batch, ..., input_dim]
            context: 可选的上下文张量 [batch, ..., context_dim]
        Returns:
            输出张量 [batch, ..., output_dim]
        """
        # 保存原始输入用于残差连接
        skip = x if self.skip_fc is None else self.skip_fc(x)
        
        # 线性变换 + 激活
        h = self.fc1(x)
        if context is not None and self.context_dim is not None:
            h = h + self.ctx_fc1(context)
        h = self.activation(h)
        h = self.dropout(h)
        
        # 第二个线性变换
        h = self.fc2(h)
        h = self.dropout(h)
        
        # 门控机制
        gate = torch.sigmoid(self.gate_fc(h))
        output = torch.mul(gate, h)
        
        # 残差连接
        output = output + skip
        
        return output

GRN 的数学表述

其中:

  • 是主输入, 是上下文
  • 是激活函数(GLU门控)
  • 是逐元素乘法
  • 是可学习权重

2.2 变量选择网络

变量选择网络用于识别对预测最重要的输入特征:

class VariableSelectionNetwork(nn.Module):
    """
    变量选择网络
    
    使用注意力机制对输入变量加权,
    选出最相关的特征组合
    """
    
    def __init__(self, n_variables, hidden_dim, dropout=0.1):
        super().__init__()
        
        self.n_variables = n_variables
        self.hidden_dim = hidden_dim
        
        # 每个变量的独立GRN
        self.var_grns = nn.ModuleList([
            GatedResidualNetwork(1, hidden_dim, hidden_dim, dropout)
            for _ in range(n_variables)
        ])
        
        # 变量选择GRN
        self.selection_grn = GatedResidualNetwork(
            n_variables, hidden_dim, n_variables, dropout
        )
        
    def forward(self, x):
        """
        Args:
            x: 输入张量 [batch, time, n_variables]
        Returns:
            selected: 加权后的特征 [batch, time, hidden_dim]
            weights: 选择的权重 [batch, n_variables]
        """
        # 原始变量
        var_weights = self.selection_grn(x.mean(dim=1, keepdim=True).expand_as(x))
        var_weights = torch.softmax(var_weights, dim=-1)  # 归一化
        
        # 应用权重
        selected = []
        for i in range(self.n_variables):
            var_out = self.var_grns[i](x[..., i:i+1])  # 每个变量独立处理
            selected.append(var_out * var_weights[..., i:i+1])
        
        selected = torch.stack(selected, dim=-1).sum(dim=-1)
        
        return selected, var_weights

变量选择的数学表述

2.3 时间处理

位置编码

TFT 使用可学习的位置编码而非正弦编码:

class TimeDistributedEmbedding(nn.Module):
    """时间分布的嵌入层"""
    
    def __init__(self, n_tokens, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(n_tokens, embedding_dim)
        
    def forward(self, x):
        """
        Args:
            x: 时间索引 [batch, time]
        Returns:
            嵌入向量 [batch, time, embedding_dim]
        """
        return self.embedding(x)

LSTM 编码器-解码器

class LSTMSequenceEncoder(nn.Module):
    """LSTM 序列编码器"""
    
    def __init__(self, input_dim, hidden_dim, num_layers=1, dropout=0.1):
        super().__init__()
        
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
    def forward(self, x):
        """返回所有时间步的隐藏状态"""
        output, (h_n, c_n) = self.lstm(x)
        return output  # [batch, time, hidden_dim]
 
 
class InterpretableMultiHeadAttention(nn.Module):
    """
    可解释的多头自注意力
    
    改进点:
    1. 共享值矩阵以提高可解释性
    2. 输出投影使用组合而非独立头
    """
    
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 查询、键、值投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # 组合头输出的投影
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query, key, value: [batch, time, d_model]
            mask: 可选的注意力掩码
        Returns:
            output: [batch, time, d_model]
            attention_weights: [batch, n_heads, time, time]
        """
        batch_size = query.size(0)
        
        # 线性投影
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # 重塑为多头形式
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # 应用注意力
        context = torch.matmul(attention_weights, V)
        
        # 合并多头
        context = context.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        output = self.W_o(context)
        
        return output, attention_weights

三、完整模型实现

3.1 TFT 主模型

class TemporalFusionTransformer(nn.Module):
    """
    时间融合Transformer
    
    完整实现,包含所有核心组件
    """
    
    def __init__(self, n_static_features, n_known_features, n_observed_features,
                 d_model=128, n_heads=4, d_ffn=256, n_encoder_layers=2,
                 n_decoder_layers=2, dropout=0.1, n_quantiles=10):
        super().__init__()
        
        self.d_model = d_model
        self.n_quantiles = n_quantiles
        
        # ========== 静态特征处理 ==========
        self.static_context_grn = GatedResidualNetwork(
            n_static_features, d_model, d_model * 2, dropout
        )
        
        # ========== 时间特征处理 ==========
        # Known features
        self.known_grn = GatedResidualNetwork(
            n_known_features, d_model, d_model, dropout
        )
        
        # Observed features
        self.observed_grn = GatedResidualNetwork(
            n_observed_features, d_model, d_model, dropout
        )
        
        # ========== 变量选择 ==========
        self.encoder_vsn = VariableSelectionNetwork(
            n_known_features + n_observed_features, d_model, dropout
        )
        
        # ========== 序列编码/解码 ==========
        self.encoder_lstm = nn.LSTM(
            d_model, d_model, n_encoder_layers,
            dropout=dropout, batch_first=True
        )
        
        self.decoder_lstm = nn.LSTM(
            d_model, d_model, n_decoder_layers,
            dropout=dropout, batch_first=True
        )
        
        # ========== 自注意力 ==========
        self.self_attention = InterpretableMultiHeadAttention(
            d_model, n_heads, dropout
        )
        
        # ========== 输出层 ==========
        # 编码器和解码器的投影
        self.encoder_projection = nn.Linear(d_model, d_model)
        self.decoder_projection = nn.Linear(d_model, d_model)
        
        # 分位数输出
        quantiles = torch.linspace(0.1, 0.9, n_quantiles)
        self.register_buffer('quantiles', quantiles)
        
        self.output_layer = nn.Linear(d_model, n_quantiles)
        
    def forward(self, static_input, known_input, observed_input, 
                encoder_mask=None, decoder_mask=None):
        """
        Args:
            static_input: [batch, n_static_features]
            known_input: [batch, encoder_len + decoder_len, n_known_features]
            observed_input: [batch, encoder_len, n_observed_features]
            encoder_mask: [batch, encoder_len]
            decoder_mask: [batch, decoder_len]
        Returns:
            predictions: [batch, decoder_len, n_quantiles]
            attention_weights: [batch, n_heads, total_len, total_len]
        """
        batch_size = static_input.size(0)
        
        # ========== 静态上下文 ==========
        static_context = self.static_context_grn(static_input)
        static_encoder_context = static_context[:, :self.d_model]
        static_decoder_context = static_context[:, self.d_model:]
        
        # ========== 编码器输入 ==========
        # 合并 known 和 observed 输入
        encoder_combined = torch.cat([
            known_input[:, :known_input.size(1) - decoder_mask.size(1)],
            observed_input
        ], dim=-1)
        
        # 变量选择
        encoder_features, encoder_weights = self.encoder_vsn(encoder_combined)
        
        # 应用GRN
        encoder_features = self.known_grn(
            encoder_features, 
            context=static_encoder_context.unsqueeze(1).expand_as(encoder_features)
        )
        
        # ========== LSTM 编码 ==========
        encoder_output, _ = self.encoder_lstm(encoder_features)
        
        # ========== 解码器输入 ==========
        decoder_features = known_input[:, -decoder_mask.size(1):]
        decoder_features = self.known_grn(
            decoder_features,
            context=static_decoder_context.unsqueeze(1).expand_as(decoder_features)
        )
        
        # ========== LSTM 解码 ==========
        decoder_output, _ = self.decoder_lstm(
            decoder_features,
            (encoder_output[:, -1:, :], 
             torch.zeros_like(encoder_output[:, -1:, :]))
        )
        
        # ========== 自注意力 ==========
        # 合并编码器和解码器进行注意力计算
        total_len = encoder_output.size(1) + decoder_output.size(1)
        combined_sequence = torch.cat([encoder_output, decoder_output], dim=1)
        
        attention_output, attention_weights = self.self_attention(
            combined_sequence, combined_sequence, combined_sequence
        )
        
        # 分离编码器和解码器的注意力输出
        encoder_attention = attention_output[:, :encoder_output.size(1)]
        decoder_attention = attention_output[:, encoder_output.size(1):]
        
        # ========== 输出 ==========
        decoder_output = decoder_output + decoder_attention
        decoder_output = self.decoder_projection(decoder_output)
        
        predictions = self.output_layer(decoder_output)
        
        return predictions, attention_weights

3.2 数据预处理

def prepare_tft_data(df, target_col, time_col, 
                      static_cols, known_cols, observed_cols,
                      encoder_len, decoder_len):
    """
    为TFT准备数据
    """
    from sklearn.preprocessing import StandardScaler
    
    # 提取特征
    static = df[static_cols].values
    known = df[known_cols].values
    observed = df[observed_cols].values
    target = df[target_col].values
    
    # 归一化
    scaler = StandardScaler()
    target_scaled = scaler.fit_transform(target.reshape(-1, 1)).flatten()
    
    # 构建时间序列
    sequences = []
    for i in range(len(df) - encoder_len - decoder_len + 1):
        seq = {
            'static': static[i],
            'known_input': known[i:i+encoder_len+decoder_len],
            'observed_input': observed[i:i+encoder_len],
            'target': target_scaled[i+encoder_len:i+encoder_len+decoder_len]
        }
        sequences.append(seq)
    
    return sequences, scaler

3.3 训练与评估

def train_tft(model, train_loader, val_loader, epochs=50, lr=1e-3):
    """训练TFT模型"""
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=5, factor=0.5
    )
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        for batch in train_loader:
            static, known_in, observed_in, target = batch
            
            optimizer.zero_grad()
            
            predictions, _ = model(
                static, known_in, observed_in
            )
            
            # 分位数损失
            loss = quantile_loss(predictions, target, model.quantiles)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                static, known_in, observed_in, target = batch
                predictions, _ = model(static, known_in, observed_in)
                loss = quantile_loss(predictions, target, model.quantiles)
                val_loss += loss.item()
        
        scheduler.step(val_loss)
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch}: Train Loss = {train_loss/len(train_loader):.4f}, "
                  f"Val Loss = {val_loss/len(val_loader):.4f}")
 
 
def quantile_loss(pred, target, quantiles):
    """分位数损失"""
    losses = []
    for i, q in enumerate(quantiles):
        errors = target - pred[:, :, i]
        loss = torch.max((q - 1) * errors, q * errors)
        losses.append(loss.mean())
    return sum(losses) / len(quantiles)

四、可解释性特性

4.1 变量重要性

def get_variable_importance(model, dataloader, feature_names):
    """获取变量重要性"""
    model.eval()
    
    # 收集所有批次的权重
    all_weights = []
    with torch.no_grad():
        for batch in dataloader:
            static, known_in, observed_in, _ = batch
            _, weights = model.encoder_vsn(
                torch.cat([known_in, observed_in], dim=-1)
            )
            all_weights.append(weights.mean(dim=1))  # 时间平均
    
    weights = torch.cat(all_weights, dim=0).mean(dim=0)
    
    # 按重要性排序
    importance = dict(zip(feature_names, weights.numpy()))
    importance = dict(sorted(importance.items(), key=lambda x: -x[1]))
    
    return importance

4.2 注意力可视化

def visualize_attention(attention_weights, encoder_len, decoder_len,
                       time_index=None, save_path='attention.png'):
    """可视化注意力权重"""
    import matplotlib.pyplot as plt
    
    # 平均所有头和批次
    attn = attention_weights.mean(dim=(0, 1)).numpy()
    
    plt.figure(figsize=(12, 10))
    plt.imshow(attn, cmap='viridis', aspect='auto')
    plt.colorbar()
    
    # 添加分界线
    plt.axhline(y=encoder_len, color='white', linestyle='--', linewidth=2)
    plt.axvline(x=encoder_len, color='white', linestyle='--', linewidth=2)
    
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.title('Attention Weights Heatmap')
    
    if time_index is not None:
        plt.xticks(range(0, len(time_index), 24), 
                   [time_index[i] for i in range(0, len(time_index), 24)],
                   rotation=45)
        plt.yticks(range(0, len(time_index), 24),
                   [time_index[i] for i in range(0, len(time_index), 24)])
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

五、与其它模型的对比

特性TFTN-BEATSN-HiTSLSTM
可解释性✓✓✓✓✓
多变量支持✓✓✓✓✓
静态特征✓✓✓
长序列处理✓✓✓✓
计算效率✓✓✓✓✓✓✓
分位数预测✓✓✓

六、参考


相关阅读

Footnotes

  1. Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021). Temporal Fusion Transformers for interpretable multi-horizon time series forecasting. International Journal of Forecasting, 37(4), 1748-1764.