图神经网络在流体模拟中的应用

图神经网络(GNN)为流体模拟提供了一种处理非结构化数据、自然表达复杂几何关系的强大范式。MeshGraphNets等架构在保持物理一致性的同时,实现了大规模复杂几何流场的高效预测。

1. 为什么GNN适合流体模拟?

1.1 传统CFD的局限性

传统CFD方法面临以下挑战:

方法优势局限性
有限元(FEM)处理复杂几何计算量大
有限体积(FVM)守恒性好需要高质量网格
有限差分(FDM)实现简单复杂几何受限
谱方法高精度规则域限制

1.2 GNN的核心优势

  1. 任意几何:不依赖规则网格,直接处理非结构化网格
  2. 自适应精度:不同区域可有不同节点密度
  3. 物理先验:通过图结构编码局部性和邻域关系
  4. 可扩展性:消息传递机制天然适合并行计算

1.3 流体模拟中的图表示

流体域 被离散化为图

  • 节点 :网格顶点或单元中心,代表流体变量(速度、压力)
  • :网格连接关系,编码空间邻域
@dataclass
class FluidMesh:
    """流体模拟的图表示"""
    num_nodes: int
    node_features: torch.Tensor  # (N, node_dim) 节点特征
    edge_index: torch.Tensor     # (2, E) 边索引
    edge_features: torch.Tensor # (E, edge_dim) 边特征
    cell_connectivity: torch.Tensor  # 单元连接
    
    # 可选:几何信息
    node_positions: torch.Tensor  # (N, 3) 节点坐标
    edge_lengths: torch.Tensor    # (E,) 边长度
    
    def to(self, device):
        """移动到设备"""
        return FluidMesh(
            num_nodes=self.num_nodes,
            node_features=self.node_features.to(device),
            edge_index=self.edge_index.to(device),
            edge_features=self.edge_features.to(device),
            cell_connectivity=self.cell_connectivity.to(device),
            node_positions=self.node_positions.to(device),
            edge_lengths=self.edge_lengths.to(device)
        )

2. MeshGraphNets架构

2.1 核心架构

MeshGraphNets(Pfaff et al., 2021)是GNN在流体模拟中的里程碑工作:

class MeshGraphNets(nn.Module):
    """MeshGraphNets: Learning mesh-based PDEs"""
    def __init__(self, node_dim, edge_dim, latent_dim, n_message_passing, n_decoder_layers):
        super().__init__()
        
        # 编码器
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(),
        )
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(),
        )
        
        # 消息传递层
        self.message_passing = nn.ModuleList([
            EdgeBlock(latent_dim) for _ in range(n_message_passing)
        ] + [
            NodeBlock(latent_dim) for _ in range(n_message_passing)
        ])
        
        # 解码器
        self.decoder = nn.Sequential(*[
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
        ] * n_decoder_layers)
        
        # 输出头
        self.velocity_head = nn.Linear(latent_dim, 3)  # 速度预测
        self.pressure_head = nn.Linear(latent_dim, 1)  # 压力预测
    
    def forward(self, graph):
        """
        graph: 包含 node_features, edge_index, edge_features
        """
        # 编码
        node_features = self.node_encoder(graph.node_features)
        edge_features = self.edge_encoder(graph.edge_features)
        
        # 消息传递
        for i in range(len(self.message_passing)):
            if i % 2 == 0:  # EdgeBlock
                edge_features = self.message_passing[i](
                    node_features, edge_features, graph.edge_index
                )
            else:  # NodeBlock
                node_features = self.message_passing[i](
                    node_features, edge_features, graph.edge_index
                )
        
        # 解码
        latent = self.decoder(node_features)
        
        # 输出
        velocity = self.velocity_head(latent)
        pressure = self.pressure_head(latent)
        
        return velocity, pressure
 
 
class EdgeBlock(nn.Module):
    """边更新块"""
    def __init__(self, latent_dim):
        super().__init__()
        self.message_net = nn.Sequential(
            nn.Linear(3 * latent_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(),
        )
    
    def forward(self, node_features, edge_features, edge_index):
        src, dst = edge_index
        
        # 聚合源和目标节点特征
        messages = torch.cat([
            node_features[src],
            node_features[dst],
            edge_features
        ], dim=-1)
        
        # 边更新
        return node_features[src] + self.message_net(messages)
 
 
class NodeBlock(nn.Module):
    """节点更新块"""
    def __init__(self, latent_dim):
        super().__init__()
        self.aggregate_net = nn.Sequential(
            nn.Linear(2 * latent_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(),
        )
        self.update_net = nn.Sequential(
            nn.Linear(2 * latent_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(),
        )
    
    def forward(self, node_features, edge_features, edge_index):
        src, dst = edge_index
        
        # 聚合传入消息
        incoming = scatter_mean(edge_features, dst, dim=0, dim_size=len(node_features))
        
        aggregated = torch.cat([node_features, incoming], dim=-1)
        aggregated = self.aggregate_net(aggregated)
        
        # 节点更新(残差连接)
        return node_features + self.update_net(aggregated)

2.2 训练策略

def train_meshgraphnets(model, train_loader, val_loader, device):
    """MeshGraphNets训练"""
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=10, factor=0.5
    )
    
    best_val_loss = float('inf')
    
    for epoch in range(100):
        # 训练
        model.train()
        train_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            
            # 自回归预测下一步
            velocity_pred, pressure_pred = model(batch)
            velocity_target = batch.next_velocity
            pressure_target = batch.next_pressure
            
            loss = F.mse_loss(velocity_pred, velocity_target) + \
                   0.1 * F.mse_loss(pressure_pred, pressure_target)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                velocity_pred, _ = model(batch)
                val_loss += F.mse_loss(velocity_pred, batch.next_velocity).item()
        
        scheduler.step(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pt')
        
        print(f"Epoch {epoch}: train_loss={train_loss:.6f}, val_loss={val_loss:.6f}")

2.3 关键设计选择

  1. 消息传递方向:边块更新使用有向消息传递
  2. 残差连接:保持信息流动
  3. 归一化:LayerNorm稳定训练
  4. 多步预测:训练时使用teacher forcing

3. 3D流体模拟

3.1 3D图表示

3D网格比2D复杂得多:

  • 四面体网格(TetMesh):最常见,适合任意复杂几何
  • 六面体网格(HexMesh):结构化程度高,精度好
  • 混合网格:不同区域用不同单元
class TetrahedralMesh:
    """四面体网格处理"""
    def __init__(self, nodes, elements):
        self.nodes = nodes  # (N, 3) 节点坐标
        self.elements = elements  # (E, 4) 四面体顶点索引
        
        # 构建图结构
        self.edge_index = self._build_edges()
        self.edge_features = self._compute_edge_features()
        self.node_features = self._compute_node_features()
    
    def _build_edges(self):
        """从四面体构建边"""
        edges = set()
        for elem in self.elements:
            # 四面体有6条边
            for (i, j) in [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]:
                edges.add((min(elem[i], elem[j]), max(elem[i], elem[j])))
        return torch.tensor(list(edges), dtype=torch.long).T
    
    def _compute_edge_features(self):
        """计算边特征:距离、方向"""
        src, dst = self.edge_index
        diff = self.nodes[dst] - self.nodes[src]
        distance = torch.norm(diff, dim=-1, keepdim=True)
        direction = diff / (distance + 1e-8)
        
        return torch.cat([distance, direction], dim=-1)

3.2 3D MeshGraphNets

class MeshGraphNets3D(nn.Module):
    """3D MeshGraphNets"""
    def __init__(self, node_dim, edge_dim, latent_dim, n_steps):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(node_dim + 3, latent_dim),  # +3 for position
            nn.LayerNorm(latent_dim),
            nn.GELU(),
        )
        
        # 更深的网络处理3D复杂性
        self.encoder_edge = nn.Sequential(
            nn.Linear(edge_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.GELU(),
        )
        
        self.mp_steps = nn.ModuleList([
            MessagePassing3D(latent_dim) for _ in range(n_steps)
        ])
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.GELU(),
            nn.Linear(latent_dim, 3),  # velocity
        )
    
    def forward(self, mesh):
        x = torch.cat([mesh.node_features, mesh.node_positions], dim=-1)
        x = self.encoder(x)
        edge_attr = self.encoder_edge(mesh.edge_features)
        
        for step in self.mp_steps:
            x, edge_attr = step(x, edge_attr, mesh.edge_index)
        
        return self.decoder(x)
 
 
class MessagePassing3D(nn.Module):
    """3D消息传递"""
    def __init__(self, latent_dim):
        super().__init__()
        self.edge_net = nn.Sequential(
            nn.Linear(3 * latent_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.GELU(),
        )
        self.node_net = nn.Sequential(
            nn.Linear(2 * latent_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.GELU(),
        )
    
    def forward(self, node_features, edge_features, edge_index):
        src, dst = edge_index
        
        # 边更新
        edge_msg = torch.cat([
            node_features[src],
            node_features[dst],
            edge_features
        ], dim=-1)
        edge_out = self.edge_net(edge_msg)
        
        # 节点更新
        aggregated = scatter_add(edge_out, dst, dim=0, dim_size=len(node_features))
        node_out = node_features + self.node_net(
            torch.cat([node_features, aggregated], dim=-1)
        )
        
        return node_out, edge_out

4. 可变形域上的GNN

4.1 动态网格问题

实际应用中,网格可能随时间变化:

  • 流体-结构相互作用(FSI)
  • 气动弹性
  • **自由表面流动

4.2 图结构的动态更新

class DynamicMeshGNN(nn.Module):
    """动态网格上的GNN"""
    def __init__(self, latent_dim):
        super().__init__()
        self.feature_encoder = FeatureEncoder(latent_dim)
        self.dynamics_predictor = DynamicsPredictor(latent_dim)
        self.mesh_updater = MeshUpdater(latent_dim)
    
    def forward(self, mesh, predict_mesh=True):
        """
        Args:
            mesh: 当前时刻的网格
            predict_mesh: 是否预测下一时刻网格
        """
        # 编码特征
        node_features = self.feature_encoder(mesh)
        
        # 预测速度变化
        velocity_update = self.dynamics_predictor(node_features, mesh.edge_index)
        
        # 预测下一时刻网格(可选)
        if predict_mesh:
            mesh_update = self.mesh_updater(
                node_features, mesh.node_positions
            )
            next_positions = mesh.node_positions + mesh_update
        else:
            next_positions = mesh.node_positions
        
        return velocity_update, next_positions
 
 
class MeshUpdater(nn.Module):
    """网格位移预测"""
    def __init__(self, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + 3, latent_dim),
            nn.GELU(),
            nn.Linear(latent_dim, 3),  # 位移
        )
    
    def forward(self, features, positions):
        x = torch.cat([features, positions], dim=-1)
        return self.net(x)

5. 混合方法:GNN + 传统求解器

5.1 粗网格求解 + GNN校正

传统求解器在粗网格上计算,GNN在细网格上校正:

其中 是粗到细的插值, 是GNN校正器。

5.2 物理约束集成

在GNN中集成物理守恒约束:

class PhysicsConstrainedGNN(nn.Module):
    """带物理约束的GNN"""
    def __init__(self, latent_dim):
        super().__init__()
        self.gnn = MeshGraphNets(latent_dim)
        
        # 散度惩罚层
        self.divergence_penalty = DivergencePenalty()
    
    def forward(self, mesh, return_penalty=True):
        velocity_pred = self.gnn(mesh)
        
        if return_penalty:
            # 计算散度惩罚(不可压缩约束)
            div_penalty = self.divergence_penalty(
                velocity_pred, mesh.node_positions, mesh.edge_index
            )
            return velocity_pred, div_penalty
        return velocity_pred
 
 
class DivergencePenalty(nn.Module):
    """散度惩罚项"""
    def forward(self, velocity, positions, edge_index):
        src, dst = edge_index
        
        # 计算相对位移
        rel_pos = positions[dst] - positions[src]
        rel_vel = velocity[dst] - velocity[src]
        
        # 近似散度
        divergence = torch.sum(rel_vel * rel_pos, dim=-1) / \
                    (torch.norm(rel_pos, dim=-1)**2 + 1e-8)
        
        return torch.mean(divergence**2)

6. 大规模流体模拟

6.1 分布式训练

class DistributedMeshGraphNets:
    """分布式MeshGraphNets"""
    def __init__(self, model, num_gpus):
        self.model = model
        self.num_gpus = num_gpus
        
        # 图分区
        self.partitioner = MeshPartitioner()
    
    def train(self, mesh_dataset, batch_size):
        # 按空间分区
        partitioned_meshes = self.partitioner.partition(mesh_dataset, self.num_gpus)
        
        # 每个GPU训练本地子图
        for local_mesh in partitioned_meshes:
            local_loss = self.local_forward(local_mesh)
            local_loss.backward()
        
        # 梯度同步
        self.sync_gradients()

6.2 推理加速

  • 图采样:随机采样子图进行推理
  • 层次化方法:粗细网格交替
  • 图神经网络编译器:TorchScript优化

7. 前沿进展与未来方向

7.1 Graph Transformer for Fluids

将Transformer的注意力机制引入图流体模拟:

  • 全局信息传递:捕获长程依赖
  • 动态边权重:自适应感受野

7.2 Neural PDE Solver Integration

与神经微分方程求解器结合:

  • PINN-GNN混合:PINNs提供物理约束
  • 数据驱动初始条件:学习更好的初始化

7.3 工业应用

  • 航空工业:飞机周围流场快速预测
  • 汽车工业:风洞测试替代
  • 能源行业:管道流动优化

参考文献