torch_geometric.nn.models.GRetriever

class GRetriever(llm: LLM, gnn: Module, use_lora: bool = False, mlp_out_channels: int = 4096, mlp_out_tokens: int = 1)[source]

Bases: Module

来自“G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and Question Answering”论文的G-Retriever模型。

Parameters:
  • llm (LLM) – The LLM to use.

  • gnn (torch.nn.Module) – 使用的GNN。

  • use_lora (bool, optional) – 如果设置为 True,将使用 peft 中的 LORA 来训练 LLM,详情请参见 这里。 (默认值: False)

  • mlp_out_channels (int, optional) – 投影后每个图嵌入的大小。(默认值: 4096)

  • mlp_out_tokens (int, optional) – 为GNN输出保留的LLM前缀标记数量。(默认值:1

警告

该模块已使用以下HuggingFace模型进行测试

  • llm_to_use="meta-llama/Llama-2-7b-chat-hf"

  • llm_to_use="google/gemma-7b"

and may not work with other models. See other models at HuggingFace Models and let us know if you encounter any issues.

注意

有关使用 GRetriever 的示例,请参见 examples/llm/g_retriever.py

forward(question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, label: List[str], edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None)[source]

前向传播。

Parameters:
  • 问题 (列表[str]) – 问题/提示。

  • x (torch.Tensor) – The input node features.

  • edge_index (torch.Tensor) – 边的索引。

  • batch (torch.Tensor) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个元素分配给特定的示例。

  • 标签 (列表[str]) – 答案/标签。

  • edge_attr (torch.Tensor, optional) – 边的特征(如果GNN支持的话)。(默认值: None)

  • additional_text_context (List[str], optional) – 提供给LLM的额外上下文,例如文本化的知识图谱。 (默认值: None)

inference(question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None, max_out_tokens: Optional[int] = 32)[source]

推理过程。

Parameters:
  • 问题 (列表[str]) – 问题/提示。

  • x (torch.Tensor) – The input node features.

  • edge_index (torch.Tensor) – 边的索引。

  • batch (torch.Tensor) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个元素分配给特定的示例。

  • edge_attr (torch.Tensor, optional) – 边的特征(如果GNN支持的话)。(默认值: None)

  • additional_text_context (List[str], optional) – 提供给LLM的额外上下文,例如文本化的知识图谱。 (default: None)

  • max_out_tokens (int, optional) – 大语言模型生成的最大令牌数。(默认值: 32)