torch_geometric.nn.models.GRetriever
- class GRetriever(llm: LLM, gnn: Module, use_lora: bool = False, mlp_out_channels: int = 4096, mlp_out_tokens: int = 1)[source]
Bases:
Module来自“G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and Question Answering”论文的G-Retriever模型。
- Parameters:
警告
该模块已使用以下HuggingFace模型进行测试
llm_to_use="meta-llama/Llama-2-7b-chat-hf"llm_to_use="google/gemma-7b"
and may not work with other models. See other models at HuggingFace Models and let us know if you encounter any issues.
注意
有关使用
GRetriever的示例,请参见 examples/llm/g_retriever.py。- forward(question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, label: List[str], edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None)[source]
前向传播。
- Parameters:
问题 (列表[str]) – 问题/提示。
x (torch.Tensor) – The input node features.
edge_index (torch.Tensor) – 边的索引。
batch (torch.Tensor) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个元素分配给特定的示例。
标签 (列表[str]) – 答案/标签。
edge_attr (torch.Tensor, optional) – 边的特征(如果GNN支持的话)。(默认值:
None)additional_text_context (List[str], optional) – 提供给LLM的额外上下文,例如文本化的知识图谱。 (默认值:
None)
- inference(question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None, max_out_tokens: Optional[int] = 32)[source]
推理过程。
- Parameters:
问题 (列表[str]) – 问题/提示。
x (torch.Tensor) – The input node features.
edge_index (torch.Tensor) – 边的索引。
batch (torch.Tensor) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个元素分配给特定的示例。
edge_attr (torch.Tensor, optional) – 边的特征(如果GNN支持的话)。(默认值:
None)additional_text_context (List[str], optional) – 提供给LLM的额外上下文,例如文本化的知识图谱。 (default:
None)max_out_tokens (int, optional) – 大语言模型生成的最大令牌数。(默认值:
32)