概率分布
作者: Jacob Schreiber 联系方式: jmschreiber91@gmail.com
pomegranate中的一切都围绕着概率分布的使用展开。虽然这些对象可以单独使用,例如拟合数据或给定参数并用于评估新样本,但它们旨在作为更大组合模型(如混合模型或隐马尔可夫模型)的一部分使用。由于pomegranate中的所有组件都设计为即插即用,这意味着任何概率分布都可以直接嵌入到其他模型中。
pomegranate v1.0.0版本与之前版本在分布上的一个关键区别是:早期版本中的分布通常是单变量的(即一个对象代表一个维度),而v1.0.0及之后版本中每个分布都是多变量的。在早期版本中若需建模多个维度,必须使用包含多个分布的IndependentComponentsDistribution,但在v1.0.0中只需使用单个分布对象即可。另外说明,IndependentComponents仍然可用,以便您可以用不同分布来建模不同维度。
[1]:
import numpy
import torch
from pomegranate.distributions import *
numpy.random.seed(0)
numpy.set_printoptions(suppress=True)
%load_ext watermark
%watermark -m -n -p numpy,torch,pomegranate
numpy : 1.23.4
torch : 1.13.0
pomegranate: 1.0.0
Compiler : GCC 11.2.0
OS : Linux
Release : 4.15.0-208-generic
Machine : x86_64
Processor : x86_64
CPU cores : 8
Architecture: 64bit
初始化与拟合
首先,我们来看看如何创建一个概率分布。如果你知道要传入哪些参数,可以轻松完成。这些参数可以是列表、元组、numpy数组或torch张量的形式。
[2]:
d1 = Normal([0.3, 0.7, 1.1], [1.1, 0.3, 1.8], covariance_type='diag')
d2 = Exponential([0.8, 1.4, 4.1])
d3 = Categorical([[0.3, 0.2, 0.5], [0.2, 0.1, 0.7]])
d11 = Normal((0.3, 0.7, 1.1), (1.1, 0.3, 1.8), covariance_type='diag')
d12 = Normal(numpy.array([0.3, 0.7, 1.1]), numpy.array([1.1, 0.3, 1.8]), covariance_type='diag')
d13 = Normal(torch.tensor([0.3, 0.7, 1.1]), torch.tensor([1.1, 0.3, 1.8]), covariance_type='diag')
如果没有参数,您可以直接从数据中学习它们。之前,这是通过Distribution.from_samples方法完成的。然而,由于pomegranate v1.0.0旨在更类似于sklearn,直接从数据学习应该使用fit来完成。这将通过数据使用MLE推导出参数。
[3]:
torch.manual_seed(0)
X = torch.randn(100, 3)
d4 = Normal().fit(X)
d4.means, d4.covs
[3]:
(Parameter containing:
tensor([-0.0132, -0.0643, 0.0985]),
Parameter containing:
tensor([[ 0.8174, 0.0668, -0.0590],
[ 0.0668, 0.7918, 0.1045],
[-0.0590, 0.1045, 0.9713]]))
[4]:
X2 = torch.randint(3, size=(20, 4))
d5 = Categorical().fit(X2)
d5.probs
[4]:
Parameter containing:
tensor([[0.3500, 0.2000, 0.4500],
[0.3500, 0.3500, 0.3000],
[0.4500, 0.3500, 0.2000],
[0.2500, 0.3500, 0.4000]])
与sklearn类似,任何用于训练的超参数(如正则化)将在初始化时传入。
概率与对数概率
所有分布都可以使用各自的方法计算概率和对数概率。
[5]:
d4.log_probability(X)
[5]:
tensor([-3.9452, -3.2879, -5.3004, -3.6380, -3.9600, -4.9730, -3.2313, -5.4351,
-3.0938, -4.7396, -3.6861, -2.6550, -2.8112, -2.9265, -2.6482, -8.2887,
-3.7147, -2.6614, -2.8981, -9.5658, -6.2381, -3.2002, -5.7639, -6.9646,
-4.4075, -3.8988, -3.0689, -3.2529, -3.6521, -5.3077, -5.5544, -3.2166,
-5.6651, -7.9825, -2.6263, -2.6650, -3.4593, -6.5449, -2.8980, -3.0915,
-4.5713, -3.1680, -4.8918, -3.0811, -4.6555, -3.1913, -3.5364, -3.1703,
-2.5797, -3.4614, -2.5375, -4.8910, -2.9253, -3.9987, -3.0313, -3.2010,
-2.6444, -3.2952, -3.7149, -3.9957, -4.4953, -3.8348, -4.1071, -4.5762,
-2.9732, -2.9576, -3.4012, -3.4736, -3.9769, -3.7505, -4.5513, -4.0950,
-4.5067, -2.7840, -3.3281, -4.1321, -2.9699, -3.8536, -3.9683, -5.8055,
-5.3984, -4.9514, -2.7441, -3.8885, -4.5353, -3.0082, -2.8207, -3.3852,
-3.9225, -3.7536, -6.9391, -3.0570, -5.8579, -3.4830, -2.6783, -5.0286,
-2.9454, -3.4192, -3.8757, -4.4241])
[6]:
d1.log_probability(X)
[6]:
tensor([ -9.6459, -2.8243, -16.0460, -3.9857, -2.9930, -11.8263, -2.6130,
-6.0255, -5.7931, -3.3603, -7.4735, -5.1014, -3.1393, -3.6045,
-3.4457, -14.1638, -2.9771, -4.5638, -5.1863, -5.1922, -10.3993,
-5.5200, -5.2215, -7.2889, -7.4847, -7.9908, -2.9989, -6.8441,
-4.6477, -4.3911, -6.8748, -3.9965, -10.5521, -22.9875, -3.1194,
-2.8532, -6.6198, -8.0589, -3.4627, -7.2507, -5.3280, -3.2750,
-4.5530, -6.5848, -2.8317, -4.6167, -9.5592, -5.2165, -3.4062,
-3.2597, -3.9544, -14.5495, -3.4490, -3.8333, -3.5855, -2.8570,
-3.3047, -5.5304, -10.1993, -11.8056, -3.3747, -8.7955, -3.4717,
-10.6717, -3.0119, -2.9799, -3.2086, -3.6065, -8.6801, -3.4716,
-4.2680, -6.6669, -4.1253, -3.1685, -5.0236, -3.8058, -3.1228,
-5.6273, -3.9447, -5.2440, -14.2746, -4.6809, -4.1667, -3.5050,
-3.8123, -4.4155, -4.7357, -5.1111, -3.4382, -6.3055, -6.9832,
-2.7879, -5.8146, -3.9857, -4.3523, -5.0716, -4.9841, -6.1210,
-3.9729, -10.4107])
[7]:
d4.probability(X)
[7]:
tensor([1.9348e-02, 3.7332e-02, 4.9895e-03, 2.6305e-02, 1.9064e-02, 6.9225e-03,
3.9505e-02, 4.3607e-03, 4.5328e-02, 8.7420e-03, 2.5069e-02, 7.0302e-02,
6.0135e-02, 5.3586e-02, 7.0780e-02, 2.5134e-04, 2.4362e-02, 6.9850e-02,
5.5130e-02, 7.0085e-05, 1.9536e-03, 4.0754e-02, 3.1389e-03, 9.4471e-04,
1.2186e-02, 2.0266e-02, 4.6471e-02, 3.8663e-02, 2.5936e-02, 4.9533e-03,
3.8702e-03, 4.0090e-02, 3.4649e-03, 3.4138e-04, 7.2346e-02, 6.9601e-02,
3.1452e-02, 1.4374e-03, 5.5134e-02, 4.5435e-02, 1.0344e-02, 4.2086e-02,
7.5080e-03, 4.5910e-02, 9.5090e-03, 4.1118e-02, 2.9118e-02, 4.1990e-02,
7.5798e-02, 3.1386e-02, 7.9060e-02, 7.5137e-03, 5.3647e-02, 1.8340e-02,
4.8254e-02, 4.0721e-02, 7.1046e-02, 3.7061e-02, 2.4357e-02, 1.8395e-02,
1.1162e-02, 2.1606e-02, 1.6455e-02, 1.0294e-02, 5.1139e-02, 5.1945e-02,
3.3334e-02, 3.1006e-02, 1.8743e-02, 2.3507e-02, 1.0554e-02, 1.6656e-02,
1.1035e-02, 6.1790e-02, 3.5861e-02, 1.6049e-02, 5.1307e-02, 2.1204e-02,
1.8905e-02, 3.0109e-03, 4.5240e-03, 7.0733e-03, 6.4304e-02, 2.0475e-02,
1.0723e-02, 4.9382e-02, 5.9563e-02, 3.3871e-02, 1.9791e-02, 2.3433e-02,
9.6916e-04, 4.7028e-02, 2.8573e-03, 3.0714e-02, 6.8677e-02, 6.5477e-03,
5.2580e-02, 3.2739e-02, 2.0740e-02, 1.1985e-02])
[8]:
d1.probability(X)
[8]:
tensor([6.4691e-05, 5.9351e-02, 1.0748e-07, 1.8580e-02, 5.0137e-02, 7.3099e-06,
7.3312e-02, 2.4163e-03, 3.0485e-03, 3.4725e-02, 5.6796e-04, 6.0884e-03,
4.3315e-02, 2.7200e-02, 3.1883e-02, 7.0588e-07, 5.0939e-02, 1.0422e-02,
5.5924e-03, 5.5596e-03, 3.0455e-05, 4.0060e-03, 5.3992e-03, 6.8307e-04,
5.6159e-04, 3.3856e-04, 4.9844e-02, 1.0658e-03, 9.5840e-03, 1.2387e-02,
1.0335e-03, 1.8379e-02, 2.6139e-05, 1.0391e-10, 4.4183e-02, 5.7661e-02,
1.3337e-03, 3.1628e-04, 3.1344e-02, 7.0967e-04, 4.8536e-03, 3.7815e-02,
1.0535e-02, 1.3812e-03, 5.8910e-02, 9.8850e-03, 7.0553e-05, 5.4264e-03,
3.3168e-02, 3.8402e-02, 1.9170e-02, 4.7997e-07, 3.1776e-02, 2.1639e-02,
2.7723e-02, 5.7439e-02, 3.6711e-02, 3.9645e-03, 3.7195e-05, 7.4626e-06,
3.4229e-02, 1.5142e-04, 3.1063e-02, 2.3192e-05, 4.9196e-02, 5.0799e-02,
4.0414e-02, 2.7148e-02, 1.6994e-04, 3.1066e-02, 1.4009e-02, 1.2723e-03,
1.6158e-02, 4.2068e-02, 6.5808e-03, 2.2242e-02, 4.4035e-02, 3.5983e-03,
1.9357e-02, 5.2789e-03, 6.3185e-07, 9.2710e-03, 1.5503e-02, 3.0046e-02,
2.2097e-02, 1.2088e-02, 8.7762e-03, 6.0295e-03, 3.2123e-02, 1.8262e-03,
9.2731e-04, 6.1549e-02, 2.9838e-03, 1.8579e-02, 1.2878e-02, 6.2726e-03,
6.8462e-03, 2.1963e-03, 1.8819e-02, 3.0109e-05])
与初始化类似,这些可以是列表、numpy数组或torch张量。
[9]:
d4.log_probability(torch.tensor(X))
/tmp/ipykernel_88680/2372227.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
d4.log_probability(torch.tensor(X))
[9]:
tensor([-3.9452, -3.2879, -5.3004, -3.6380, -3.9600, -4.9730, -3.2313, -5.4351,
-3.0938, -4.7396, -3.6861, -2.6550, -2.8112, -2.9265, -2.6482, -8.2887,
-3.7147, -2.6614, -2.8981, -9.5658, -6.2381, -3.2002, -5.7639, -6.9646,
-4.4075, -3.8988, -3.0689, -3.2529, -3.6521, -5.3077, -5.5544, -3.2166,
-5.6651, -7.9825, -2.6263, -2.6650, -3.4593, -6.5449, -2.8980, -3.0915,
-4.5713, -3.1680, -4.8918, -3.0811, -4.6555, -3.1913, -3.5364, -3.1703,
-2.5797, -3.4614, -2.5375, -4.8910, -2.9253, -3.9987, -3.0313, -3.2010,
-2.6444, -3.2952, -3.7149, -3.9957, -4.4953, -3.8348, -4.1071, -4.5762,
-2.9732, -2.9576, -3.4012, -3.4736, -3.9769, -3.7505, -4.5513, -4.0950,
-4.5067, -2.7840, -3.3281, -4.1321, -2.9699, -3.8536, -3.9683, -5.8055,
-5.3984, -4.9514, -2.7441, -3.8885, -4.5353, -3.0082, -2.8207, -3.3852,
-3.9225, -3.7536, -6.9391, -3.0570, -5.8579, -3.4830, -2.6783, -5.0286,
-2.9454, -3.4192, -3.8757, -4.4241])
摘要
虽然从数据中学习参数的主要方法是使用fit方法,但这种学习的底层引擎是一对操作:summarize和from_summaries。在summarize中,数据被压缩为可加性充分统计量,这些统计量可以跨批次求和。例如:
[10]:
d = Normal()
d.summarize(X[:5])
d._xw_sum
[10]:
tensor([-3.0192, -0.8312, 2.1886])
[11]:
d.summarize(X[5:])
d._xw_sum
[11]:
tensor([-1.3155, -6.4282, 9.8475])
如果我们对整个数据集进行了汇总,这些值将是相同的。
[12]:
d2 = Normal()
d2.summarize(X)
d2._xw_sum
[12]:
tensor([-1.3155, -6.4282, 9.8475])
通过这些通常存储为_w_sum和_xw_sum的值,您可以完美地重建拟合整个数据集时会得到的值。您可以使用from_summaries方法来实现这一点。
[13]:
d.from_summaries()
d2.from_summaries()
d.means, d2.means
[13]:
(Parameter containing:
tensor([-0.0132, -0.0643, 0.0985]),
Parameter containing:
tensor([-0.0132, -0.0643, 0.0985]))
我们将在其他教程中更详细地探讨这些概念,特别是如何通过这些方法轻松实现核外学习的批处理方案,以及如何在有限的GPU内存下适应大型数据集。