截断#

class pymc.Truncated(name, *args, rng=None, dims=None, initval=None, observed=None, total_size=None, transform=UNSET, **kwargs)[源代码]#

截断分布

截断分布的pdf是

\[\begin{split}\begin{cases} 0 & \text{对于 } x < lower, \\ \frac{\text{PDF}(x, dist)}{\text{CDF}(upper, dist) - \text{CDF}(lower, dist)} & \text{对于 } lower <= x <= upper, \\ 0 & \text{对于 } x > upper, \end{cases}\end{split}\]
参数:
dist: 未命名的发行版

通过 .dist() API 创建的单变量分布将被截断。该分布必须是一个纯 RandomVariable,并且必须实现 logcdf 方法以用于 MCMC 采样。

警告

dist 将被克隆,使其独立于作为输入传递的那个。

lower: tensor_like of float or None

下(左)截断点。如果为 None,则分布不会被左截断。

upper: tensor_like of float or None

上(右)截断点。如果为 None,分布将不会在右侧被截断。

max_n_steps: int, 默认为 10_000

在进行拒绝采样时尝试的最大重采样次数。如果在这么多步骤后仍未达到收敛,则会引发 TruncationError

返回:
truncated_distribution: TensorVariable

表示截断 RandomVariable 的图。如果分布的 Op 有一个分派的 _truncated 函数,则可以使用专门的 Op。否则,将返回一个表示截断过程的 SymbolicRandomVariable 图,通过逆 CDF 采样(如果底层分布有 logcdf 方法)或拒绝采样。

示例

with pm.Model():
    normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
    truncated_normal = pm.Truncated("truncated_normal", normal_dist, lower=-1, upper=1)

方法

Truncated.dist(dist[, lower, upper, max_n_steps])

创建一个与 cls 分布相对应的张量变量。