可扩展多智能体RL:序列模型与新基准
1. 引言
随着多智能体系统规模的增长,传统方法面临严峻的可扩展性挑战:
| 挑战 | 传统方法 | Transformer方法 |
|---|---|---|
| 参数复杂度 | 通信 | 但高效并行 |
| 泛化能力 | 差 | 好(注意力机制) |
| 计算效率 | 难以并行 | 高度可并行 |
| 多任务适应 | 需要重新训练 | 上下文适应 |
核心洞察:将多智能体决策建模为序列生成问题,利用Transformer的强大序列建模能力。
2. Oryx:可扩展序列模型
2.1 核心架构
Oryx1将多智能体系统建模为统一的序列生成过程:
class OryxArchitecture(nn.Module):
"""
Oryx核心架构
将N个智能体的决策建模为生成联合动作序列
"""
def __init__(
self,
n_agents: int,
state_dim: int,
obs_dim: int,
action_dim: int,
hidden_dim: int = 256,
n_layers: int = 6,
n_heads: int = 8,
dropout: float = 0.1
):
super().__init__()
self.n_agents = n_agents
self.hidden_dim = hidden_dim
# 输入嵌入层
self.state_embed = nn.Linear(state_dim, hidden_dim)
self.obs_embed = nn.Linear(obs_dim, hidden_dim)
self.action_embed = nn.Linear(action_dim, hidden_dim)
# 时间步嵌入
self.time_embed = nn.Embedding(max_len=1000, embedding_dim=hidden_dim)
# 智能体嵌入
self.agent_embed = nn.Embedding(n_agents, hidden_dim)
# Transformer编码器
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=n_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
norm_first=True # Pre-LN for stability
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# 输出头
self.action_heads = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, action_dim)
)
for _ in range(n_agents)
])
# 值函数头(用于RL训练)
self.value_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1)
)
def forward(
self,
state: torch.Tensor, # [B, state_dim]
observations: torch.Tensor, # [B, N, obs_dim]
history_actions: torch.Tensor = None, # [B, T, N, action_dim]
timestep: torch.Tensor = None # [B]
) -> Dict[str, torch.Tensor]:
"""
前向传播
"""
B, N = observations.shape[:2]
# 编码状态
state_enc = self.state_embed(state).unsqueeze(1) # [B, 1, H]
# 编码观察
obs_enc = self.obs_embed(observations) # [B, N, H]
# 添加智能体嵌入
agent_ids = torch.arange(N, device=observations.device)
agent_emb = self.agent_embed(agent_ids).unsqueeze(0).expand(B, -1, -1)
obs_enc = obs_enc + agent_emb
# 编码历史动作
if history_actions is not None and history_actions.size(1) > 0:
T = history_actions.size(1)
hist_enc = self.action_embed(history_actions) # [B, T, N, H]
hist_enc = hist_enc.view(B, T * N, self.hidden_dim) # [B, T*N, H]
else:
hist_enc = torch.zeros(B, 1, self.hidden_dim, device=observations.device)
# 添加时间嵌入
if timestep is not None:
time_emb = self.time_embed(timestep).unsqueeze(1) # [B, 1, H]
state_enc = state_enc + time_emb
# 拼接序列
# 顺序: [state, obs_1, obs_2, ..., obs_N, history]
seq = torch.cat([state_enc, obs_enc, hist_enc], dim=1) # [B, 1+N+T*N, H]
# 生成注意力掩码(防止看到未来)
seq_len = seq.size(1)
attn_mask = torch.triu(
torch.ones(seq_len, seq_len, device=seq.device),
diagonal=1
).bool()
# Transformer处理
encoded = self.encoder(seq, mask=attn_mask) # [B, seq_len, H]
# 提取各智能体的表示
agent_repr = encoded[:, 1:1+N] # [B, N, H]
# 生成动作
actions = torch.zeros(B, N, self.action_dim, device=observations.device)
for i, head in enumerate(self.action_heads):
actions[:, i] = head(agent_repr[:, i])
# 生成值函数估计
state_repr = encoded[:, 0] # [B, H]
value = self.value_head(state_repr) # [B, 1]
return {
'actions': actions,
'value': value,
'representations': agent_repr
}2.2 Offline MARL设置
Oryx采用离线强化学习设置,从预先收集的数据集学习:
class OryxOfflineMARL:
"""
Oryx离线MARL训练
"""
def __init__(self, config: OryxConfig):
self.model = OryxArchitecture(**config.__dict__)
self.target_model = copy.deepcopy(self.model)
self.gamma = 0.99
self.tau = 0.005 # 软更新参数
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict]:
"""
计算训练损失
"""
state = batch['state']
observations = batch['observations']
actions = batch['actions']
rewards = batch['rewards']
next_observations = batch['next_observations']
dones = batch['dones']
# 当前策略输出
outputs = self.model(state, observations)
pred_actions = outputs['actions']
# 行为克隆损失
bc_loss = F.mse_loss(pred_actions, actions)
# TD损失
with torch.no_grad():
next_outputs = self.target_model(state, next_observations)
next_values = next_outputs['value']
td_target = rewards + self.gamma * (1 - dones) * next_values
# 值函数损失
value_pred = outputs['value']
value_loss = F.mse_loss(value_pred, td_target)
# 加权组合
total_loss = bc_loss + 0.5 * value_loss
metrics = {
'bc_loss': bc_loss.item(),
'value_loss': value_loss.item(),
'total_loss': total_loss.item()
}
return total_loss, metrics
def soft_update(self):
"""
软更新目标网络
"""
for target_param, param in zip(
self.target_model.parameters(),
self.model.parameters()
):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)2.3 Many-Agent协调
Oryx通过分组注意力支持大规模智能体:
class GroupedAttention(nn.Module):
"""
分组注意力:支持大规模智能体系统
"""
def __init__(self, hidden_dim: int, num_groups: int, group_size: int):
super().__init__()
self.num_groups = num_groups
self.group_size = group_size
self.hidden_dim = hidden_dim
# 组内注意力
self.intra_group_attn = nn.MultiheadAttention(
hidden_dim, num_heads=4, batch_first=True
)
# 组间注意力
self.inter_group_attn = nn.MultiheadAttention(
hidden_dim, num_heads=4, batch_first=True
)
# 组表示聚合
self.group_aggregator = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, num_groups * group_size, H]
Returns:
[B, num_groups * group_size, H]
"""
B = x.size(0)
# 重塑为组结构
x_reshaped = x.view(B, self.num_groups, self.group_size, self.hidden_dim)
# 组内注意力
intra_group = self.intra_group_attn(
x_reshaped, x_reshaped, x_reshaped
)[0] # [B, num_groups, group_size, H]
# 生成组表示
group_repr = intra_group.mean(dim=2) # [B, num_groups, H]
# 组间注意力
inter_group = self.inter_group_attn(
group_repr, group_repr, group_repr
)[0] # [B, num_groups, H]
# 广播回每个智能体
inter_group_expanded = inter_group.unsqueeze(2).expand(
-1, -1, self.group_size, -1
) # [B, num_groups, group_size, H]
# 组合组内和组间信息
combined = torch.cat([intra_group, inter_group_expanded], dim=-1)
combined = self.group_aggregator(combined)
return combined.view(B, -1, self.hidden_dim)3. STAIRS-Former
3.1 时空注意力设计
STAIRS-Former2专门为多任务MARL设计,强调时空注意力:
class STAIRSFormer(nn.Module):
"""
STAIRS-Former: 时空注意力多智能体Transformer
核心:分别建模时间和空间维度的依赖关系
"""
def __init__(
self,
n_agents: int,
obs_dim: int,
action_dim: int,
hidden_dim: int = 256,
n_heads: int = 8,
spatial_layers: int = 4,
temporal_layers: int = 2
):
super().__init__()
self.n_agents = n_agents
self.hidden_dim = hidden_dim
# 输入嵌入
self.obs_encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU()
)
# 空间注意力层
self.spatial_attention = nn.ModuleList([
SpatialAttentionLayer(hidden_dim, n_heads)
for _ in range(spatial_layers)
])
# 时间注意力层
self.temporal_attention = nn.ModuleList([
TemporalAttentionLayer(hidden_dim, n_heads)
for _ in range(temporal_layers)
])
# 跨时空融合
self.spatiotemporal_fusion = SpatioTemporalFusion(hidden_dim)
# 任务编码
self.task_encoder = nn.Sequential(
nn.Linear(task_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 输出头
self.action_head = nn.Linear(hidden_dim, action_dim)
self.value_head = nn.Linear(hidden_dim, 1)
def forward(
self,
observations: torch.Tensor, # [B, T, N, obs_dim]
task_embedding: torch.Tensor = None
) -> Dict[str, torch.Tensor]:
"""
前向传播
Args:
observations: [B, T, N, obs_dim] - B批次, T时间步, N智能体
"""
B, T, N = observations.shape[:3]
# 编码观察
h = self.obs_encoder(observations) # [B, T, N, H]
# 时空混合
# 形状: [B*T, N, H] 或 [B, T, N, H]
# 空间注意力(建模智能体间关系)
for spatial_layer in self.spatial_attention:
h = spatial_layer(h) # [B, T, N, H]
# 时间注意力(建模时序依赖)
for temporal_layer in self.temporal_attention:
h = temporal_layer(h) # [B, T, N, H]
# 跨时空融合
h = self.spatiotemporal_fusion(h)
# 添加任务编码
if task_embedding is not None:
task_enc = self.task_encoder(task_embedding)
h = h + task_enc.unsqueeze(1).unsqueeze(1)
# 输出动作和值
actions = self.action_head(h) # [B, T, N, action_dim]
values = self.value_head(h) # [B, T, N, 1]
return {
'actions': actions,
'values': values,
'representations': h
}
class SpatialAttentionLayer(nn.Module):
"""
空间注意力层:建模智能体间的依赖关系
"""
def __init__(self, hidden_dim: int, n_heads: int):
super().__init__()
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.qkv = nn.Linear(hidden_dim, hidden_dim * 3)
self.proj = nn.Linear(hidden_dim, hidden_dim)
self.norm = nn.LayerNorm(hidden_dim)
# 可学习的相对位置编码
self.relative_bias = nn.Parameter(torch.zeros(n_heads, N, N))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [B, T, N, H] 或 [B, N, H]
"""
# 处理3D输入
original_shape = x.shape
if len(x.shape) == 4:
B, T, N, H = x.shape
x = x.view(B * T, N, H)
else:
B, N, H = x.shape
T = 1
# QKV投影
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
# 多头注意力
q = q.view(B * T, N, self.n_heads, H // self.n_heads).transpose(1, 2)
k = k.view(B * T, N, self.n_heads, H // self.n_heads).transpose(1, 2)
v = v.view(B * T, N, self.n_heads, H // self.n_heads).transpose(1, 2)
# 相对位置偏置
attn_bias = self.relative_bias[:self.n_heads]
# 计算注意力
attn = (q @ k.transpose(-2, -1)) / (H // self.n_heads) ** 0.5
attn = attn + attn_bias.unsqueeze(0)
attn = F.softmax(attn, dim=-1)
# 应用注意力
out = (attn @ v).transpose(1, 2).contiguous()
out = out.view(B * T, N, H)
# 投影和残差
out = self.proj(out)
out = self.norm(x + out)
# 恢复形状
if len(original_shape) == 4:
out = out.view(B, T, N, H)
return out
class TemporalAttentionLayer(nn.Module):
"""
时间注意力层:建模时序依赖
"""
def __init__(self, hidden_dim: int, n_heads: int):
super().__init__()
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.self_attn = nn.MultiheadAttention(
hidden_dim, n_heads, batch_first=True
)
self.norm = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Linear(hidden_dim * 4, hidden_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [B, T, N, H] -> [B, T, N, H]
"""
B, T, N, H = x.shape
# 重排列: [B, N, T, H] -> [B*N, T, H]
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(B * N, T, H)
# 时间自注意力
attn_out, _ = self.self_attn(x, x, x)
x = self.norm(x + attn_out)
# FFN
x = self.norm(x + self.ffn(x))
# 恢复形状: [B*N, T, H] -> [B, T, N, H]
x = x.view(B, N, T, H)
x = x.permute(0, 2, 1, 3).contiguous()
return x3.2 多任务MARL支持
class MultiTaskMARL:
"""
多任务MARL支持
"""
def __init__(self, base_model: STAIRSFormer, n_tasks: int):
self.model = base_model
self.n_tasks = n_tasks
# 任务嵌入
self.task_embeddings = nn.Embedding(n_tasks, 256)
# 任务判别器(用于辅助任务)
self.task_classifier = nn.Linear(256, n_tasks)
# 任务特定的动作空间
self.task_action_dims = {}
def forward_with_task(
self,
observations: torch.Tensor,
task_id: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
给定任务的forward
"""
# 获取任务嵌入
task_emb = self.task_embeddings(task_id)
# 模型forward
outputs = self.model(observations, task_emb)
# 添加任务判别辅助损失
task_logits = self.task_classifier(task_emb)
outputs['task_logits'] = task_logits
return outputs
def compute_multitask_loss(
self,
batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict]:
"""
计算多任务损失
"""
task_id = batch['task_id']
observations = batch['observations']
actions = batch['actions']
rewards = batch['rewards']
# 获取任务嵌入
task_emb = self.task_embeddings(task_id)
# Forward
outputs = self.model(observations, task_emb)
pred_actions = outputs['actions']
# 动作损失
action_loss = F.cross_entropy(
pred_actions.view(-1, pred_actions.size(-1)),
actions.view(-1),
reduction='mean'
)
# 值函数损失
value_loss = F.mse_loss(outputs['values'], rewards.unsqueeze(-1))
# 任务判别损失(辅助任务)
task_logits = self.task_classifier(task_emb)
task_loss = F.cross_entropy(task_logits, task_id)
# 总损失
total_loss = action_loss + 0.5 * value_loss + 0.1 * task_loss
metrics = {
'action_loss': action_loss.item(),
'value_loss': value_loss.item(),
'task_loss': task_loss.item(),
'total_loss': total_loss.item()
}
return total_loss, metrics4. 其他序列模型方法
4.1 MAST:Multi-Agent Spatial Transformer
MAST3专注于空间结构的建模:
class MAST(nn.Module):
"""
MAST: Multi-Agent Spatial Transformer
核心:利用空间结构先验提高效率
"""
def __init__(self, spatial_dim: int, hidden_dim: int):
super().__init__()
# 空间位置编码
self.spatial_encoding = nn.Parameter(
torch.randn(spatial_dim, spatial_dim, hidden_dim)
)
# 空间感知注意力
self.spatial_attention = SpatialAwareAttention(hidden_dim)
# 策略头
self.policy_head = nn.Linear(hidden_dim, action_dim)
def forward(self, observations: torch.Tensor, positions: torch.Tensor):
"""
observations: [B, N, obs_dim]
positions: [B, N, 2] - 2D positions
"""
# 获取空间编码
spatial_enc = self.get_spatial_encoding(positions) # [B, N, H]
# 空间感知处理
h = observations + spatial_enc
h = self.spatial_attention(h, positions)
# 策略输出
policy = self.policy_head(h)
return policy4.2 MATWM:Multi-Agent Transformer World Model
MATWM将世界模型引入多智能体学习:
class MATWM(nn.Module):
"""
MATWM: Multi-Agent Transformer World Model
核心:学习多智能体环境的动态模型
"""
def __init__(self, n_agents: int, state_dim: int, hidden_dim: int):
super().__init__()
# 观察编码器
self.obs_encoder = nn.Linear(obs_dim, hidden_dim)
# 状态解码器
self.state_decoder = nn.Linear(hidden_dim, state_dim)
# 奖励预测器
self.reward_predictor = nn.Linear(hidden_dim, 1)
# 变分推断
self.prior_network = nn.Linear(hidden_dim, hidden_dim * 2)
self.posterior_network = nn.Linear(hidden_dim * 2, hidden_dim * 2)
# Transformer动态模型
self.dynamics_transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(hidden_dim, n_heads=8, batch_first=True),
num_layers=6
)
def forward(self, observations: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
学习世界模型
"""
# 编码观察
h = self.obs_encoder(observations)
# 变分推断
prior_params = self.prior_network(h)
prior_mean, prior_logvar = prior_params.chunk(2, dim=-1)
# 采样潜在变量
z = self.reparameterize(prior_mean, prior_logvar)
# 预测下一步
pred_next = self.dynamics_transformer(z)
# 解码状态和奖励
pred_state = self.state_decoder(pred_next)
pred_reward = self.reward_predictor(pred_next)
return {
'pred_state': pred_state,
'pred_reward': pred_reward,
'z': z,
'prior_mean': prior_mean,
'prior_logvar': prior_logvar
}
def reparameterize(self, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std5. 新基准测试
5.1 Craftax:开放多智能体基准
Craftax是一个开放的、大规模的多智能体基准测试:
class CraftaxBenchmark:
"""
Craftax: 开放多智能体基准
特点:开放式任务、长期规划、部分可观察
"""
ENV_CONFIG = {
'grid_size': 64,
'n_agents': [2, 4, 8, 16], # 支持不同规模
'n_entity_types': 20,
'max_steps': 1000,
'observation_radius': 5,
'task_types': [
'exploration',
'resource_gathering',
'combat',
'construction',
'collaborative_puzzle'
]
}
def __init__(self, n_agents: int, task_type: str):
self.n_agents = n_agents
self.task_type = task_type
self.env = self.create_env()
def create_env(self):
"""
创建Craftax环境
"""
import craftax
return craftax.make(
'Craftax-v0',
num_agents=self.n_agents,
task=self.task_type,
**self.ENV_CONFIG
)
def evaluate(self, agent: nn.Module, n_episodes: int = 100) -> Dict:
"""
评估智能体
"""
episode_returns = []
success_rates = []
for _ in range(n_episodes):
obs = self.env.reset()
done = False
episode_return = 0
while not done:
# 批量处理所有智能体的观察
obs_batch = torch.stack([
torch.FloatTensor(o) for o in obs['observations']
]).unsqueeze(0)
with torch.no_grad():
actions = agent(obs_batch)
# 执行动作
action_list = actions.squeeze(0).argmax(dim=-1).numpy().tolist()
obs, reward, done, info = self.env.step(action_list)
episode_return += reward
episode_returns.append(episode_return)
success_rates.append(info.get('success', 0))
return {
'mean_return': np.mean(episode_returns),
'std_return': np.std(episode_returns),
'success_rate': np.mean(success_rates),
'n_agents': self.n_agents,
'task_type': self.task_type
}
class ScalabilityBenchmark:
"""
可扩展性基准测试
"""
def __init__(self):
self.results = {}
def run_scalability_test(
self,
model: nn.Module,
agent_counts: List[int] = [2, 4, 8, 16, 32, 64]
):
"""
测试不同智能体数量下的性能
"""
for n_agents in agent_counts:
benchmark = CraftaxBenchmark(n_agents, 'exploration')
# 测试性能
start_time = time.time()
result = benchmark.evaluate(model, n_episodes=10)
elapsed = time.time() - start_time
self.results[n_agents] = {
'performance': result['mean_return'],
'time_per_episode': elapsed / 10,
'memory_usage': self.get_memory_usage()
}
# 清理内存
del benchmark
torch.cuda.empty_cache()
return self.results
def plot_scalability(self):
"""
绘制可扩展性曲线
"""
import matplotlib.pyplot as plt
agents = list(self.results.keys())
performances = [self.results[a]['performance'] for a in agents]
times = [self.results[a]['time_per_episode'] for a in agents]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(agents, performances)
ax1.set_xlabel('Number of Agents')
ax1.set_ylabel('Performance')
ax1.set_title('Performance vs. Scale')
ax2.plot(agents, times)
ax2.set_xlabel('Number of Agents')
ax2.set_ylabel('Time per Episode (s)')
ax2.set_title('Computational Cost vs. Scale')
return fig5.2 基准对比表
| 基准 | 特点 | 智能体数量 | 任务类型 | 评估指标 |
|---|---|---|---|---|
| SMAC | 星际争霸微管理 | 2-27 | 战斗 | 胜率 |
| MPE | 多粒子环境 | 2-10 | 多样 | 回合奖励 |
| Hanabi | 协作卡牌 | 2-4 | 推理 | 得分 |
| Craftax | 开放世界 | 2-64 | 开放式 | 综合指标 |
| Overcooked-AI | 厨房协作 | 2 | 协作烹饪 | 任务完成 |
| Neural MMO | 大规模MMO | 100+ | 生存 | 多样性 |
5.3 扩展MARL评估的重要性
class MARLEvaluationFramework:
"""
扩展MARL评估框架
"""
@staticmethod
def compute_generalization_metrics(results: Dict) -> Dict:
"""
计算泛化指标
"""
metrics = {}
# 1. 规模泛化
# 从小规模训练的策略能否泛化到大规模?
metrics['scale_generalization'] = {
'train_4_eval_8': results['4']['performance'] - results['8']['performance'],
'train_8_eval_16': results['8']['performance'] - results['16']['performance'],
'train_16_eval_32': results['16']['performance'] - results['32']['performance']
}
# 2. 任务泛化
# 在训练任务上学到的能力能否泛化到新任务?
metrics['task_generalization'] = {
'zero_shot_performance': results.get('unseen_task', 0),
'few_shot_performance': results.get('few_shot_task', 0)
}
# 3. 组合泛化
# 能否组合已学技能解决新问题?
metrics['compositional_generalization'] = {
'seen_combos': results['seen_combos'],
'unseen_combos': results['unseen_combos'],
'generalization_gap': results['seen_combos'] - results['unseen_combos']
}
return metrics
@staticmethod
def compute_coordination_metrics(trajectories: List) -> Dict:
"""
计算协调质量指标
"""
metrics = {}
# 1. 通信效率
messages = [t['messages'] for t in trajectories if 'messages' in t]
if messages:
metrics['avg_message_length'] = np.mean([len(m) for m in messages])
metrics['message_redundancy'] = compute_redundancy(messages)
# 2. 行动协调
actions = [t['actions'] for t in trajectories]
metrics['action_diversity'] = compute_action_entropy(actions)
metrics['coordinated_actions'] = compute_coordination_score(actions)
# 3. 信用分配
contributions = [t['contributions'] for t in trajectories if 'contributions' in t]
if contributions:
metrics['fairness'] = compute_fairness(contributions)
metrics['credit_alignment'] = compute_credit_alignment(contributions)
return metrics6. 总结
可扩展多智能体RL的Transformer方法带来了新的可能性:
- 统一建模:将多智能体决策建模为序列生成问题
- 高效并行:Transformer的并行计算特性
- 强大泛化:注意力机制的泛化能力
- 新基准:Craftax等开放环境推动研究
未来方向包括:
- 更高效的大规模注意力机制
- 更丰富的多任务基准
- 因果推断与Transformer的结合
参考文献
Footnotes
-
Oryx框架详见 Oryx - 可扩展多智能体序列模型 ↩
-
STAIRS-Former结合了时空注意力设计 ↩
-
MAST和MATWM是其他值得关注的序列模型方法 ↩