概述
Temporal Fusion Transformer(TFT)是 Google Cloud AI 于2019年提出的时间序列预测架构,它结合了 Transformer 的自注意力机制和专门为时间序列设计的多尺度处理模块。1
TFT 的核心特点是可解释性:通过可视化注意力权重、变量重要性等,帮助理解模型决策。这使其在需要模型解释的商业场景(如金融预测、能源调度)中具有独特优势。
一、模型架构概览
1.1 整体结构
TFT 采用序列到序列(Seq2Seq)架构,包含以下核心组件:
┌─────────────────────────────────────────────────────────────────┐
│ TFT 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ │
│ │ 输入处理 │ Static Covariates → Static Context │
│ │ (Input │ Known/Unknown Inputs → Temporal Features │
│ │ Processing)│ │
│ └──────┬───────┘ │
│ │ │
│ ┌──────▼───────┐ │
│ │ 门控残差 │ Gated Residual Network (GRN) │
│ │ 网络(GRN) │ 跳过连接 + 可选层归一化 │
│ └──────┬───────┘ │
│ │ │
│ ┌──────▼───────┐ │
│ │ 局部编码 │ Time-to-first-event encoding │
│ │ (Local │ 位置编码 + 线性投影 │
│ │ Processing) │ │
│ └──────┬───────┘ │
│ │ │
│ ┌──────▼───────┐ │
│ │ 序列编码 │ LSTM Encoder │
│ │ (Sequence │ 编码历史信息 │
│ │ Encoding) │ │
│ └──────┬───────┘ │
│ │ │
│ ┌──────▼───────┐ │
│ │ 多头注意力 │ Interpretable Multi-Head Attention │
│ │ (Temporal │ 捕获长程依赖 │
│ │ Self- │ │
│ │ Attention) │ │
│ └──────┬───────┘ │
│ │ │
│ ┌──────▼───────┐ │
│ │ 序列解码 │ LSTM Decoder │
│ │ (Sequence │ 解码未来信息 │
│ │ Decoding) │ │
│ └──────┬───────┘ │
│ │ │
│ ┌──────▼───────┐ │
│ │ 输出层 │ 分位数输出层 │
│ │ (Output) │ 输出多个分位数预测 │
│ └──────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
1.2 输入类型
TFT 将时间序列数据分为三类输入:
| 输入类型 | 说明 | 示例 |
|---|---|---|
| Static Covariates | 不随时间变化的特征 | 产品类别、地理位置 |
| Known Future Inputs | 未来已知的信息 | 节日日历、计划事件 |
| Observed Inputs | 历史观察到的特征 | 历史销量、过去价格 |
二、核心组件详解
2.1 门控残差网络(GRN)
GRN 是 TFT 的核心构建块,用于自适应控制信息流:
class GatedResidualNetwork(nn.Module):
"""
门控残差网络 (Gated Residual Network)
核心思想:通过门控机制控制残差连接的强度,
让网络自适应决定是否需要跳过当前层
"""
def __init__(self, input_dim, hidden_dim, output_dim,
dropout=0.1, context_dim=None, activation='gelu'):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.context_dim = context_dim
self.hidden_dim = hidden_dim
# 线性投影层
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
# 可选的上下文投影
if context_dim is not None:
self.ctx_fc1 = nn.Linear(context_dim, hidden_dim, bias=False)
# 门控层
self.gate_fc = nn.Linear(output_dim, output_dim)
# Dropout
self.dropout = nn.Dropout(dropout)
# 激活函数
self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
# 可选:线性残差连接
self.skip_fc = nn.Linear(input_dim, output_dim) if input_dim != output_dim else None
def forward(self, x, context=None):
"""
Args:
x: 输入张量 [batch, ..., input_dim]
context: 可选的上下文张量 [batch, ..., context_dim]
Returns:
输出张量 [batch, ..., output_dim]
"""
# 保存原始输入用于残差连接
skip = x if self.skip_fc is None else self.skip_fc(x)
# 线性变换 + 激活
h = self.fc1(x)
if context is not None and self.context_dim is not None:
h = h + self.ctx_fc1(context)
h = self.activation(h)
h = self.dropout(h)
# 第二个线性变换
h = self.fc2(h)
h = self.dropout(h)
# 门控机制
gate = torch.sigmoid(self.gate_fc(h))
output = torch.mul(gate, h)
# 残差连接
output = output + skip
return outputGRN 的数学表述:
其中:
- 是主输入, 是上下文
- 是激活函数(GLU门控)
- 是逐元素乘法
- 是可学习权重
2.2 变量选择网络
变量选择网络用于识别对预测最重要的输入特征:
class VariableSelectionNetwork(nn.Module):
"""
变量选择网络
使用注意力机制对输入变量加权,
选出最相关的特征组合
"""
def __init__(self, n_variables, hidden_dim, dropout=0.1):
super().__init__()
self.n_variables = n_variables
self.hidden_dim = hidden_dim
# 每个变量的独立GRN
self.var_grns = nn.ModuleList([
GatedResidualNetwork(1, hidden_dim, hidden_dim, dropout)
for _ in range(n_variables)
])
# 变量选择GRN
self.selection_grn = GatedResidualNetwork(
n_variables, hidden_dim, n_variables, dropout
)
def forward(self, x):
"""
Args:
x: 输入张量 [batch, time, n_variables]
Returns:
selected: 加权后的特征 [batch, time, hidden_dim]
weights: 选择的权重 [batch, n_variables]
"""
# 原始变量
var_weights = self.selection_grn(x.mean(dim=1, keepdim=True).expand_as(x))
var_weights = torch.softmax(var_weights, dim=-1) # 归一化
# 应用权重
selected = []
for i in range(self.n_variables):
var_out = self.var_grns[i](x[..., i:i+1]) # 每个变量独立处理
selected.append(var_out * var_weights[..., i:i+1])
selected = torch.stack(selected, dim=-1).sum(dim=-1)
return selected, var_weights变量选择的数学表述:
2.3 时间处理
位置编码
TFT 使用可学习的位置编码而非正弦编码:
class TimeDistributedEmbedding(nn.Module):
"""时间分布的嵌入层"""
def __init__(self, n_tokens, embedding_dim):
super().__init__()
self.embedding = nn.Embedding(n_tokens, embedding_dim)
def forward(self, x):
"""
Args:
x: 时间索引 [batch, time]
Returns:
嵌入向量 [batch, time, embedding_dim]
"""
return self.embedding(x)LSTM 编码器-解码器
class LSTMSequenceEncoder(nn.Module):
"""LSTM 序列编码器"""
def __init__(self, input_dim, hidden_dim, num_layers=1, dropout=0.1):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0,
batch_first=True
)
def forward(self, x):
"""返回所有时间步的隐藏状态"""
output, (h_n, c_n) = self.lstm(x)
return output # [batch, time, hidden_dim]
class InterpretableMultiHeadAttention(nn.Module):
"""
可解释的多头自注意力
改进点:
1. 共享值矩阵以提高可解释性
2. 输出投影使用组合而非独立头
"""
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 查询、键、值投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 组合头输出的投影
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
"""
Args:
query, key, value: [batch, time, d_model]
mask: 可选的注意力掩码
Returns:
output: [batch, time, d_model]
attention_weights: [batch, n_heads, time, time]
"""
batch_size = query.size(0)
# 线性投影
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 重塑为多头形式
Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# 应用注意力
context = torch.matmul(attention_weights, V)
# 合并多头
context = context.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
output = self.W_o(context)
return output, attention_weights三、完整模型实现
3.1 TFT 主模型
class TemporalFusionTransformer(nn.Module):
"""
时间融合Transformer
完整实现,包含所有核心组件
"""
def __init__(self, n_static_features, n_known_features, n_observed_features,
d_model=128, n_heads=4, d_ffn=256, n_encoder_layers=2,
n_decoder_layers=2, dropout=0.1, n_quantiles=10):
super().__init__()
self.d_model = d_model
self.n_quantiles = n_quantiles
# ========== 静态特征处理 ==========
self.static_context_grn = GatedResidualNetwork(
n_static_features, d_model, d_model * 2, dropout
)
# ========== 时间特征处理 ==========
# Known features
self.known_grn = GatedResidualNetwork(
n_known_features, d_model, d_model, dropout
)
# Observed features
self.observed_grn = GatedResidualNetwork(
n_observed_features, d_model, d_model, dropout
)
# ========== 变量选择 ==========
self.encoder_vsn = VariableSelectionNetwork(
n_known_features + n_observed_features, d_model, dropout
)
# ========== 序列编码/解码 ==========
self.encoder_lstm = nn.LSTM(
d_model, d_model, n_encoder_layers,
dropout=dropout, batch_first=True
)
self.decoder_lstm = nn.LSTM(
d_model, d_model, n_decoder_layers,
dropout=dropout, batch_first=True
)
# ========== 自注意力 ==========
self.self_attention = InterpretableMultiHeadAttention(
d_model, n_heads, dropout
)
# ========== 输出层 ==========
# 编码器和解码器的投影
self.encoder_projection = nn.Linear(d_model, d_model)
self.decoder_projection = nn.Linear(d_model, d_model)
# 分位数输出
quantiles = torch.linspace(0.1, 0.9, n_quantiles)
self.register_buffer('quantiles', quantiles)
self.output_layer = nn.Linear(d_model, n_quantiles)
def forward(self, static_input, known_input, observed_input,
encoder_mask=None, decoder_mask=None):
"""
Args:
static_input: [batch, n_static_features]
known_input: [batch, encoder_len + decoder_len, n_known_features]
observed_input: [batch, encoder_len, n_observed_features]
encoder_mask: [batch, encoder_len]
decoder_mask: [batch, decoder_len]
Returns:
predictions: [batch, decoder_len, n_quantiles]
attention_weights: [batch, n_heads, total_len, total_len]
"""
batch_size = static_input.size(0)
# ========== 静态上下文 ==========
static_context = self.static_context_grn(static_input)
static_encoder_context = static_context[:, :self.d_model]
static_decoder_context = static_context[:, self.d_model:]
# ========== 编码器输入 ==========
# 合并 known 和 observed 输入
encoder_combined = torch.cat([
known_input[:, :known_input.size(1) - decoder_mask.size(1)],
observed_input
], dim=-1)
# 变量选择
encoder_features, encoder_weights = self.encoder_vsn(encoder_combined)
# 应用GRN
encoder_features = self.known_grn(
encoder_features,
context=static_encoder_context.unsqueeze(1).expand_as(encoder_features)
)
# ========== LSTM 编码 ==========
encoder_output, _ = self.encoder_lstm(encoder_features)
# ========== 解码器输入 ==========
decoder_features = known_input[:, -decoder_mask.size(1):]
decoder_features = self.known_grn(
decoder_features,
context=static_decoder_context.unsqueeze(1).expand_as(decoder_features)
)
# ========== LSTM 解码 ==========
decoder_output, _ = self.decoder_lstm(
decoder_features,
(encoder_output[:, -1:, :],
torch.zeros_like(encoder_output[:, -1:, :]))
)
# ========== 自注意力 ==========
# 合并编码器和解码器进行注意力计算
total_len = encoder_output.size(1) + decoder_output.size(1)
combined_sequence = torch.cat([encoder_output, decoder_output], dim=1)
attention_output, attention_weights = self.self_attention(
combined_sequence, combined_sequence, combined_sequence
)
# 分离编码器和解码器的注意力输出
encoder_attention = attention_output[:, :encoder_output.size(1)]
decoder_attention = attention_output[:, encoder_output.size(1):]
# ========== 输出 ==========
decoder_output = decoder_output + decoder_attention
decoder_output = self.decoder_projection(decoder_output)
predictions = self.output_layer(decoder_output)
return predictions, attention_weights3.2 数据预处理
def prepare_tft_data(df, target_col, time_col,
static_cols, known_cols, observed_cols,
encoder_len, decoder_len):
"""
为TFT准备数据
"""
from sklearn.preprocessing import StandardScaler
# 提取特征
static = df[static_cols].values
known = df[known_cols].values
observed = df[observed_cols].values
target = df[target_col].values
# 归一化
scaler = StandardScaler()
target_scaled = scaler.fit_transform(target.reshape(-1, 1)).flatten()
# 构建时间序列
sequences = []
for i in range(len(df) - encoder_len - decoder_len + 1):
seq = {
'static': static[i],
'known_input': known[i:i+encoder_len+decoder_len],
'observed_input': observed[i:i+encoder_len],
'target': target_scaled[i+encoder_len:i+encoder_len+decoder_len]
}
sequences.append(seq)
return sequences, scaler3.3 训练与评估
def train_tft(model, train_loader, val_loader, epochs=50, lr=1e-3):
"""训练TFT模型"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=5, factor=0.5
)
for epoch in range(epochs):
model.train()
train_loss = 0
for batch in train_loader:
static, known_in, observed_in, target = batch
optimizer.zero_grad()
predictions, _ = model(
static, known_in, observed_in
)
# 分位数损失
loss = quantile_loss(predictions, target, model.quantiles)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
# 验证
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
static, known_in, observed_in, target = batch
predictions, _ = model(static, known_in, observed_in)
loss = quantile_loss(predictions, target, model.quantiles)
val_loss += loss.item()
scheduler.step(val_loss)
if epoch % 5 == 0:
print(f"Epoch {epoch}: Train Loss = {train_loss/len(train_loader):.4f}, "
f"Val Loss = {val_loss/len(val_loader):.4f}")
def quantile_loss(pred, target, quantiles):
"""分位数损失"""
losses = []
for i, q in enumerate(quantiles):
errors = target - pred[:, :, i]
loss = torch.max((q - 1) * errors, q * errors)
losses.append(loss.mean())
return sum(losses) / len(quantiles)四、可解释性特性
4.1 变量重要性
def get_variable_importance(model, dataloader, feature_names):
"""获取变量重要性"""
model.eval()
# 收集所有批次的权重
all_weights = []
with torch.no_grad():
for batch in dataloader:
static, known_in, observed_in, _ = batch
_, weights = model.encoder_vsn(
torch.cat([known_in, observed_in], dim=-1)
)
all_weights.append(weights.mean(dim=1)) # 时间平均
weights = torch.cat(all_weights, dim=0).mean(dim=0)
# 按重要性排序
importance = dict(zip(feature_names, weights.numpy()))
importance = dict(sorted(importance.items(), key=lambda x: -x[1]))
return importance4.2 注意力可视化
def visualize_attention(attention_weights, encoder_len, decoder_len,
time_index=None, save_path='attention.png'):
"""可视化注意力权重"""
import matplotlib.pyplot as plt
# 平均所有头和批次
attn = attention_weights.mean(dim=(0, 1)).numpy()
plt.figure(figsize=(12, 10))
plt.imshow(attn, cmap='viridis', aspect='auto')
plt.colorbar()
# 添加分界线
plt.axhline(y=encoder_len, color='white', linestyle='--', linewidth=2)
plt.axvline(x=encoder_len, color='white', linestyle='--', linewidth=2)
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Attention Weights Heatmap')
if time_index is not None:
plt.xticks(range(0, len(time_index), 24),
[time_index[i] for i in range(0, len(time_index), 24)],
rotation=45)
plt.yticks(range(0, len(time_index), 24),
[time_index[i] for i in range(0, len(time_index), 24)])
plt.tight_layout()
plt.savefig(save_path)
plt.show()五、与其它模型的对比
| 特性 | TFT | N-BEATS | N-HiTS | LSTM |
|---|---|---|---|---|
| 可解释性 | ✓✓✓ | ✓✓ | ✓ | ✓ |
| 多变量支持 | ✓✓✓ | ✓ | ✓ | ✓✓ |
| 静态特征 | ✓✓✓ | ✗ | ✗ | ✗ |
| 长序列处理 | ✓✓ | ✓ | ✓✓ | ✓ |
| 计算效率 | ✓ | ✓✓ | ✓✓ | ✓✓✓ |
| 分位数预测 | ✓✓✓ | ✓ | ✓ | ✓ |
六、参考
相关阅读
Footnotes
-
Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021). Temporal Fusion Transformers for interpretable multi-horizon time series forecasting. International Journal of Forecasting, 37(4), 1748-1764. ↩