torch_frame.config.ModelConfig
- class ModelConfig(model: Callable[[Union[Tensor, MultiNestedTensor, MultiEmbeddingTensor, dict[str, torch_frame.data.multi_nested_tensor.MultiNestedTensor]]], Tensor], out_channels: int)[source]
基础类:
object
可学习的模型,将单列的
TensorData
对象映射为行嵌入。- Parameters:
model (可调用的) – 一个可调用的模型,它接受一个形状为
[batch_size, 1, *]
的TensorData
对象作为输入,并输出形状为[batch_size, 1, out_channels]
的嵌入。out_channels (int) – 模型输出通道。