概率分布

作者: Jacob Schreiber 联系方式: jmschreiber91@gmail.com

pomegranate中的一切都围绕着概率分布的使用展开。虽然这些对象可以单独使用,例如拟合数据或给定参数并用于评估新样本,但它们旨在作为更大组合模型(如混合模型或隐马尔可夫模型)的一部分使用。由于pomegranate中的所有组件都设计为即插即用,这意味着任何概率分布都可以直接嵌入到其他模型中。

pomegranate v1.0.0版本与之前版本在分布上的一个关键区别是:早期版本中的分布通常是单变量的(即一个对象代表一个维度),而v1.0.0及之后版本中每个分布都是多变量的。在早期版本中若需建模多个维度,必须使用包含多个分布的IndependentComponentsDistribution,但在v1.0.0中只需使用单个分布对象即可。另外说明,IndependentComponents仍然可用,以便您可以用不同分布来建模不同维度。

[1]:
import numpy
import torch
from pomegranate.distributions import *

numpy.random.seed(0)
numpy.set_printoptions(suppress=True)

%load_ext watermark
%watermark -m -n -p numpy,torch,pomegranate
numpy      : 1.23.4
torch      : 1.13.0
pomegranate: 1.0.0

Compiler    : GCC 11.2.0
OS          : Linux
Release     : 4.15.0-208-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

初始化与拟合

首先,我们来看看如何创建一个概率分布。如果你知道要传入哪些参数,可以轻松完成。这些参数可以是列表、元组、numpy数组或torch张量的形式。

[2]:
d1 = Normal([0.3, 0.7, 1.1], [1.1, 0.3, 1.8], covariance_type='diag')
d2 = Exponential([0.8, 1.4, 4.1])
d3 = Categorical([[0.3, 0.2, 0.5], [0.2, 0.1, 0.7]])

d11 = Normal((0.3, 0.7, 1.1), (1.1, 0.3, 1.8), covariance_type='diag')
d12 = Normal(numpy.array([0.3, 0.7, 1.1]), numpy.array([1.1, 0.3, 1.8]), covariance_type='diag')
d13 = Normal(torch.tensor([0.3, 0.7, 1.1]), torch.tensor([1.1, 0.3, 1.8]), covariance_type='diag')

如果没有参数,您可以直接从数据中学习它们。之前,这是通过Distribution.from_samples方法完成的。然而,由于pomegranate v1.0.0旨在更类似于sklearn,直接从数据学习应该使用fit来完成。这将通过数据使用MLE推导出参数。

[3]:
torch.manual_seed(0)
X = torch.randn(100, 3)

d4 = Normal().fit(X)
d4.means, d4.covs
[3]:
(Parameter containing:
 tensor([-0.0132, -0.0643,  0.0985]),
 Parameter containing:
 tensor([[ 0.8174,  0.0668, -0.0590],
         [ 0.0668,  0.7918,  0.1045],
         [-0.0590,  0.1045,  0.9713]]))
[4]:
X2 = torch.randint(3, size=(20, 4))

d5 = Categorical().fit(X2)
d5.probs
[4]:
Parameter containing:
tensor([[0.3500, 0.2000, 0.4500],
        [0.3500, 0.3500, 0.3000],
        [0.4500, 0.3500, 0.2000],
        [0.2500, 0.3500, 0.4000]])

与sklearn类似,任何用于训练的超参数(如正则化)将在初始化时传入。

概率与对数概率

所有分布都可以使用各自的方法计算概率和对数概率。

[5]:
d4.log_probability(X)
[5]:
tensor([-3.9452, -3.2879, -5.3004, -3.6380, -3.9600, -4.9730, -3.2313, -5.4351,
        -3.0938, -4.7396, -3.6861, -2.6550, -2.8112, -2.9265, -2.6482, -8.2887,
        -3.7147, -2.6614, -2.8981, -9.5658, -6.2381, -3.2002, -5.7639, -6.9646,
        -4.4075, -3.8988, -3.0689, -3.2529, -3.6521, -5.3077, -5.5544, -3.2166,
        -5.6651, -7.9825, -2.6263, -2.6650, -3.4593, -6.5449, -2.8980, -3.0915,
        -4.5713, -3.1680, -4.8918, -3.0811, -4.6555, -3.1913, -3.5364, -3.1703,
        -2.5797, -3.4614, -2.5375, -4.8910, -2.9253, -3.9987, -3.0313, -3.2010,
        -2.6444, -3.2952, -3.7149, -3.9957, -4.4953, -3.8348, -4.1071, -4.5762,
        -2.9732, -2.9576, -3.4012, -3.4736, -3.9769, -3.7505, -4.5513, -4.0950,
        -4.5067, -2.7840, -3.3281, -4.1321, -2.9699, -3.8536, -3.9683, -5.8055,
        -5.3984, -4.9514, -2.7441, -3.8885, -4.5353, -3.0082, -2.8207, -3.3852,
        -3.9225, -3.7536, -6.9391, -3.0570, -5.8579, -3.4830, -2.6783, -5.0286,
        -2.9454, -3.4192, -3.8757, -4.4241])
[6]:
d1.log_probability(X)
[6]:
tensor([ -9.6459,  -2.8243, -16.0460,  -3.9857,  -2.9930, -11.8263,  -2.6130,
         -6.0255,  -5.7931,  -3.3603,  -7.4735,  -5.1014,  -3.1393,  -3.6045,
         -3.4457, -14.1638,  -2.9771,  -4.5638,  -5.1863,  -5.1922, -10.3993,
         -5.5200,  -5.2215,  -7.2889,  -7.4847,  -7.9908,  -2.9989,  -6.8441,
         -4.6477,  -4.3911,  -6.8748,  -3.9965, -10.5521, -22.9875,  -3.1194,
         -2.8532,  -6.6198,  -8.0589,  -3.4627,  -7.2507,  -5.3280,  -3.2750,
         -4.5530,  -6.5848,  -2.8317,  -4.6167,  -9.5592,  -5.2165,  -3.4062,
         -3.2597,  -3.9544, -14.5495,  -3.4490,  -3.8333,  -3.5855,  -2.8570,
         -3.3047,  -5.5304, -10.1993, -11.8056,  -3.3747,  -8.7955,  -3.4717,
        -10.6717,  -3.0119,  -2.9799,  -3.2086,  -3.6065,  -8.6801,  -3.4716,
         -4.2680,  -6.6669,  -4.1253,  -3.1685,  -5.0236,  -3.8058,  -3.1228,
         -5.6273,  -3.9447,  -5.2440, -14.2746,  -4.6809,  -4.1667,  -3.5050,
         -3.8123,  -4.4155,  -4.7357,  -5.1111,  -3.4382,  -6.3055,  -6.9832,
         -2.7879,  -5.8146,  -3.9857,  -4.3523,  -5.0716,  -4.9841,  -6.1210,
         -3.9729, -10.4107])
[7]:
d4.probability(X)
[7]:
tensor([1.9348e-02, 3.7332e-02, 4.9895e-03, 2.6305e-02, 1.9064e-02, 6.9225e-03,
        3.9505e-02, 4.3607e-03, 4.5328e-02, 8.7420e-03, 2.5069e-02, 7.0302e-02,
        6.0135e-02, 5.3586e-02, 7.0780e-02, 2.5134e-04, 2.4362e-02, 6.9850e-02,
        5.5130e-02, 7.0085e-05, 1.9536e-03, 4.0754e-02, 3.1389e-03, 9.4471e-04,
        1.2186e-02, 2.0266e-02, 4.6471e-02, 3.8663e-02, 2.5936e-02, 4.9533e-03,
        3.8702e-03, 4.0090e-02, 3.4649e-03, 3.4138e-04, 7.2346e-02, 6.9601e-02,
        3.1452e-02, 1.4374e-03, 5.5134e-02, 4.5435e-02, 1.0344e-02, 4.2086e-02,
        7.5080e-03, 4.5910e-02, 9.5090e-03, 4.1118e-02, 2.9118e-02, 4.1990e-02,
        7.5798e-02, 3.1386e-02, 7.9060e-02, 7.5137e-03, 5.3647e-02, 1.8340e-02,
        4.8254e-02, 4.0721e-02, 7.1046e-02, 3.7061e-02, 2.4357e-02, 1.8395e-02,
        1.1162e-02, 2.1606e-02, 1.6455e-02, 1.0294e-02, 5.1139e-02, 5.1945e-02,
        3.3334e-02, 3.1006e-02, 1.8743e-02, 2.3507e-02, 1.0554e-02, 1.6656e-02,
        1.1035e-02, 6.1790e-02, 3.5861e-02, 1.6049e-02, 5.1307e-02, 2.1204e-02,
        1.8905e-02, 3.0109e-03, 4.5240e-03, 7.0733e-03, 6.4304e-02, 2.0475e-02,
        1.0723e-02, 4.9382e-02, 5.9563e-02, 3.3871e-02, 1.9791e-02, 2.3433e-02,
        9.6916e-04, 4.7028e-02, 2.8573e-03, 3.0714e-02, 6.8677e-02, 6.5477e-03,
        5.2580e-02, 3.2739e-02, 2.0740e-02, 1.1985e-02])
[8]:
d1.probability(X)
[8]:
tensor([6.4691e-05, 5.9351e-02, 1.0748e-07, 1.8580e-02, 5.0137e-02, 7.3099e-06,
        7.3312e-02, 2.4163e-03, 3.0485e-03, 3.4725e-02, 5.6796e-04, 6.0884e-03,
        4.3315e-02, 2.7200e-02, 3.1883e-02, 7.0588e-07, 5.0939e-02, 1.0422e-02,
        5.5924e-03, 5.5596e-03, 3.0455e-05, 4.0060e-03, 5.3992e-03, 6.8307e-04,
        5.6159e-04, 3.3856e-04, 4.9844e-02, 1.0658e-03, 9.5840e-03, 1.2387e-02,
        1.0335e-03, 1.8379e-02, 2.6139e-05, 1.0391e-10, 4.4183e-02, 5.7661e-02,
        1.3337e-03, 3.1628e-04, 3.1344e-02, 7.0967e-04, 4.8536e-03, 3.7815e-02,
        1.0535e-02, 1.3812e-03, 5.8910e-02, 9.8850e-03, 7.0553e-05, 5.4264e-03,
        3.3168e-02, 3.8402e-02, 1.9170e-02, 4.7997e-07, 3.1776e-02, 2.1639e-02,
        2.7723e-02, 5.7439e-02, 3.6711e-02, 3.9645e-03, 3.7195e-05, 7.4626e-06,
        3.4229e-02, 1.5142e-04, 3.1063e-02, 2.3192e-05, 4.9196e-02, 5.0799e-02,
        4.0414e-02, 2.7148e-02, 1.6994e-04, 3.1066e-02, 1.4009e-02, 1.2723e-03,
        1.6158e-02, 4.2068e-02, 6.5808e-03, 2.2242e-02, 4.4035e-02, 3.5983e-03,
        1.9357e-02, 5.2789e-03, 6.3185e-07, 9.2710e-03, 1.5503e-02, 3.0046e-02,
        2.2097e-02, 1.2088e-02, 8.7762e-03, 6.0295e-03, 3.2123e-02, 1.8262e-03,
        9.2731e-04, 6.1549e-02, 2.9838e-03, 1.8579e-02, 1.2878e-02, 6.2726e-03,
        6.8462e-03, 2.1963e-03, 1.8819e-02, 3.0109e-05])

与初始化类似,这些可以是列表、numpy数组或torch张量。

[9]:
d4.log_probability(torch.tensor(X))
/tmp/ipykernel_88680/2372227.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  d4.log_probability(torch.tensor(X))
[9]:
tensor([-3.9452, -3.2879, -5.3004, -3.6380, -3.9600, -4.9730, -3.2313, -5.4351,
        -3.0938, -4.7396, -3.6861, -2.6550, -2.8112, -2.9265, -2.6482, -8.2887,
        -3.7147, -2.6614, -2.8981, -9.5658, -6.2381, -3.2002, -5.7639, -6.9646,
        -4.4075, -3.8988, -3.0689, -3.2529, -3.6521, -5.3077, -5.5544, -3.2166,
        -5.6651, -7.9825, -2.6263, -2.6650, -3.4593, -6.5449, -2.8980, -3.0915,
        -4.5713, -3.1680, -4.8918, -3.0811, -4.6555, -3.1913, -3.5364, -3.1703,
        -2.5797, -3.4614, -2.5375, -4.8910, -2.9253, -3.9987, -3.0313, -3.2010,
        -2.6444, -3.2952, -3.7149, -3.9957, -4.4953, -3.8348, -4.1071, -4.5762,
        -2.9732, -2.9576, -3.4012, -3.4736, -3.9769, -3.7505, -4.5513, -4.0950,
        -4.5067, -2.7840, -3.3281, -4.1321, -2.9699, -3.8536, -3.9683, -5.8055,
        -5.3984, -4.9514, -2.7441, -3.8885, -4.5353, -3.0082, -2.8207, -3.3852,
        -3.9225, -3.7536, -6.9391, -3.0570, -5.8579, -3.4830, -2.6783, -5.0286,
        -2.9454, -3.4192, -3.8757, -4.4241])

摘要

虽然从数据中学习参数的主要方法是使用fit方法,但这种学习的底层引擎是一对操作:summarizefrom_summaries。在summarize中,数据被压缩为可加性充分统计量,这些统计量可以跨批次求和。例如:

[10]:
d = Normal()
d.summarize(X[:5])

d._xw_sum
[10]:
tensor([-3.0192, -0.8312,  2.1886])
[11]:
d.summarize(X[5:])

d._xw_sum
[11]:
tensor([-1.3155, -6.4282,  9.8475])

如果我们对整个数据集进行了汇总,这些值将是相同的。

[12]:
d2 = Normal()
d2.summarize(X)

d2._xw_sum
[12]:
tensor([-1.3155, -6.4282,  9.8475])

通过这些通常存储为_w_sum_xw_sum的值,您可以完美地重建拟合整个数据集时会得到的值。您可以使用from_summaries方法来实现这一点。

[13]:
d.from_summaries()
d2.from_summaries()

d.means, d2.means
[13]:
(Parameter containing:
 tensor([-0.0132, -0.0643,  0.0985]),
 Parameter containing:
 tensor([-0.0132, -0.0643,  0.0985]))

我们将在其他教程中更详细地探讨这些概念,特别是如何通过这些方法轻松实现核外学习的批处理方案,以及如何在有限的GPU内存下适应大型数据集。