概述

GRAG(Graph Retrieval-Augmented Generation)是一种针对网络化文档设计的图检索增强生成框架,由 Hu 等人在 NAACL 2025 上提出。1

传统 Naive RAG 在检索时关注单个文档,难以处理以下场景:

  • 引用网络:学术论文、专利文献之间的引用关系
  • 社交网络:用户之间的关注、转发、评论关系
  • 知识图谱:实体之间的多跳关联关系
  • 时序数据:事件随时间演化的动态关系

GRAG 的核心思想是将检索单元从文档扩展到子图,通过图结构保留文档间的语义关联,从而支持更复杂的推理任务。


1. GRAG 核心原理

1.1 问题定义

给定一个查询 和一个图结构化的文档集合 ,其中:

  • 表示节点集合(实体、文档片段)
  • 表示边集合(关系、引用、链接)

GRAG 的目标是检索与查询 最相关的子图 ,并基于 生成答案。

1.2 与 Naive RAG 的对比

维度Naive RAGGRAG
检索单元文档块 (Chunk)子图 (Subgraph)
关系建模隐式(向量相似性)显式(图结构)
多跳推理困难自然支持
上下文连贯性碎片化结构化保持
计算复杂度

1.3 系统架构

┌─────────────────────────────────────────────────────────────────┐
│                        GRAG Pipeline                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌──────────────┐    ┌──────────────┐    ┌──────────────────┐  │
│  │  网络化文档   │ →  │  图谱构建    │ →  │  子图检索        │  │
│  │(Documents)   │    │(Graph Build) │    │(Subgraph Retrieve)│  │
│  └──────────────┘    └──────┬───────┘    └────────┬─────────┘  │
│                             │                      │            │
│                             ↓                      ↓            │
│                    ┌──────────────┐       ┌──────────────────┐  │
│                    │  图索引      │       │  子图编码        │  │
│                    │(Graph Index)│       │(Subgraph Encode) │  │
│                    └──────────────┘       └────────┬─────────┘  │
│                                                    │            │
│                                                    ↓            │
│  ┌──────────────┐    ┌──────────────┐    ┌──────────────────┐  │
│  │   答案生成    │ ←  │  上下文融合  │ ←  │  相关性子图排序  │  │
│  │(Generation)  │    │(Fusion)      │    │(Relevance Rank)  │  │
│  └──────────────┘    └──────────────┘    └──────────────────┘  │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

2. 子图检索策略

2.1 子图检索的核心挑战

  1. 子图匹配:查询与知识图谱子结构的匹配
  2. 子图排序:多候选子图的相关性排序
  3. 子图编码:将图结构转化为可计算表示

2.2 基于锚点的子图检索

"""
基于锚点的子图检索
Anchor-based Subgraph Retrieval
"""
 
import numpy as np
from typing import List, Tuple, Dict, Set
from dataclasses import dataclass
 
@dataclass
class Entity:
    """实体节点"""
    id: str
    name: str
    type: str
    embedding: np.ndarray
 
@dataclass
class Relation:
    """关系边"""
    source: str
    target: str
    type: str
    weight: float = 1.0
 
@dataclass
class Subgraph:
    """检索到的子图"""
    nodes: List[Entity]
    edges: List[Relation]
    relevance_score: float
    path_explanation: List[str]
 
class AnchorBasedSubgraphRetrieval:
    """
    基于锚点的子图检索策略
    
    核心思想:
    1. 识别查询中的锚点实体
    2. 从锚点出发进行图遍历
    3. 扩展形成候选子图
    """
    
    def __init__(self, kg: Dict, embedding_model):
        self.kg = kg
        self.embedding_model = embedding_model
        self.node_index = {n['id']: n for n in kg['nodes']}
        
    def retrieve(self, query: str, top_k: int = 5, 
                 max_hops: int = 2) -> List[Subgraph]:
        """
        检索与查询最相关的子图
        
        Args:
            query: 查询文本
            top_k: 返回的子图数量
            max_hops: 最大跳数(控制子图大小)
        
        Returns:
            按相关性排序的子图列表
        """
        # Step 1: 识别查询中的锚点实体
        query_embedding = self.embedding_model.encode(query)
        anchors = self._identify_anchors(query_embedding, threshold=0.7)
        
        # Step 2: 从锚点扩展子图
        candidate_subgraphs = []
        for anchor in anchors:
            subgraph = self._expand_subgraph(anchor, max_hops)
            subgraph.relevance_score = self._compute_relevance(
                subgraph, query_embedding
            )
            candidate_subgraphs.append(subgraph)
        
        # Step 3: 排序并返回top_k
        candidate_subgraphs.sort(key=lambda x: x.relevance_score, reverse=True)
        return candidate_subgraphs[:top_k]
    
    def _identify_anchors(self, query_emb: np.ndarray, 
                          threshold: float) -> List[Entity]:
        """识别锚点实体"""
        anchors = []
        for node in self.kg['nodes']:
            similarity = np.dot(query_emb, node['embedding'])
            if similarity >= threshold:
                anchors.append(Entity(
                    id=node['id'],
                    name=node['name'],
                    type=node['type'],
                    embedding=node['embedding']
                ))
        return anchors
    
    def _expand_subgraph(self, anchor: Entity, max_hops: int) -> Subgraph:
        """从锚点扩展子图"""
        nodes = {anchor.id: anchor}
        edges = []
        visited = {anchor.id}
        queue = [(anchor.id, 0)]
        
        while queue:
            current_id, depth = queue.pop(0)
            if depth >= max_hops:
                continue
                
            # 遍历出边
            for edge in self.kg['edges']:
                if edge['source'] == current_id:
                    if edge['target'] not in visited:
                        visited.add(edge['target'])
                        nodes[edge['target']] = self.node_index[edge['target']]
                        queue.append((edge['target'], depth + 1))
                    edges.append(Relation(
                        source=edge['source'],
                        target=edge['target'],
                        type=edge['type'],
                        weight=edge.get('weight', 1.0)
                    ))
        
        return Subgraph(
            nodes=list(nodes.values()),
            edges=edges,
            relevance_score=0.0,
            path_explanation=[]
        )
    
    def _compute_relevance(self, subgraph: Subgraph, 
                           query_emb: np.ndarray) -> float:
        """计算子图相关性分数"""
        # 综合考虑节点覆盖度和语义相似度
        node_scores = []
        for node in subgraph.nodes:
            sim = np.dot(query_emb, node.embedding)
            node_scores.append(sim)
        
        # 加权平均 + 覆盖度惩罚
        avg_score = np.mean(node_scores)
        coverage = len(subgraph.nodes) / len(self.kg['nodes'])
        penalty = np.exp(-coverage * 0.5)  # 避免子图过大
        
        return avg_score * (1 - penalty * 0.2)

2.3 基于路径的子图检索

"""
基于路径的子图检索
Path-based Subgraph Retrieval
"""
 
from collections import deque
from typing import Optional
 
class PathBasedSubgraphRetrieval:
    """
    基于路径的子图检索策略
    
    核心思想:
    1. 找到查询中两个实体之间的最短路径
    2. 提取路径及其相邻节点形成子图
    3. 路径提供推理链的可解释性
    """
    
    def __init__(self, kg: Dict):
        self.kg = kg
        self.adjacency = self._build_adjacency()
        
    def _build_adjacency(self) -> Dict[str, List[Tuple[str, str]]]:
        """构建邻接表"""
        adj = {}
        for edge in self.kg['edges']:
            src, tgt, rel = edge['source'], edge['target'], edge['type']
            if src not in adj:
                adj[src] = []
            adj[src].append((tgt, rel))
            if tgt not in adj:
                adj[tgt] = []
            adj[tgt].append((src, f"REVERSE_{rel}"))
        return adj
    
    def find_path_between_entities(self, entity1: str, entity2: str, 
                                    max_length: int = 3) -> List[List[str]]:
        """
        查找两个实体之间的所有路径
        
        Args:
            entity1: 起始实体
            entity2: 目标实体
            max_length: 最大路径长度
        
        Returns:
            所有满足条件的路径列表
        """
        if entity1 not in self.adjacency or entity2 not in self.adjacency:
            return []
        
        paths = []
        queue = deque([(entity1, [entity1])])
        
        while queue:
            current, path = queue.popleft()
            if len(path) > max_length:
                continue
            
            if current == entity2 and len(path) > 1:
                paths.append(path)
                continue
            
            for next_node, relation in self.adjacency.get(current, []):
                if next_node not in path:  # 避免循环
                    queue.append((next_node, path + [next_node]))
        
        return paths
    
    def extract_subgraph_from_paths(self, paths: List[List[str]]) -> Subgraph:
        """从路径提取子图"""
        nodes_set = set()
        edges_list = []
        
        for path in paths:
            for i in range(len(path) - 1):
                nodes_set.add(path[i])
                nodes_set.add(path[i + 1])
                edges_list.append(Relation(
                    source=path[i],
                    target=path[i + 1],
                    type="PATH_EDGE",
                    weight=1.0 / len(path)
                ))
        
        # 去重
        nodes = [self.kg['nodes_map'][nid] for nid in nodes_set]
        unique_edges = []
        seen = set()
        for e in edges_list:
            key = (e.source, e.target)
            if key not in seen:
                seen.add(key)
                unique_edges.append(e)
        
        return Subgraph(
            nodes=nodes,
            edges=unique_edges,
            relevance_score=0.0,
            path_explanation=[f"路径: {' -> '.join(p)}" for p in paths]
        )

3. 图神经网络在 RAG 中的应用

3.1 GNN 编码器架构

图神经网络(GNN)能够有效捕捉知识图谱中的多跳关系和结构信息:

"""
GNN-based 子图编码器
GNN-based Subgraph Encoder
"""
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
 
class GraphAttentionLayer(nn.Module):
    """图注意力层"""
    
    def __init__(self, in_features: int, out_features: int, 
                 num_heads: int = 4, dropout: float = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = out_features // num_heads
        self.out_features = out_features
        
        self.W = nn.Linear(in_features, out_features, bias=False)
        self.att = nn.Parameter(torch.Tensor(1, num_heads, 2 * self.head_dim))
        self.bias = nn.Parameter(torch.Tensor(out_features))
        self.dropout = nn.Dropout(dropout)
        
        nn.init.xavier_uniform_(self.W.weight)
        nn.init.xavier_uniform_(self.att)
        nn.init.zeros_(self.bias)
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: 节点特征 [num_nodes, in_features]
            edge_index: 边索引 [2, num_edges]
        """
        N = x.size(0)
        H = self.num_heads
        D = self.head_dim
        
        # 线性变换
        x = self.W(x)  # [N, out_features]
        x = x.view(N, H, D)  # [N, H, D]
        
        # 计算注意力
        src, dst = edge_index
        src_x = x[src]  # [num_edges, H, D]
        dst_x = x[dst]  # [num_edges, H, D]
        
        # 拼接 src 和 dst 特征
        cat_x = torch.cat([src_x, dst_x], dim=-1)  # [num_edges, H, 2D]
        att = (cat_x * self.att).sum(dim=-1)  # [num_edges, H]
        att = F.leaky_relu(att, 0.2)
        att = F.softmax(att, dim=0)
        att = self.dropout(att)
        
        # 聚合邻居信息
        out = torch.zeros(N, H, D, device=x.device)
        out.index_add_(0, src, (att.unsqueeze(-1) * dst_x))
        
        out = out.view(N, -1) + self.bias
        return F.elu(out)
 
 
class GNNSubgraphEncoder(nn.Module):
    """
    GNN 子图编码器
    
    将知识图谱中的子图编码为固定维度的向量表示,
    支持后续的相似度计算和排序任务。
    """
    
    def __init__(self, node_features: int, hidden_dim: int = 256,
                 num_layers: int = 3, num_heads: int = 4,
                 dropout: float = 0.1):
        super().__init__()
        
        self.node_embedding = nn.Linear(node_features, hidden_dim)
        self.layers = nn.ModuleList([
            GraphAttentionLayer(hidden_dim, hidden_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = dropout
        
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
                node_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: 节点特征 [num_nodes, node_features]
            edge_index: 边索引 [2, num_edges]
            node_mask: 节点掩码,用于区分不同子图
        
        Returns:
            子图表示 [hidden_dim]
        """
        # 特征投影
        x = self.node_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # GNN 层前向传播
        for layer in self.layers:
            x = layer(x, edge_index)
            x = self.norm(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        # 如果有节点掩码,只聚合掩码节点的表示
        if node_mask is not None:
            x = x * node_mask.unsqueeze(-1)
        
        # 图级别池化:平均 + 最大池化拼接
        graph_repr = torch.cat([
            x.mean(dim=0),  # 平均池化
            x.max(dim=0)[0]  # 最大池化
        ], dim=-1)
        
        return graph_repr

3.2 GNN 增强的检索-生成框架

"""
GNN 增强的检索-生成框架
GNN-enhanced Retrieval-Generation Framework
"""
 
class GNNEnhancedRAG:
    """
    GNN 增强的 RAG 系统
    
    结合 GNN 的结构感知能力和 LLM 的生成能力,
    实现更精准的知识检索和答案生成。
    """
    
    def __init__(self, kg: Dict, llm, embedding_model,
                 hidden_dim: int = 256):
        self.kg = kg
        self.llm = llm
        self.subgraph_retriever = AnchorBasedSubgraphRetrieval(kg, embedding_model)
        
        # GNN 编码器
        self.gnn_encoder = GNNSubgraphEncoder(
            node_features=embedding_model.embedding_dim,
            hidden_dim=hidden_dim
        )
        
        # 查询编码器
        self.query_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def retrieve_and_encode(self, query: str) -> Tuple[Subgraph, torch.Tensor]:
        """检索子图并编码"""
        # Step 1: 检索相关子图
        subgraphs = self.subgraph_retriever.retrieve(query, top_k=5)
        
        if not subgraphs:
            return None, None
        
        # Step 2: GNN 编码
        subgraph = subgraphs[0]  # 选择最相关的子图
        node_features = torch.tensor(
            [n.embedding for n in subgraph.nodes]
        )
        
        # 构建边索引
        src = [e.source for e in subgraph.edges]
        tgt = [e.target for e in subgraph.edges]
        edge_index = torch.tensor([src, tgt])
        
        # 编码
        with torch.no_grad():
            subgraph_repr = self.gnn_encoder(node_features, edge_index)
        
        # Step 3: 融合查询表示
        query_emb = self.subgraph_retriever.embedding_model.encode(query)
        query_emb = torch.tensor(query_emb)
        combined_repr = torch.cat([query_emb, subgraph_repr], dim=-1)
        fused_repr = self.query_encoder(combined_repr)
        
        return subgraph, fused_repr
    
    def generate(self, query: str, subgraph: Subgraph, 
                 context_repr: torch.Tensor) -> str:
        """基于子图上下文生成答案"""
        # 构建提示
        context_text = self._format_subgraph_context(subgraph)
        
        prompt = f"""基于以下知识图谱子图回答问题。
 
知识图谱子图:
{context_text}
 
问题:{query}
 
要求:
1. 只使用子图中提供的信息
2. 如果信息不足,说明无法回答
3. 引用子图中的实体和关系来支撑答案
"""
        
        response = self.llm.generate(prompt)
        return response
    
    def _format_subgraph_context(self, subgraph: Subgraph) -> str:
        """将子图格式化为文本描述"""
        lines = []
        lines.append("节点(实体):")
        for node in subgraph.nodes:
            lines.append(f"  - {node.name} ({node.type})")
        
        lines.append("\n关系:")
        for edge in subgraph.edges:
            src_name = self.kg['nodes_map'][edge.source].get('name', edge.source)
            tgt_name = self.kg['nodes_map'][edge.target].get('name', edge.target)
            lines.append(f"  - {src_name} --[{edge.type}]--> {tgt_name}")
        
        return "\n".join(lines)

4. 动态知识图谱更新机制

4.1 增量更新策略

动态知识图谱需要支持新知识的实时融入:

"""
动态知识图谱更新机制
Dynamic Knowledge Graph Update Mechanism
"""
 
from dataclasses import field
from typing import Dict, List, Optional, Set
from datetime import datetime
import threading
 
@dataclass
class TimeStampedEntity:
    """带时间戳的实体"""
    id: str
    name: str
    type: str
    attributes: Dict
    created_at: datetime = field(default_factory=datetime.now)
    updated_at: datetime = field(default_factory=datetime.now)
    version: int = 1
    is_active: bool = True
 
@dataclass
class TimeStampedRelation:
    """带时间戳的关系"""
    source: str
    target: str
    type: str
    start_time: datetime
    end_time: Optional[datetime] = None
    confidence: float = 1.0
 
class DynamicKnowledgeGraph:
    """
    动态知识图谱
    
    支持:
    - 实体的新增、更新、删除
    - 关系的时间范围管理
    - 版本控制与回滚
    - 增量索引更新
    """
    
    def __init__(self, storage_backend=None, index_backend=None):
        self.entities: Dict[str, TimeStampedEntity] = {}
        self.relations: List[TimeStampedRelation] = []
        self.entity_history: Dict[str, List[TimeStampedEntity]] = {}
        
        self.storage = storage_backend  # 持久化存储
        self.index = index_backend  # 索引服务
        
        self._lock = threading.RLock()
        
    def add_entity(self, entity: TimeStampedEntity) -> None:
        """添加新实体"""
        with self._lock:
            if entity.id in self.entities:
                raise ValueError(f"Entity {entity.id} already exists")
            
            entity.created_at = datetime.now()
            entity.updated_at = datetime.now()
            entity.version = 1
            
            self.entities[entity.id] = entity
            self.entity_history[entity.id] = [entity]
            
            # 增量更新索引
            if self.index:
                self.index.add_entity(entity)
    
    def update_entity(self, entity_id: str, 
                      updates: Dict, 
                      timestamp: Optional[datetime] = None) -> TimeStampedEntity:
        """
        更新实体
        
        Args:
            entity_id: 实体ID
            updates: 更新内容
            timestamp: 更新时间(用于时序查询)
        
        Returns:
            更新后的实体
        """
        with self._lock:
            if entity_id not in self.entities:
                raise ValueError(f"Entity {entity_id} not found")
            
            old_entity = self.entities[entity_id]
            ts = timestamp or datetime.now()
            
            # 创建新版本
            new_entity = TimeStampedEntity(
                id=entity_id,
                name=updates.get('name', old_entity.name),
                type=updates.get('type', old_entity.type),
                attributes={**old_entity.attributes, **updates.get('attributes', {})},
                created_at=old_entity.created_at,
                updated_at=ts,
                version=old_entity.version + 1,
                is_active=updates.get('is_active', old_entity.is_active)
            )
            
            self.entities[entity_id] = new_entity
            self.entity_history[entity_id].append(new_entity)
            
            # 增量更新索引
            if self.index:
                self.index.update_entity(entity_id, new_entity)
            
            return new_entity
    
    def add_relation(self, relation: TimeStampedRelation) -> None:
        """添加新关系"""
        with self._lock:
            # 验证实体存在
            if relation.source not in self.entities:
                raise ValueError(f"Source entity {relation.source} not found")
            if relation.target not in self.entities:
                raise ValueError(f"Target entity {relation.target} not found")
            
            self.relations.append(relation)
            
            # 索引更新
            if self.index:
                self.index.add_relation(relation)
    
    def deactivate_relation(self, source: str, target: str, 
                           relation_type: str,
                           timestamp: Optional[datetime] = None) -> bool:
        """
        软删除关系(设置结束时间)
        
        支持关系的时间范围查询,例如查询历史状态。
        """
        ts = timestamp or datetime.now()
        
        for rel in self.relations:
            if (rel.source == source and 
                rel.target == target and 
                rel.type == relation_type and
                rel.end_time is None):
                rel.end_time = ts
                
                if self.index:
                    self.index.update_relation(rel)
                return True
        
        return False
    
    def get_active_state(self, timestamp: Optional[datetime] = None) -> Dict:
        """
        获取指定时间点的图谱状态
        
        用于回溯历史或时间点查询。
        """
        ts = timestamp or datetime.now()
        
        active_entities = {
            eid: entity for eid, entity in self.entities.items()
            if entity.is_active and entity.updated_at <= ts
        }
        
        active_relations = [
            rel for rel in self.relations
            if rel.start_time <= ts and 
               (rel.end_time is None or rel.end_time > ts)
        ]
        
        return {
            'entities': active_entities,
            'relations': active_relations,
            'timestamp': ts
        }
    
    def rollback_to_version(self, entity_id: str, 
                            target_version: int) -> TimeStampedEntity:
        """
        回滚实体到指定版本
        
        用于错误恢复或版本切换。
        """
        if entity_id not in self.entity_history:
            raise ValueError(f"No history for entity {entity_id}")
        
        history = self.entity_history[entity_id]
        target_entity = None
        
        for entity in history:
            if entity.version == target_version:
                target_entity = entity
                break
        
        if target_entity is None:
            raise ValueError(f"Version {target_version} not found")
        
        # 更新当前版本
        self.entities[entity_id] = target_entity
        
        if self.index:
            self.index.update_entity(entity_id, target_entity)
        
        return target_entity

4.2 变更检测与传播

class ChangeDetector:
    """
    知识图谱变更检测器
    
    检测图谱变化并触发相应的下游更新:
    - 索引更新
    - 缓存失效
    - 订阅通知
    """
    
    def __init__(self, dkg: DynamicKnowledgeGraph):
        self.dkg = dkg
        self.subscribers: List[callable] = []
        self.change_log: List[Dict] = []
        
    def subscribe(self, callback: callable) -> None:
        """订阅变更通知"""
        self.subscribers.append(callback)
    
    def notify_subscribers(self, change: Dict) -> None:
        """通知所有订阅者"""
        for subscriber in self.subscribers:
            try:
                subscriber(change)
            except Exception as e:
                print(f"Subscriber notification failed: {e}")
    
    def detect_and_propagate(self, change_type: str, 
                             change_data: Dict) -> None:
        """
        检测变更并传播
        
        Args:
            change_type: 变更类型 (entity_add, entity_update, 
                         relation_add, relation_delete)
            change_data: 变更数据
        """
        change = {
            'type': change_type,
            'data': change_data,
            'timestamp': datetime.now(),
            'id': len(self.change_log)
        }
        
        self.change_log.append(change)
        
        # 确定影响的节点和边
        affected_nodes = self._get_affected_nodes(change_type, change_data)
        affected_edges = self._get_affected_edges(change_type, change_data)
        
        change['affected_nodes'] = affected_nodes
        change['affected_edges'] = affected_edges
        
        # 传播变更
        self.notify_subscribers(change)
    
    def _get_affected_nodes(self, change_type: str, 
                             change_data: Dict) -> Set[str]:
        """获取受影响的节点"""
        if change_type in ['entity_add', 'entity_update', 'entity_delete']:
            return {change_data['entity_id']}
        elif change_type in ['relation_add', 'relation_delete']:
            return {change_data['source'], change_data['target']}
        return set()
    
    def _get_affected_edges(self, change_type: str, 
                             change_data: Dict) -> List[Tuple[str, str]]:
        """获取受影响的边"""
        if change_type in ['relation_add', 'relation_delete']:
            return [(change_data['source'], change_data['target'])]
        return []

5. 时序知识图谱 RAG

5.1 问题背景

时序知识图谱(Temporal Knowledge Graph, TKG)中的事实具有时间跨度:

其中 表示事实的有效时间范围。

5.2 STAR-RAG 框架

STAR-RAG 是 Temporal Graph-based RAG 的代表工作:

"""
时序知识图谱 RAG
Temporal Knowledge Graph RAG
"""
 
from dataclasses import dataclass
from typing import List, Optional, Tuple
from datetime import datetime, timedelta
 
@dataclass
class TemporalFact:
    """时序事实"""
    subject: str
    predicate: str
    object: str
    start_time: datetime
    end_time: Optional[datetime] = None
    
    def is_active_at(self, t: datetime) -> bool:
        """检查事实是否在指定时间有效"""
        if t < self.start_time:
            return False
        if self.end_time is not None and t > self.end_time:
            return False
        return True
 
@dataclass
class TemporalQuery:
    """时序查询"""
    question: str
    query_time: Optional[datetime] = None  # 查询的时间点
    time_range: Optional[Tuple[datetime, datetime]] = None  # 或时间范围
 
class TemporalKnowledgeGraph:
    """时序知识图谱"""
    
    def __init__(self):
        self.facts: List[TemporalFact] = []
        self.entity_index: Dict[str, List[TemporalFact]] = {}
        self.time_index: Dict[datetime, List[TemporalFact]] = {}
        
    def add_fact(self, fact: TemporalFact) -> None:
        """添加时序事实"""
        self.facts.append(fact)
        
        # 更新索引
        key = (fact.subject, fact.predicate, fact.object)
        if key not in self.entity_index:
            self.entity_index[key] = []
        self.entity_index[key].append(fact)
        
        # 时间索引(按月)
        month_key = fact.start_time.replace(day=1, hour=0, minute=0, second=0)
        if month_key not in self.time_index:
            self.time_index[month_key] = []
        self.time_index[month_key].append(fact)
    
    def query_at_time(self, t: datetime) -> List[TemporalFact]:
        """查询指定时间点的事实"""
        return [f for f in self.facts if f.is_active_at(t)]
    
    def query_in_range(self, start: datetime, 
                       end: datetime) -> List[TemporalFact]:
        """查询时间范围内的事实"""
        return [
            f for f in self.facts
            if f.start_time <= end and 
               (f.end_time is None or f.end_time >= start)
        ]
 
 
class STAR_RAG:
    """
    STAR-RAG: 时序感知的检索增强生成
    
    核心思想:
    1. 时序感知的子图检索
    2. 时间一致性约束
    3. 时序推理链构建
    """
    
    def __init__(self, tkg: TemporalKnowledgeGraph, llm):
        self.tkg = tkg
        self.llm = llm
        
    def retrieve_temporal_subgraph(self, query: TemporalQuery) -> Dict:
        """
        检索时序子图
        
        根据查询的时间约束,从时序知识图谱中检索相关子图。
        """
        if query.query_time:
            # 时间点查询
            facts = self.tkg.query_at_time(query.query_time)
        elif query.time_range:
            # 时间范围查询
            facts = self.tkg.query_in_range(
                query.time_range[0], 
                query.time_range[1]
            )
        else:
            # 无时间约束,返回全部
            facts = self.tkg.facts
        
        # 构建时序子图
        temporal_subgraph = {
            'facts': facts,
            'query': query,
            'temporal_paths': self._extract_temporal_paths(facts)
        }
        
        return temporal_subgraph
    
    def _extract_temporal_paths(self, facts: List[TemporalFact]) -> List[Dict]:
        """提取时序推理路径"""
        # 按时间排序事实
        sorted_facts = sorted(facts, key=lambda f: f.start_time)
        
        paths = []
        current_path = []
        
        for fact in sorted_facts:
            if not current_path:
                current_path.append(fact)
            else:
                last_fact = current_path[-1]
                # 检查时间连续性
                time_gap = (fact.start_time - last_fact.end_time).days if last_fact.end_time else 0
                
                if time_gap <= 30:  # 30天内算连续
                    current_path.append(fact)
                else:
                    if len(current_path) > 1:
                        paths.append({
                            'facts': current_path.copy(),
                            'duration': self._calc_duration(current_path)
                        })
                    current_path = [fact]
        
        if len(current_path) > 1:
            paths.append({
                'facts': current_path,
                'duration': self._calc_duration(current_path)
            })
        
        return paths
    
    def _calc_duration(self, facts: List[TemporalFact]) -> timedelta:
        """计算路径持续时间"""
        start = facts[0].start_time
        end = facts[-1].end_time or datetime.now()
        return end - start
    
    def generate_with_temporal_awareness(self, query: TemporalQuery,
                                          subgraph: Dict) -> str:
        """生成时序感知的答案"""
        context = self._format_temporal_context(subgraph)
        
        # 构建时序感知提示
        time_constraint = ""
        if query.query_time:
            time_constraint = f"问题关注的时间点:{query.query_time.strftime('%Y-%m-%d')}"
        elif query.time_range:
            time_constraint = (
                f"问题关注的时间范围:"
                f"{query.time_range[0].strftime('%Y-%m-%d')} 至 "
                f"{query.time_range[1].strftime('%Y-%m-%d')}"
            )
        
        prompt = f"""基于以下时序知识图谱回答问题。
 
{time_constraint}
 
时序知识图谱:
{context}
 
问题:{query.question}
 
要求:
1. 关注时间维度,不同时间可能有不同答案
2. 引用事实的时间范围来支撑答案
3. 体现事件的发展演变过程
"""
        
        return self.llm.generate(prompt)
    
    def _format_temporal_context(self, subgraph: Dict) -> str:
        """格式化时序上下文"""
        lines = []
        
        for fact in subgraph['facts']:
            time_str = fact.end_time.strftime('%Y-%m-%d') if fact.end_time else '至今'
            lines.append(
                f"- {fact.subject} --[{fact.predicate}]--> "
                f"{fact.object} ({fact.start_time.strftime('%Y-%m-%d')}{time_str})"
            )
        
        return "\n".join(lines)

6. SimGRAG 变体方法

6.1 SimGRAG 核心思想

SimGRAG(ACL 2025)提出了查询-模式-子图的对齐范式:

┌─────────────────────────────────────────────────────────────┐
│                    SimGRAG Pipeline                         │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  ┌──────────┐    ┌─────────────┐    ┌──────────────────┐   │
│  │  查询 Q   │ →  │ Query-to-   │ →  │ Pattern P         │   │
│  │          │    │ Pattern     │    │ (查询模式图)      │   │
│  └──────────┘    └─────────────┘    └────────┬─────────┘   │
│                                               │              │
│                                               ↓              │
│  ┌──────────┐    ┌─────────────┐    ┌──────────────────┐   │
│  │ 子图 S   │ ←  │ Graph       │ ←  │ Pattern-to-      │   │
│  │          │    │ Semantic    │    │ Subgraph         │   │
│  └──────────┘    │ Distance    │    │ Alignment        │   │
│                  └─────────────┘    └──────────────────┘   │
│                                                              │
└─────────────────────────────────────────────────────────────┘

6.2 SimGRAG 实现

"""
SimGRAG: Similar Subgraph Enhanced RAG
基于相似子图的检索增强生成
"""
 
from typing import List, Tuple, Dict, Optional
import numpy as np
 
class GraphPattern:
    """图模式(查询模式)"""
    
    def __init__(self, entities: List[str], relations: List[str],
                 structure: List[Tuple[int, int]]):
        """
        Args:
            entities: 实体类型列表
            relations: 关系类型列表
            structure: 邻接结构 [(src_idx, tgt_idx), ...]
        """
        self.entities = entities
        self.relations = relations
        self.structure = structure
        self.embedding: Optional[np.ndarray] = None
    
    def to_adjacency_matrix(self, size: int) -> np.ndarray:
        """转换为邻接矩阵"""
        adj = np.zeros((size, size))
        for src, tgt in self.structure:
            adj[src][tgt] = 1
        return adj
 
 
class SimGRAG:
    """
    SimGRAG 实现
    
    两阶段对齐:
    1. Query-to-Pattern: 将自然语言查询转化为图模式
    2. Pattern-to-Subgraph: 找到与模式最相似的子图
    """
    
    def __init__(self, kg: Dict, llm, embedding_model):
        self.kg = kg
        self.llm = llm
        self.embedding_model = embedding_model
        
    def query_to_pattern(self, query: str) -> GraphPattern:
        """
        阶段1: 查询转化为图模式
        
        使用 LLM 从查询中提取实体类型和关系结构。
        """
        prompt = f"""从以下查询中提取图模式。
 
查询:{query}
 
要求:
1. 识别查询中的实体类型(用 [ENTITY] 占位)
2. 识别实体间的关系类型(用 [RELATION] 占位)
3. 给出实体的连接关系
 
输出格式:
{{
  "entities": ["类型1", "类型2", ...],
  "relations": ["关系1", "关系2", ...],
  "structure": [[0, 1], [1, 2], ...]  // 索引对应 entities
}}
 
例如:
查询:"谁在谷歌工作并住在加州?"
输出:
{{
  "entities": ["PERSON", "ORGANIZATION", "LOCATION"],
  "relations": ["WORKS_AT", "LIVES_IN"],
  "structure": [[0, 1], [0, 2]]
}}
"""
        
        response = self.llm.generate(prompt)
        pattern_data = self._parse_pattern_response(response)
        
        pattern = GraphPattern(
            entities=pattern_data['entities'],
            relations=pattern_data['relations'],
            structure=pattern_data['structure']
        )
        
        # 生成模式嵌入
        pattern.embedding = self._encode_pattern(pattern)
        
        return pattern
    
    def _encode_pattern(self, pattern: GraphPattern) -> np.ndarray:
        """编码图模式为向量"""
        # 实体嵌入
        entity_embs = []
        for entity_type in pattern.entities:
            emb = self.embedding_model.encode(entity_type)
            entity_embs.append(emb)
        
        # 关系嵌入
        rel_embs = []
        for rel_type in pattern.relations:
            emb = self.embedding_model.encode(rel_type)
            rel_embs.append(emb)
        
        # 结构编码
        n = len(pattern.entities)
        adj = pattern.to_adjacency_matrix(n)
        
        # 组合:平均实体嵌入 + 平均关系嵌入 + 结构特征
        entity_feat = np.mean(entity_embs, axis=0) if entity_embs else np.zeros(128)
        rel_feat = np.mean(rel_embs, axis=0) if rel_embs else np.zeros(128)
        struct_feat = adj.flatten()
        
        pattern_emb = np.concatenate([entity_feat, rel_feat, struct_feat])
        return pattern_emb
    
    def pattern_to_subgraph(self, pattern: GraphPattern,
                            candidate_subgraphs: List[Dict],
                            top_k: int = 5) -> List[Dict]:
        """
        阶段2: 模式与子图对齐
        
        使用图语义距离(GSD)度量模式与子图的相似度。
        """
        scored_subgraphs = []
        
        for subgraph in candidate_subgraphs:
            gsd = self._compute_graph_semantic_distance(pattern, subgraph)
            scored_subgraphs.append({
                'subgraph': subgraph,
                'gsd_score': gsd,
                'pattern_match': self._extract_pattern_match(pattern, subgraph)
            })
        
        # 按 GSD 升序排序(距离越小越相似)
        scored_subgraphs.sort(key=lambda x: x['gsd_score'])
        
        return scored_subgraphs[:top_k]
    
    def _compute_graph_semantic_distance(self, pattern: GraphPattern,
                                          subgraph: Dict) -> float:
        """
        计算图语义距离 (Graph Semantic Distance)
        
        GSD 综合考虑:
        1. 实体类型匹配度
        2. 关系类型匹配度
        3. 图结构相似度
        """
        if not pattern.embedding or not subgraph.get('embedding'):
            return float('inf')
        
        # 嵌入空间距离
        emb_dist = np.linalg.norm(
            pattern.embedding - subgraph['embedding']
        )
        
        # 类型匹配惩罚
        entity_types = set(pattern.entities)
        subgraph_types = set(subgraph.get('entity_types', []))
        type_penalty = len(entity_types - subgraph_types) / max(len(entity_types), 1)
        
        # 结构距离
        p_size = len(pattern.entities)
        s_size = len(subgraph.get('nodes', []))
        size_penalty = abs(p_size - s_size) / max(p_size, s_size)
        
        # 综合 GSD
        gsd = emb_dist * (1 + type_penalty * 0.3 + size_penalty * 0.2)
        
        return gsd
    
    def _extract_pattern_match(self, pattern: GraphPattern,
                               subgraph: Dict) -> Dict:
        """提取模式匹配信息"""
        matches = {
            'entity_mapping': {},
            'relation_mapping': {},
            'matched_paths': []
        }
        
        # 简单匹配:基于类型和名称相似度
        for i, p_entity in enumerate(pattern.entities):
            best_match = None
            best_score = 0
            
            for node in subgraph.get('nodes', []):
                score = self.embedding_model.similarity(
                    p_entity, node.get('name', '')
                )
                if score > best_score:
                    best_score = score
                    best_match = node
            
            if best_match and best_score > 0.7:
                matches['entity_mapping'][i] = best_match
        
        return matches
    
    def retrieve(self, query: str, top_k: int = 5) -> List[Dict]:
        """完整的 SimGRAG 检索流程"""
        # 阶段1: 查询转模式
        pattern = self.query_to_pattern(query)
        
        # 获取候选子图
        candidate_subgraphs = self._get_candidate_subgraphs(query)
        
        # 阶段2: 模式与子图对齐
        matched = self.pattern_to_subgraph(pattern, candidate_subgraphs, top_k)
        
        return matched
    
    def _get_candidate_subgraphs(self, query: str) -> List[Dict]:
        """获取候选子图(可用其他检索方法)"""
        # 简化为基于关键词的召回
        # 实际应用中可使用向量检索、图遍历等方法
        return self.kg.get('subgraphs', [])
    
    def _parse_pattern_response(self, response: str) -> Dict:
        """解析 LLM 输出的模式"""
        import json
        import re
        
        # 尝试提取 JSON
        match = re.search(r'\{.*\}', response, re.DOTALL)
        if match:
            try:
                return json.loads(match.group())
            except json.JSONDecodeError:
                pass
        
        return {
            'entities': [],
            'relations': [],
            'structure': []
        }

7. GRAG vs Microsoft GraphRAG

7.1 核心差异

维度GRAG (Hu et al., 2025)Microsoft GraphRAG
提出时间2025 NAACL2024 (Microsoft Research)
核心目标网络化文档的图检索全局语义理解
检索粒度子图 (Subgraph)社区 (Community)
索引结构原始图结构社区层次索引
适用场景引用网络、社交网络大规模文档集合
生成方式直接基于子图社区摘要 + 全局总结

7.2 架构对比

┌────────────────────────────────────────────────────────────────┐
│                    Microsoft GraphRAG                         │
├────────────────────────────────────────────────────────────────┤
│                                                                 │
│  文档 → 实体抽取 → 知识图谱 → 社区检测 → 社区摘要              │
│                                      ↓                          │
│  查询 → Local/Global 检索 → 社区摘要 → 生成                     │
│                                                                 │
│  特点:                                                          │
│  - 支持全局性问题(需要综合全文)                               │
│  - 社区摘要提供高层次语义                                       │
│  - 适合开放域问答                                               │
│                                                                 │
└────────────────────────────────────────────────────────────────┘

┌────────────────────────────────────────────────────────────────┐
│                    GRAG (NAACL 2025)                           │
├────────────────────────────────────────────────────────────────┤
│                                                                 │
│  网络化文档 → 图构建 → 子图检索 → 子图编码 → 生成              │
│                                  ↓                              │
│  查询 → 锚点识别 → 子图扩展 → GNN编码 → 排序                   │
│                                                                 │
│  特点:                                                          │
│  - 保留原始图结构                                               │
│  - GNN 编码支持复杂关系                                         │
│  - 适合结构化知识问答                                           │
│                                                                 │
└────────────────────────────────────────────────────────────────┘

7.3 选型指南

应用场景
    │
    ├── 学术文献引用网络分析 ────────────→ GRAG
    │                                       
    ├── 社交媒体关系推理 ───────────────→ GRAG
    │                                       
    ├── 企业文档知识库(全局总结)────────→ Microsoft GraphRAG
    │                                       
    ├── 法律/医疗知识图谱问答 ───────────→ GRAG + 时序扩展
    │                                       
    └── 时序事件分析 ────────────────────→ T-GRAG / STAR-RAG

8. 完整示例:GRAG 系统实现

"""
完整的 GRAG 系统实现
Complete GRAG System Implementation
"""
 
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
from enum import Enum
import numpy as np
 
class RetrievalStrategy(Enum):
    """检索策略枚举"""
    ANCHOR_BASED = "anchor"
    PATH_BASED = "path"
    GNN_BASED = "gnn"
    SIMGRAG = "simgrag"
 
 
@dataclass
class GRAGConfig:
    """GRAG 配置"""
    embedding_dim: int = 256
    hidden_dim: int = 256
    num_gnn_layers: int = 3
    num_attention_heads: int = 4
    max_subgraph_nodes: int = 100
    max_hops: int = 2
    retrieval_top_k: int = 5
    strategy: RetrievalStrategy = RetrievalStrategy.ANCHOR_BASED
 
 
@dataclass
class DocumentNode:
    """文档节点"""
    id: str
    content: str
    metadata: Dict = field(default_factory=dict)
    embedding: Optional[np.ndarray] = None
 
 
@dataclass
class KnowledgeGraphEdge:
    """知识图谱边"""
    source: str
    target: str
    relation_type: str
    weight: float = 1.0
    metadata: Dict = field(default_factory=dict)
 
 
class GRAGSystem:
    """
    完整的 GRAG 系统
    
    集成子图检索、GNN 编码和 LLM 生成。
    """
    
    def __init__(self, config: GRAGConfig, llm, embedding_model):
        self.config = config
        self.llm = llm
        self.embedding_model = embedding_model
        
        # 图数据
        self.nodes: Dict[str, DocumentNode] = {}
        self.edges: List[KnowledgeGraphEdge] = []
        self.adjacency: Dict[str, List[Tuple[str, str]]] = {}
        
        # 组件
        self.subgraph_retriever = None
        self.gnn_encoder = None
        
    def build_graph(self, documents: List[Dict]) -> None:
        """
        从文档构建知识图谱
        
        实际应用中可使用 LLM 进行实体关系抽取。
        """
        # 添加节点
        for doc in documents:
            node = DocumentNode(
                id=doc['id'],
                content=doc['content'],
                metadata=doc.get('metadata', {}),
                embedding=self.embedding_model.encode(doc['content'])
            )
            self.nodes[node.id] = node
        
        # 添加边(示例:基于共现关系)
        for i, doc1 in enumerate(documents):
            for j, doc2 in enumerate(documents):
                if i >= j:
                    continue
                if self._has_relation(doc1, doc2):
                    edge = KnowledgeGraphEdge(
                        source=doc1['id'],
                        target=doc2['id'],
                        relation_type='RELATED_TO',
                        weight=0.8
                    )
                    self.edges.append(edge)
                    self._add_to_adjacency(edge)
    
    def _has_relation(self, doc1: Dict, doc2: Dict) -> bool:
        """判断两个文档是否有关系"""
        # 简化实现:基于关键词重叠
        keywords1 = set(doc1.get('keywords', []))
        keywords2 = set(doc2.get('keywords', []))
        overlap = len(keywords1 & keywords2)
        return overlap >= 2
    
    def _add_to_adjacency(self, edge: KnowledgeGraphEdge) -> None:
        """添加到邻接表"""
        if edge.source not in self.adjacency:
            self.adjacency[edge.source] = []
        self.adjacency[edge.source].append((edge.target, edge.relation_type))
        
        if edge.target not in self.adjacency:
            self.adjacency[edge.target] = []
        self.adjacency[edge.target].append((edge.source, f"REVERSE_{edge.relation_type}"))
    
    def query(self, question: str, strategy: Optional[RetrievalStrategy] = None) -> str:
        """
        执行 GRAG 查询
        
        Args:
            question: 用户问题
            strategy: 检索策略(可选,使用配置默认值)
        
        Returns:
            生成的回答
        """
        strategy = strategy or self.config.strategy
        
        # Step 1: 检索相关子图
        if strategy == RetrievalStrategy.ANCHOR_BASED:
            subgraph = self._anchor_based_retrieval(question)
        elif strategy == RetrievalStrategy.PATH_BASED:
            subgraph = self._path_based_retrieval(question)
        elif strategy == RetrievalStrategy.GNN_BASED:
            subgraph = self._gnn_based_retrieval(question)
        elif strategy == RetrievalStrategy.SIMGRAG:
            subgraph = self._simgrag_retrieval(question)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")
        
        if not subgraph:
            return "抱歉,我无法从知识图谱中找到相关信息来回答这个问题。"
        
        # Step 2: 构建上下文
        context = self._build_context(subgraph)
        
        # Step 3: 生成回答
        answer = self._generate(question, context)
        
        return answer
    
    def _anchor_based_retrieval(self, question: str) -> Dict:
        """基于锚点的子图检索"""
        query_emb = self.embedding_model.encode(question)
        
        # 找锚点节点
        anchors = []
        for node_id, node in self.nodes.items():
            sim = np.dot(query_emb, node.embedding)
            if sim > 0.6:
                anchors.append((node_id, sim))
        
        if not anchors:
            return {}
        
        anchors.sort(key=lambda x: x[1], reverse=True)
        anchor_id = anchors[0][0]
        
        # 扩展子图
        subgraph_nodes = {anchor_id}
        subgraph_edges = []
        
        queue = [(anchor_id, 0)]
        visited = {anchor_id}
        
        while queue:
            current, depth = queue.pop(0)
            if depth >= self.config.max_hops:
                continue
            
            for neighbor, rel_type in self.adjacency.get(current, []):
                if neighbor not in visited:
                    visited.add(neighbor)
                    subgraph_nodes.add(neighbor)
                    queue.append((neighbor, depth + 1))
                
                edge = KnowledgeGraphEdge(
                    source=current, target=neighbor, relation_type=rel_type
                )
                subgraph_edges.append(edge)
        
        return {
            'nodes': [self.nodes[nid] for nid in subgraph_nodes],
            'edges': subgraph_edges,
            'anchor': anchor_id
        }
    
    def _path_based_retrieval(self, question: str) -> Dict:
        """基于路径的子图检索"""
        # 找相关节点
        query_emb = self.embedding_model.encode(question)
        relevant_nodes = []
        
        for node_id, node in self.nodes.items():
            sim = np.dot(query_emb, node.embedding)
            if sim > 0.5:
                relevant_nodes.append((node_id, sim))
        
        if len(relevant_nodes) < 2:
            return self._anchor_based_retrieval(question)
        
        relevant_nodes.sort(key=lambda x: x[1], reverse=True)
        
        # 找两个最相关节点之间的路径
        path = self._find_shortest_path(
            relevant_nodes[0][0], 
            relevant_nodes[1][0]
        )
        
        if not path:
            return {}
        
        # 提取路径子图
        subgraph_nodes = set(path)
        subgraph_edges = []
        
        for i in range(len(path) - 1):
            edge = KnowledgeGraphEdge(
                source=path[i], target=path[i + 1], relation_type='PATH'
            )
            subgraph_edges.append(edge)
        
        return {
            'nodes': [self.nodes[nid] for nid in subgraph_nodes],
            'edges': subgraph_edges,
            'path': path
        }
    
    def _gnn_based_retrieval(self, question: str) -> Dict:
        """基于 GNN 的子图检索"""
        # 这是简化版本,完整实现需要 GNN 编码器
        return self._anchor_based_retrieval(question)
    
    def _simgrag_retrieval(self, question: str) -> Dict:
        """SimGRAG 检索"""
        # SimGRAG 需要额外的模式匹配逻辑
        # 这里使用简化的锚点检索作为基础
        return self._anchor_based_retrieval(question)
    
    def _find_shortest_path(self, src: str, tgt: str) -> List[str]:
        """BFS 找最短路径"""
        from collections import deque
        
        queue = deque([(src, [src])])
        visited = {src}
        
        while queue:
            current, path = queue.popleft()
            
            if current == tgt:
                return path
            
            for neighbor, _ in self.adjacency.get(current, []):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, path + [neighbor]))
        
        return []
    
    def _build_context(self, subgraph: Dict) -> str:
        """构建检索上下文"""
        lines = ["知识图谱上下文:", ""]
        
        lines.append("实体:")
        for node in subgraph.get('nodes', []):
            lines.append(f"  - [{node.id}] {node.content[:100]}...")
        
        lines.append("")
        lines.append("关系:")
        for edge in subgraph.get('edges', []):
            lines.append(f"  - {edge.source} --[{edge.relation_type}]--> {edge.target}")
        
        return "\n".join(lines)
    
    def _generate(self, question: str, context: str) -> str:
        """生成回答"""
        prompt = f"""基于以下知识图谱上下文回答问题。
 
{context}
 
问题:{question}
 
要求:
1. 只使用图中提供的信息
2. 引用相关实体和关系
3. 回答要准确、简洁
"""
        
        return self.llm.generate(prompt)
 
 
# 使用示例
def demo():
    """GRAG 系统演示"""
    
    # 模拟数据
    documents = [
        {
            'id': 'doc1',
            'content': 'Transformer 架构是现代大语言模型的基础,由 Google 在 2017 年提出。',
            'keywords': ['transformer', 'google', '大语言模型']
        },
        {
            'id': 'doc2', 
            'content': 'BERT 是基于 Transformer 的预训练模型,在 NLP 任务上取得突破。',
            'keywords': ['bert', 'transformer', 'nlp', '预训练']
        },
        {
            'id': 'doc3',
            'content': 'GPT 系列由 OpenAI 开发,其中 GPT-3 有 1750 亿参数。',
            'keywords': ['gpt', 'openai', '大语言模型']
        },
        {
            'id': 'doc4',
            'content': 'ChatGPT 是基于 GPT 模型的对话应用,由 OpenAI 于 2022 年发布。',
            'keywords': ['chatgpt', 'openai', '对话', 'gpt']
        }
    ]
    
    print("GRAG 系统初始化...")
    print(f"加载了 {len(documents)} 个文档")
    print("知识图谱构建完成")
    print("\n查询示例:")
    print("  - 谁提出了 Transformer 架构?")
    print("  - BERT 和 GPT 有什么关系?")
    print("  - OpenAI 开发了哪些模型?")
 
 
if __name__ == "__main__":
    demo()

9. 参考资料


相关主题

Footnotes

  1. Hu, Y., Lei, Z., Zhang, Z., Pan, B., Ling, C., & Zhao, L. (2025). GRAG: Graph Retrieval-Augmented Generation. Findings of the Association for Computational Linguistics: NAACL 2025. https://aclanthology.org/2025.findings-naacl.232/