Pyro 示例

数据集

多MNIST

此脚本生成类似于[1]中描述的多MNIST数据集的数据集。

[1] Eslami, SM Ali, 等. “Attend, infer, repeat: 使用生成模型快速理解场景.” 神经信息处理系统进展. 2016.

imresize(arr, size)[source]
sample_one(canvas_size, mnist)[source]
sample_multi(num_digits, canvas_size, mnist)[source]
mk_dataset(n, mnist, max_digits, canvas_size)[source]
load_mnist(root_path)[source]
load(root_path)[源代码]

BART乘客量

load_bart_od()[source]

加载一个数据集,包含2011-2019年间每小时BART车站之间的起点-终点乘客量统计。

来源 https://www.bart.gov/about/reports/ridership

这会在第一次调用时下载数据集。在后续调用中,这会从本地缓存文件.pkl.bz2中读取。这会尝试下载由Pyro团队维护的预处理压缩缓存文件。如果缓存命中,这应该会非常快。如果缓存未命中,则会回退到下载原始数据源并预处理数据集,需要大约350MB的文件传输,存储几GB的临时文件,并可能需要超过30分钟的时间。

Returns

数据集是一个包含以下字段的字典:

  • “stations”:一个包含站点名称的字符串列表

  • “start_date”:第一个观测值的datetime.datetime

  • “counts”:一个形状为(num_hours, len(stations), len(stations))torch.FloatTensor,表示乘车次数

load_fake_od()[source]

创建一个用于冒烟测试的小型合成数据集。

Nextstrain SARS-CoV-2 计数

load_nextstrain_counts(map_location=None) dict[source]

加载一个SARS-CoV-2数据集。

原始数据集是一个预处理的中间metadata.tsv.gz,可通过 nextstrain获取。 然后,metadata.tsv.gz文件被Broad Institute的 预处理脚本聚合到 (月,位置,谱系)和(谱系,突变)的区间中。

实用工具

class MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)[source]

基础类:Generic[torch.utils.data.dataset.T_co]

mirrors = ['https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/', 'http://yann.lecun.com/exdb/mnist/', 'https://ossci-datasets.s3.amazonaws.com/mnist/']
get_data_loader(dataset_name, data_dir, batch_size=1, dataset_transforms=None, is_training_set=True, shuffle=True)[source]
print_and_log(logger, msg)[source]
get_data_directory(filepath=None)[source]