Pyro 示例¶
数据集¶
多MNIST¶
此脚本生成类似于[1]中描述的多MNIST数据集的数据集。
[1] Eslami, SM Ali, 等. “Attend, infer, repeat: 使用生成模型快速理解场景.” 神经信息处理系统进展. 2016.
BART乘客量¶
- load_bart_od()[source]¶
加载一个数据集,包含2011-2019年间每小时BART车站之间的起点-终点乘客量统计。
来源 https://www.bart.gov/about/reports/ridership
这会在第一次调用时下载数据集。在后续调用中,这会从本地缓存文件
.pkl.bz2
中读取。这会尝试下载由Pyro团队维护的预处理压缩缓存文件。如果缓存命中,这应该会非常快。如果缓存未命中,则会回退到下载原始数据源并预处理数据集,需要大约350MB的文件传输,存储几GB的临时文件,并可能需要超过30分钟的时间。- Returns
数据集是一个包含以下字段的字典:
“stations”:一个包含站点名称的字符串列表
“start_date”:第一个观测值的
datetime.datetime
“counts”:一个形状为
(num_hours, len(stations), len(stations))
的torch.FloatTensor
,表示乘车次数
Nextstrain SARS-CoV-2 计数¶
- load_nextstrain_counts(map_location=None) dict [source]¶
加载一个SARS-CoV-2数据集。
原始数据集是一个预处理的中间
metadata.tsv.gz
,可通过 nextstrain获取。 然后,metadata.tsv.gz
文件被Broad Institute的 预处理脚本聚合到 (月,位置,谱系)和(谱系,突变)的区间中。
实用工具¶
- class MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)[source]¶
基础类:
Generic
[torch.utils.data.dataset.T_co
]- mirrors = ['https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/', 'http://yann.lecun.com/exdb/mnist/', 'https://ossci-datasets.s3.amazonaws.com/mnist/']¶