torch_frame.config.ModelConfig

class ModelConfig(model: Callable[[Union[Tensor, MultiNestedTensor, MultiEmbeddingTensor, dict[str, torch_frame.data.multi_nested_tensor.MultiNestedTensor]]], Tensor], out_channels: int)[source]

基础类:object

可学习的模型,将单列的TensorData对象映射为行嵌入。

Parameters:
  • model (可调用的) – 一个可调用的模型,它接受一个形状为 [batch_size, 1, *]TensorData 对象作为输入,并输出形状为 [batch_size, 1, out_channels] 的嵌入。

  • out_channels (int) – 模型输出通道。