distributions.md 66.0 KB
Newer Older
W
wizardforcel 已提交
1 2


J
dis 25%  
janhu 已提交
3
# 概率分布 - torch.distributions
W
wizardforcel 已提交
4

J
yizhe  
janhu 已提交
5 6
> 译者:[hijkzzz](https://github.com/hijkzzz)

J
标点  
janhu 已提交
7
`distributions` 包含可参数化的概率分布和采样函数. 这允许构造用于优化的随机计算图和随机梯度估计器.  这个包一般遵循 [TensorFlow Distributions](https://arxiv.org/abs/1711.10604) 包的设计.
W
wizardforcel 已提交
8

J
标点  
janhu 已提交
9
通常, 不可能直接通过随机样本反向传播.  但是, 有两种主要方法可创建可以反向传播的代理函数.  即得分函数估计器/似然比估计器/REINFORCE和pathwise derivative估计器.  REINFORCE通常被视为强化学习中策略梯度方法的基础, 并且pathwise derivative估计器常见于变分自动编码器中的重新参数化技巧. 得分函数仅需要样本的值 ![](img/cb804637f7fdaaf91569cfe4f047b418.jpg), pathwise derivative 需要导数 ![](img/385dbaaac9dd8aad33acc31ac64d2f27.jpg). 接下来的部分将在一个强化学习示例中讨论这两个问题.  有关详细信息, 请参阅 [Gradient Estimation Using Stochastic Computation Graphs](https://arxiv.org/abs/1506.05254) .
W
wizardforcel 已提交
10

J
dis 25%  
janhu 已提交
11
## 得分函数
W
wizardforcel 已提交
12

J
标点  
janhu 已提交
13
当概率密度函数相对于其参数可微分时, 我们只需要`sample()``log_prob()`来实现REINFORCE:
W
wizardforcel 已提交
14 15 16

![](img/b50e881c13615b1d9aa00ad0c9cdfa99.jpg)

J
dis 25%  
janhu 已提交
17
![](img/51b8359f970d2bfe2ad4cdc3ac1aed3c.jpg) 是参数, ![](img/82005cc2e0087e2a52c7e43df4a19a00.jpg) 是学习速率, ![](img/f9f040e861365a0560b2552b4e4e17da.jpg) 是奖励 并且 ![](img/2e84bb32ea0808870a16b888aeaf8d0d.jpg) 是在状态 ![](img/0492c0bfd615cb5e61c847ece512ff51.jpg) 以及给定策略 ![](img/5f3ddae3395c04f9346a3ac1d327ae2a.jpg)执行动作 ![](img/070b1af5eca3a5c5d72884b536090f17.jpg) 的概率.
W
wizardforcel 已提交
18

J
标点  
janhu 已提交
19
在实践中, 我们将从网络输出中采样一个动作, 将这个动作应用于一个环境中, 然后使用`log_prob`构造一个等效的损失函数. 请注意, 我们使用负数是因为优化器使用梯度下降, 而上面的规则假设梯度上升. 有了确定的策略, REINFORCE的实现代码如下:
W
wizardforcel 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33

```py
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

```

## Pathwise derivative

J
标点  
janhu 已提交
34
实现这些随机/策略梯度的另一种方法是使用来自`rsample()`方法的重新参数化技巧, 其中参数化随机变量可以通过无参数随机变量的参数确定性函数构造.  因此, 重新参数化的样本变得可微分.  实现Pathwise derivative的代码如下:
W
wizardforcel 已提交
35 36 37 38 39 40 41 42 43 44 45 46

```py
params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assuming that reward is differentiable
loss = -reward
loss.backward()

```

J
dis 25%  
janhu 已提交
47
## 分布
W
wizardforcel 已提交
48 49

```py
W
wizardforcel 已提交
50
class torch.distributions.distribution.Distribution(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)
W
wizardforcel 已提交
51 52
```

J
dis 25%  
janhu 已提交
53
基类: [`object`](https://docs.python.org/3/library/functions.html#object "(in Python v3.7)")
W
wizardforcel 已提交
54

J
dis 25%  
janhu 已提交
55
Distribution是概率分布的抽象基类.
W
wizardforcel 已提交
56 57

```py
W
wizardforcel 已提交
58
arg_constraints
W
wizardforcel 已提交
59 60
```

片刻小哥哥's avatar
片刻小哥哥 已提交
61
从参数名称返回字典到 [`Constraint`](#torch.distributions.constraints.Constraint "torch.distributions.constraints.Constraint") 对象(应该满足这个分布的每个参数).不是张量的arg不需要出现在这个字典中.
W
wizardforcel 已提交
62 63

```py
W
wizardforcel 已提交
64
batch_shape
W
wizardforcel 已提交
65 66
```

J
dis 25%  
janhu 已提交
67
返回批量参数的形状.
W
wizardforcel 已提交
68 69

```py
W
wizardforcel 已提交
70
cdf(value)
W
wizardforcel 已提交
71 72
```

J
dis 25%  
janhu 已提交
73 74 75
返回`value`处的累积密度/质量函数估计.

| 参数: | **value** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – |
W
wizardforcel 已提交
76 77 78


```py
W
wizardforcel 已提交
79
entropy()
W
wizardforcel 已提交
80 81
```

J
dis 25%  
janhu 已提交
82 83 84
返回分布的熵, 批量的形状为 batch_shape.

| 返回值: | Tensor 形状为 batch_shape. |
W
wizardforcel 已提交
85 86 87


```py
W
wizardforcel 已提交
88
enumerate_support(expand=True)
W
wizardforcel 已提交
89 90
```

J
dis 25%  
janhu 已提交
91 92
返回包含离散分布支持的所有值的张量. 结果将在维度0上枚举, 所以结果的形状将是 `(cardinality,) + batch_shape + event_shape` (对于单变量分布 `event_shape = ()`).

J
标点  
janhu 已提交
93
注意, 这在lock-step中枚举了所有批处理张量`[[0, 0], [1, 1], …]`. 当 `expand=False`, 枚举沿着维度 0进行, 但是剩下的批处理维度是单维度, `[[0], [1], ..`.
W
wizardforcel 已提交
94

J
dis 25%  
janhu 已提交
95
遍历整个笛卡尔积的使用 `itertools.product(m.enumerate_support())`.
W
wizardforcel 已提交
96

J
dis 25%  
janhu 已提交
97 98 99
| 参数: | **expand** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")) – 是否扩展对批处理dim的支持以匹配分布的 `batch_shape`. |

| 返回值: | 张量在维上0迭代. |
W
wizardforcel 已提交
100 101 102


```py
W
wizardforcel 已提交
103
event_shape
W
wizardforcel 已提交
104 105
```

J
dis 25%  
janhu 已提交
106
返回单个样本的形状 (非批量).
W
wizardforcel 已提交
107 108

```py
W
wizardforcel 已提交
109
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
110 111
```

J
标点  
janhu 已提交
112
返回一个新的分布实例(或填充派生类提供的现有实例), 其批处理维度扩展为 `batch_shape`.  这个方法调用 [`expand`](tensors.html#torch.Tensor.expand "torch.Tensor.expand") 在分布的参数上. 因此, 这不会为扩展的分布实例分配新的内存.  此外, 第一次创建实例时, 这不会在中重复任何参数检查或参数广播在 `__init__.py`.
J
dis 25%  
janhu 已提交
113 114

参数: 
W
wizardforcel 已提交
115

J
dis 25%  
janhu 已提交
116 117
*   **batch_shape** (_torch.Size_) – 所需的扩展尺寸.
*   **_instance** – 由需要重写`.expand`的子类提供的新实例.
W
wizardforcel 已提交
118 119


J
dis 25%  
janhu 已提交
120
| 返回值: | 批处理维度扩展为`batch_size`的新分布实例. |
W
wizardforcel 已提交
121

W
wizardforcel 已提交
122 123

```py
W
wizardforcel 已提交
124
icdf(value)
W
wizardforcel 已提交
125 126
```

J
dis 25%  
janhu 已提交
127 128 129
 返回按`value`计算的反向累积密度/质量函数.

| 参数: | **value** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – |
W
wizardforcel 已提交
130 131 132


```py
W
wizardforcel 已提交
133
log_prob(value)
W
wizardforcel 已提交
134 135
```

J
dis 25%  
janhu 已提交
136 137 138
返回按`value`计算的概率密度/质量函数的对数.

| 参数: | **value** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – |
W
wizardforcel 已提交
139 140 141


```py
W
wizardforcel 已提交
142
mean
W
wizardforcel 已提交
143 144
```

J
dis 25%  
janhu 已提交
145
返回分布的平均值.
W
wizardforcel 已提交
146 147

```py
W
wizardforcel 已提交
148
perplexity()
W
wizardforcel 已提交
149 150
```

J
dis 25%  
janhu 已提交
151 152 153
返回分布的困惑度, 批量的关于 batch_shape.

| 返回值: | 形状为 batch_shape 的张量. |
W
wizardforcel 已提交
154 155 156


```py
W
wizardforcel 已提交
157
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
158 159
```

J
标点  
janhu 已提交
160
如果分布的参数是批量的, 则生成sample_shape形状的重新参数化样本或sample_shape形状的批量重新参数化样本.
W
wizardforcel 已提交
161 162

```py
W
wizardforcel 已提交
163
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
164 165
```

J
标点  
janhu 已提交
166
如果分布的参数是批量的, 则生成sample_shape形状的样本或sample_shape形状的批量样本.
W
wizardforcel 已提交
167 168

```py
W
wizardforcel 已提交
169
sample_n(n)
W
wizardforcel 已提交
170 171
```

J
标点  
janhu 已提交
172
如果分布参数是分批的, 则生成n个样本或n批样本.
W
wizardforcel 已提交
173 174

```py
W
wizardforcel 已提交
175
stddev
W
wizardforcel 已提交
176 177
```

J
dis 25%  
janhu 已提交
178
返回分布的标准差.
W
wizardforcel 已提交
179 180

```py
W
wizardforcel 已提交
181
support
W
wizardforcel 已提交
182 183
```

J
dis 25%  
janhu 已提交
184
返回[`Constraint`](#torch.distributions.constraints.Constraint "torch.distributions.constraints.Constraint") 对象表示该分布的支持.
W
wizardforcel 已提交
185 186

```py
W
wizardforcel 已提交
187
variance
W
wizardforcel 已提交
188 189
```

J
dis 25%  
janhu 已提交
190
返回分布的方差.
W
wizardforcel 已提交
191 192 193 194

## ExponentialFamily

```py
W
wizardforcel 已提交
195
class torch.distributions.exp_family.ExponentialFamily(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)
W
wizardforcel 已提交
196 197
```

J
dis 25%  
janhu 已提交
198
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
199

J
标点  
janhu 已提交
200
指数族是指数族概率分布的抽象基类, 其概率质量/密度函数的形式定义如下
W
wizardforcel 已提交
201 202 203

![](img/0c8313886f5c82dfae90e21b65152815.jpg)

J
dis 25%  
janhu 已提交
204
![](img/51b8359f970d2bfe2ad4cdc3ac1aed3c.jpg) 表示自然参数, ![](img/e705d3772de12f4df3b0cd75af5110a1.jpg) 表示充分统计量, ![](img/f876c4d8353c747436006e70fb6c4f5d.jpg) 是给定族的对数归一化函数  ![](img/d3b6af2f20ffbc8480c6ee97c42958b2.jpg) 是carrier measure.
W
wizardforcel 已提交
205

J
dis 25%  
janhu 已提交
206
注意
W
wizardforcel 已提交
207

J
标点  
janhu 已提交
208
该类是`Distribution`类与指数族分布之间的中介, 主要用于检验`.entropy()`和解析KL散度方法的正确性. 我们使用这个类来计算熵和KL散度使用AD框架和Bregman散度 (出自: Frank Nielsen and Richard Nock, Entropies and Cross-entropies of Exponential Families).
W
wizardforcel 已提交
209 210

```py
W
wizardforcel 已提交
211
entropy()
W
wizardforcel 已提交
212 213
```

J
dis 25%  
janhu 已提交
214
利用对数归一化器的Bregman散度计算熵的方法.
W
wizardforcel 已提交
215 216 217 218

## Bernoulli

```py
W
wizardforcel 已提交
219
class torch.distributions.bernoulli.Bernoulli(probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
220 221
```

J
dis 25%  
janhu 已提交
222
基类: [`torch.distributions.exp_family.ExponentialFamily`](#torch.distributions.exp_family.ExponentialFamily "torch.distributions.exp_family.ExponentialFamily")
W
wizardforcel 已提交
223

J
标点  
janhu 已提交
224
创建参数化的伯努利分布, 根据 [`probs`](#torch.distributions.bernoulli.Bernoulli.probs "torch.distributions.bernoulli.Bernoulli.probs") 或者 [`logits`](#torch.distributions.bernoulli.Bernoulli.logits "torch.distributions.bernoulli.Bernoulli.logits") (但不是同时都有).
W
wizardforcel 已提交
225

J
标点  
janhu 已提交
226
样本是二值的 (0 或者 1). 取值 `1` 伴随概率 `p` , 或者 `0` 伴随概率 `1 - p`.
W
wizardforcel 已提交
227

J
dis 25%  
janhu 已提交
228
例子:
W
wizardforcel 已提交
229 230 231 232 233 234 235 236

```py
>>> m = Bernoulli(torch.tensor([0.3]))
>>> m.sample()  # 30% chance 1; 70% chance 0
tensor([ 0.])

```

J
dis 25%  
janhu 已提交
237
参数: 
W
wizardforcel 已提交
238

W
wizardforcel 已提交
239 240
*   **probs** (_Number__,_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – the probabilty of sampling `1`
*   **logits** (_Number__,_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – the log-odds of sampling `1`
W
wizardforcel 已提交
241

W
wizardforcel 已提交
242

W
wizardforcel 已提交
243 244

```py
W
wizardforcel 已提交
245
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
W
wizardforcel 已提交
246 247 248
```

```py
W
wizardforcel 已提交
249
entropy()
W
wizardforcel 已提交
250 251 252
```

```py
W
wizardforcel 已提交
253
enumerate_support(expand=True)
W
wizardforcel 已提交
254 255 256
```

```py
W
wizardforcel 已提交
257
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
258 259 260
```

```py
W
wizardforcel 已提交
261
has_enumerate_support = True
W
wizardforcel 已提交
262 263 264
```

```py
W
wizardforcel 已提交
265
log_prob(value)
W
wizardforcel 已提交
266 267 268
```

```py
W
wizardforcel 已提交
269
logits
W
wizardforcel 已提交
270 271 272
```

```py
W
wizardforcel 已提交
273
mean
W
wizardforcel 已提交
274 275 276
```

```py
W
wizardforcel 已提交
277
param_shape
W
wizardforcel 已提交
278 279 280
```

```py
W
wizardforcel 已提交
281
probs
W
wizardforcel 已提交
282 283 284
```

```py
W
wizardforcel 已提交
285
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
286 287 288
```

```py
W
wizardforcel 已提交
289
support = Boolean()
W
wizardforcel 已提交
290 291 292
```

```py
W
wizardforcel 已提交
293
variance
W
wizardforcel 已提交
294 295 296 297 298
```

## Beta

```py
W
wizardforcel 已提交
299
class torch.distributions.beta.Beta(concentration1, concentration0, validate_args=None)
W
wizardforcel 已提交
300 301
```

J
dis 25%  
janhu 已提交
302
基类: [`torch.distributions.exp_family.ExponentialFamily`](#torch.distributions.exp_family.ExponentialFamily "torch.distributions.exp_family.ExponentialFamily")
W
wizardforcel 已提交
303

J
标点  
janhu 已提交
304
Beta 分布, 参数为 [`concentration1`](#torch.distributions.beta.Beta.concentration1 "torch.distributions.beta.Beta.concentration1")[`concentration0`](#torch.distributions.beta.Beta.concentration0 "torch.distributions.beta.Beta.concentration0").
W
wizardforcel 已提交
305

J
dis 25%  
janhu 已提交
306
例子:
W
wizardforcel 已提交
307 308 309 310 311 312 313 314

```py
>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
>>> m.sample()  # Beta distributed with concentration concentration1 and concentration0
tensor([ 0.1046])

```

J
dis 25%  
janhu 已提交
315
参数: 
W
wizardforcel 已提交
316

片刻小哥哥's avatar
片刻小哥哥 已提交
317
*   **concentration1** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的第一个浓度参数(通常称为alpha)
J
dis 25%  
janhu 已提交
318
*   **concentration0** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的第二个浓度参数(通常称为beta)
W
wizardforcel 已提交
319

W
wizardforcel 已提交
320

W
wizardforcel 已提交
321 322

```py
W
wizardforcel 已提交
323
arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
324 325 326
```

```py
W
wizardforcel 已提交
327
concentration0
W
wizardforcel 已提交
328 329 330
```

```py
W
wizardforcel 已提交
331
concentration1
W
wizardforcel 已提交
332 333 334
```

```py
W
wizardforcel 已提交
335
entropy()
W
wizardforcel 已提交
336 337 338
```

```py
W
wizardforcel 已提交
339
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
340 341 342
```

```py
W
wizardforcel 已提交
343
has_rsample = True
W
wizardforcel 已提交
344 345 346
```

```py
W
wizardforcel 已提交
347
log_prob(value)
W
wizardforcel 已提交
348 349 350
```

```py
W
wizardforcel 已提交
351
mean
W
wizardforcel 已提交
352 353 354
```

```py
W
wizardforcel 已提交
355
rsample(sample_shape=())
W
wizardforcel 已提交
356 357 358
```

```py
W
wizardforcel 已提交
359
support = Interval(lower_bound=0.0, upper_bound=1.0)
W
wizardforcel 已提交
360 361 362
```

```py
W
wizardforcel 已提交
363
variance
W
wizardforcel 已提交
364 365 366 367 368
```

## Binomial

```py
W
wizardforcel 已提交
369
class torch.distributions.binomial.Binomial(total_count=1, probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
370 371
```

J
dis 25%  
janhu 已提交
372
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
373

J
标点  
janhu 已提交
374
创建一个Binomial 分布, 参数为 `total_count`[`probs`](#torch.distributions.binomial.Binomial.probs "torch.distributions.binomial.Binomial.probs") 或者 [`logits`](#torch.distributions.binomial.Binomial.logits "torch.distributions.binomial.Binomial.logits") (但不是同时都有使用). `total_count` 必须和 [`probs`] 之间可广播(#torch.distributions.binomial.Binomial.probs "torch.distributions.binomial.Binomial.probs")/[`logits`](#torch.distributions.binomial.Binomial.logits "torch.distributions.binomial.Binomial.logits").
W
wizardforcel 已提交
375

J
dis 25%  
janhu 已提交
376
例子:
W
wizardforcel 已提交
377 378 379 380 381 382 383 384 385 386 387 388 389

```py
>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))
>>> x = m.sample()
tensor([   0.,   22.,   71.,  100.])

>>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))
>>> x = m.sample()
tensor([[ 4.,  5.],
 [ 7.,  6.]])

```

J
dis 25%  
janhu 已提交
390
参数: 
W
wizardforcel 已提交
391

J
dis 25%  
janhu 已提交
392 393 394
*   **total_count** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 伯努利试验次数
*   **probs** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 事件概率
*   **logits** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 事件 log-odds
W
wizardforcel 已提交
395

W
wizardforcel 已提交
396

W
wizardforcel 已提交
397 398

```py
W
wizardforcel 已提交
399
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerGreaterThan(lower_bound=0)}
W
wizardforcel 已提交
400 401 402
```

```py
W
wizardforcel 已提交
403
enumerate_support(expand=True)
W
wizardforcel 已提交
404 405 406
```

```py
W
wizardforcel 已提交
407
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
408 409 410
```

```py
W
wizardforcel 已提交
411
has_enumerate_support = True
W
wizardforcel 已提交
412 413 414
```

```py
W
wizardforcel 已提交
415
log_prob(value)
W
wizardforcel 已提交
416 417 418
```

```py
W
wizardforcel 已提交
419
logits
W
wizardforcel 已提交
420 421 422
```

```py
W
wizardforcel 已提交
423
mean
W
wizardforcel 已提交
424 425 426
```

```py
W
wizardforcel 已提交
427
param_shape
W
wizardforcel 已提交
428 429 430
```

```py
W
wizardforcel 已提交
431
probs
W
wizardforcel 已提交
432 433 434
```

```py
W
wizardforcel 已提交
435
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
436 437 438
```

```py
W
wizardforcel 已提交
439
support
W
wizardforcel 已提交
440 441 442
```

```py
W
wizardforcel 已提交
443
variance
W
wizardforcel 已提交
444 445 446 447 448
```

## Categorical

```py
W
wizardforcel 已提交
449
class torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
450 451
```

J
dis 25%  
janhu 已提交
452
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
453

J
标点  
janhu 已提交
454
创建一个 categorical 分布, 参数为 [`probs`](#torch.distributions.categorical.Categorical.probs "torch.distributions.categorical.Categorical.probs") 或者 [`logits`](#torch.distributions.categorical.Categorical.logits "torch.distributions.categorical.Categorical.logits") (但不是同时都有).
W
wizardforcel 已提交
455

J
dis 25%  
janhu 已提交
456
注意
W
wizardforcel 已提交
457

J
dis 25%  
janhu 已提交
458
它等价于从 [`torch.multinomial()`](torch.html#torch.multinomial "torch.multinomial") 的采样.
W
wizardforcel 已提交
459

J
dis 25%  
janhu 已提交
460
样本是整数来自![](img/7c6904e60a8ff7044a079e10eaee1f57.jpg) `K``probs.size(-1)`.
W
wizardforcel 已提交
461

J
标点  
janhu 已提交
462
如果 [`probs`](#torch.distributions.categorical.Categorical.probs "torch.distributions.categorical.Categorical.probs") 是 1D 的, 长度为`K`, 每个元素是在该索引处对类进行抽样的相对概率.
W
wizardforcel 已提交
463

J
dis 25%  
janhu 已提交
464
如果 [`probs`](#torch.distributions.categorical.Categorical.probs "torch.distributions.categorical.Categorical.probs") 是 2D 的, 它被视为一组相对概率向量.
W
wizardforcel 已提交
465

J
dis 25%  
janhu 已提交
466
注意
W
wizardforcel 已提交
467

J
标点  
janhu 已提交
468
[`probs`](#torch.distributions.categorical.Categorical.probs "torch.distributions.categorical.Categorical.probs")  必须是非负的、有限的并且具有非零和, 并且它将被归一化为和为1.
W
wizardforcel 已提交
469

J
dis 25%  
janhu 已提交
470
请参阅: [`torch.multinomial()`](torch.html#torch.multinomial "torch.multinomial")
W
wizardforcel 已提交
471

J
dis 25%  
janhu 已提交
472
例子:
W
wizardforcel 已提交
473 474 475 476 477 478 479 480

```py
>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample()  # equal probability of 0, 1, 2, 3
tensor(3)

```

J
dis 25%  
janhu 已提交
481
参数: 
W
wizardforcel 已提交
482 483 484 485

*   **probs** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – event probabilities
*   **logits** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – event log probabilities

W
wizardforcel 已提交
486

W
wizardforcel 已提交
487
```py
W
wizardforcel 已提交
488
arg_constraints = {'logits': Real(), 'probs': Simplex()}
W
wizardforcel 已提交
489 490 491
```

```py
W
wizardforcel 已提交
492
entropy()
W
wizardforcel 已提交
493 494 495
```

```py
W
wizardforcel 已提交
496
enumerate_support(expand=True)
W
wizardforcel 已提交
497 498 499
```

```py
W
wizardforcel 已提交
500
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
501 502 503
```

```py
W
wizardforcel 已提交
504
has_enumerate_support = True
W
wizardforcel 已提交
505 506 507
```

```py
W
wizardforcel 已提交
508
log_prob(value)
W
wizardforcel 已提交
509 510 511
```

```py
W
wizardforcel 已提交
512
logits
W
wizardforcel 已提交
513 514 515
```

```py
W
wizardforcel 已提交
516
mean
W
wizardforcel 已提交
517 518 519
```

```py
W
wizardforcel 已提交
520
param_shape
W
wizardforcel 已提交
521 522 523
```

```py
W
wizardforcel 已提交
524
probs
W
wizardforcel 已提交
525 526 527
```

```py
W
wizardforcel 已提交
528
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
529 530 531
```

```py
W
wizardforcel 已提交
532
support
W
wizardforcel 已提交
533 534 535
```

```py
W
wizardforcel 已提交
536
variance
W
wizardforcel 已提交
537 538 539 540 541
```

## Cauchy

```py
W
wizardforcel 已提交
542
class torch.distributions.cauchy.Cauchy(loc, scale, validate_args=None)
W
wizardforcel 已提交
543 544
```

J
dis 25%  
janhu 已提交
545
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
546

J
标点  
janhu 已提交
547
样本来自柯西(洛伦兹)分布. 均值为0的独立正态分布随机变量之比服从柯西分布. 
W
wizardforcel 已提交
548

J
dis 25%  
janhu 已提交
549
例子:
W
wizardforcel 已提交
550 551 552 553 554 555 556 557

```py
>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # sample from a Cauchy distribution with loc=0 and scale=1
tensor([ 2.3214])

```

J
dis 25%  
janhu 已提交
558
参数: 
W
wizardforcel 已提交
559

J
dis 25%  
janhu 已提交
560
*   **loc** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的模态或中值.
W
wizardforcel 已提交
561 562
*   **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – half width at half maximum.

W
wizardforcel 已提交
563

W
wizardforcel 已提交
564 565

```py
W
wizardforcel 已提交
566
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
567 568 569
```

```py
W
wizardforcel 已提交
570
cdf(value)
W
wizardforcel 已提交
571 572 573
```

```py
W
wizardforcel 已提交
574
entropy()
W
wizardforcel 已提交
575 576 577
```

```py
W
wizardforcel 已提交
578
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
579 580 581
```

```py
W
wizardforcel 已提交
582
has_rsample = True
W
wizardforcel 已提交
583 584 585
```

```py
W
wizardforcel 已提交
586
icdf(value)
W
wizardforcel 已提交
587 588 589
```

```py
W
wizardforcel 已提交
590
log_prob(value)
W
wizardforcel 已提交
591 592 593
```

```py
W
wizardforcel 已提交
594
mean
W
wizardforcel 已提交
595 596 597
```

```py
W
wizardforcel 已提交
598
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
599 600 601
```

```py
W
wizardforcel 已提交
602
support = Real()
W
wizardforcel 已提交
603 604 605
```

```py
W
wizardforcel 已提交
606
variance
W
wizardforcel 已提交
607 608 609 610 611
```

## Chi2

```py
W
wizardforcel 已提交
612
class torch.distributions.chi2.Chi2(df, validate_args=None)
W
wizardforcel 已提交
613 614
```

J
dis 25%  
janhu 已提交
615
基类: [`torch.distributions.gamma.Gamma`](#torch.distributions.gamma.Gamma "torch.distributions.gamma.Gamma")
W
wizardforcel 已提交
616

J
dis 25%  
janhu 已提交
617
 创建由形状参数[`df`](#torch.distributions.chi2.Chi2.df "torch.distributions.chi2.Chi2.df")参数化的Chi2分布.  这完全等同于 `Gamma(alpha=0.5*df, beta=0.5)`
W
wizardforcel 已提交
618

J
dis 25%  
janhu 已提交
619
例子:
W
wizardforcel 已提交
620 621 622 623 624 625 626 627

```py
>>> m = Chi2(torch.tensor([1.0]))
>>> m.sample()  # Chi2 distributed with shape df=1
tensor([ 0.1046])

```

J
dis 25%  
janhu 已提交
628 629
| 参数: | **df** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的形状参数 |

W
wizardforcel 已提交
630 631

```py
W
wizardforcel 已提交
632
arg_constraints = {'df': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
633 634 635
```

```py
W
wizardforcel 已提交
636
df
W
wizardforcel 已提交
637 638 639
```

```py
W
wizardforcel 已提交
640
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
641 642 643 644 645
```

## Dirichlet

```py
W
wizardforcel 已提交
646
class torch.distributions.dirichlet.Dirichlet(concentration, validate_args=None)
W
wizardforcel 已提交
647 648
```

J
dis 25%  
janhu 已提交
649
基类: [`torch.distributions.exp_family.ExponentialFamily`](#torch.distributions.exp_family.ExponentialFamily "torch.distributions.exp_family.ExponentialFamily")
W
wizardforcel 已提交
650

J
标点  
janhu 已提交
651
创建一个 Dirichlet 分布, 参数为`concentration`.
W
wizardforcel 已提交
652

J
dis 25%  
janhu 已提交
653
例子:
W
wizardforcel 已提交
654 655 656 657 658 659 660 661

```py
>>> m = Dirichlet(torch.tensor([0.5, 0.5]))
>>> m.sample()  # Dirichlet distributed with concentrarion concentration
tensor([ 0.1046,  0.8954])

```

片刻小哥哥's avatar
片刻小哥哥 已提交
662
| 参数: | **concentration** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  分布的浓度参数(通常称为alpha) |
J
dis 25%  
janhu 已提交
663

W
wizardforcel 已提交
664 665

```py
W
wizardforcel 已提交
666
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
667 668 669
```

```py
W
wizardforcel 已提交
670
entropy()
W
wizardforcel 已提交
671 672 673
```

```py
W
wizardforcel 已提交
674
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
675 676 677
```

```py
W
wizardforcel 已提交
678
has_rsample = True
W
wizardforcel 已提交
679 680 681
```

```py
W
wizardforcel 已提交
682
log_prob(value)
W
wizardforcel 已提交
683 684 685
```

```py
W
wizardforcel 已提交
686
mean
W
wizardforcel 已提交
687 688 689
```

```py
W
wizardforcel 已提交
690
rsample(sample_shape=())
W
wizardforcel 已提交
691 692 693
```

```py
W
wizardforcel 已提交
694
support = Simplex()
W
wizardforcel 已提交
695 696 697
```

```py
W
wizardforcel 已提交
698
variance
W
wizardforcel 已提交
699 700 701 702 703
```

## Exponential

```py
W
wizardforcel 已提交
704
class torch.distributions.exponential.Exponential(rate, validate_args=None)
W
wizardforcel 已提交
705 706
```

J
dis 25%  
janhu 已提交
707
基类: [`torch.distributions.exp_family.ExponentialFamily`](#torch.distributions.exp_family.ExponentialFamily "torch.distributions.exp_family.ExponentialFamily")
W
wizardforcel 已提交
708

J
dis 80%  
janhu 已提交
709
创建由`rate`参数化的指数分布.
W
wizardforcel 已提交
710

J
dis 25%  
janhu 已提交
711
例子:
W
wizardforcel 已提交
712 713 714 715 716 717 718 719

```py
>>> m = Exponential(torch.tensor([1.0]))
>>> m.sample()  # Exponential distributed with rate=1
tensor([ 0.1046])

```

J
dis 80%  
janhu 已提交
720
| 参数: | **rate** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – rate = 1 / 分布的scale  |
J
dis 25%  
janhu 已提交
721

W
wizardforcel 已提交
722 723

```py
W
wizardforcel 已提交
724
arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
725 726 727
```

```py
W
wizardforcel 已提交
728
cdf(value)
W
wizardforcel 已提交
729 730 731
```

```py
W
wizardforcel 已提交
732
entropy()
W
wizardforcel 已提交
733 734 735
```

```py
W
wizardforcel 已提交
736
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
737 738 739
```

```py
W
wizardforcel 已提交
740
has_rsample = True
W
wizardforcel 已提交
741 742 743
```

```py
W
wizardforcel 已提交
744
icdf(value)
W
wizardforcel 已提交
745 746 747
```

```py
W
wizardforcel 已提交
748
log_prob(value)
W
wizardforcel 已提交
749 750 751
```

```py
W
wizardforcel 已提交
752
mean
W
wizardforcel 已提交
753 754 755
```

```py
W
wizardforcel 已提交
756
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
757 758 759
```

```py
W
wizardforcel 已提交
760
stddev
W
wizardforcel 已提交
761 762 763
```

```py
W
wizardforcel 已提交
764
support = GreaterThan(lower_bound=0.0)
W
wizardforcel 已提交
765 766 767
```

```py
W
wizardforcel 已提交
768
variance
W
wizardforcel 已提交
769 770 771 772 773
```

## FisherSnedecor

```py
W
wizardforcel 已提交
774
class torch.distributions.fishersnedecor.FisherSnedecor(df1, df2, validate_args=None)
W
wizardforcel 已提交
775 776
```

J
dis 25%  
janhu 已提交
777
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
778

J
dis 80%  
janhu 已提交
779
创建由`df1``df2`参数化的Fisher-Snedecor分布
W
wizardforcel 已提交
780

J
dis 25%  
janhu 已提交
781
例子:
W
wizardforcel 已提交
782 783 784 785 786 787 788 789

```py
>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample()  # Fisher-Snedecor-distributed with df1=1 and df2=2
tensor([ 0.2453])

```

J
dis 25%  
janhu 已提交
790
参数: 
W
wizardforcel 已提交
791

J
dis 80%  
janhu 已提交
792 793
*   **df1** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  自由度参数1
*   **df2** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 自由度参数2
W
wizardforcel 已提交
794

W
wizardforcel 已提交
795

W
wizardforcel 已提交
796 797

```py
W
wizardforcel 已提交
798
arg_constraints = {'df1': GreaterThan(lower_bound=0.0), 'df2': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
799 800 801
```

```py
W
wizardforcel 已提交
802
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
803 804 805
```

```py
W
wizardforcel 已提交
806
has_rsample = True
W
wizardforcel 已提交
807 808 809
```

```py
W
wizardforcel 已提交
810
log_prob(value)
W
wizardforcel 已提交
811 812 813
```

```py
W
wizardforcel 已提交
814
mean
W
wizardforcel 已提交
815 816 817
```

```py
W
wizardforcel 已提交
818
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
819 820 821
```

```py
W
wizardforcel 已提交
822
support = GreaterThan(lower_bound=0.0)
W
wizardforcel 已提交
823 824 825
```

```py
W
wizardforcel 已提交
826
variance
W
wizardforcel 已提交
827 828 829 830 831
```

## Gamma

```py
W
wizardforcel 已提交
832
class torch.distributions.gamma.Gamma(concentration, rate, validate_args=None)
W
wizardforcel 已提交
833 834
```

J
dis 25%  
janhu 已提交
835
基类: [`torch.distributions.exp_family.ExponentialFamily`](#torch.distributions.exp_family.ExponentialFamily "torch.distributions.exp_family.ExponentialFamily")
W
wizardforcel 已提交
836

J
标点  
janhu 已提交
837
创建由`concentration``rate`参数化的伽马分布. .
W
wizardforcel 已提交
838

J
dis 25%  
janhu 已提交
839
例子:
W
wizardforcel 已提交
840 841 842 843 844 845 846 847

```py
>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # Gamma distributed with concentration=1 and rate=1
tensor([ 0.1046])

```

J
dis 25%  
janhu 已提交
848
参数: 
W
wizardforcel 已提交
849

片刻小哥哥's avatar
片刻小哥哥 已提交
850
*   **concentration** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的形状参数(通常称为alpha)
J
dis 80%  
janhu 已提交
851
*   **rate** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – rate = 1 /  分布scale (通常称为beta )
W
wizardforcel 已提交
852

W
wizardforcel 已提交
853

W
wizardforcel 已提交
854 855

```py
W
wizardforcel 已提交
856
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
857 858 859
```

```py
W
wizardforcel 已提交
860
entropy()
W
wizardforcel 已提交
861 862 863
```

```py
W
wizardforcel 已提交
864
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
865 866 867
```

```py
W
wizardforcel 已提交
868
has_rsample = True
W
wizardforcel 已提交
869 870 871
```

```py
W
wizardforcel 已提交
872
log_prob(value)
W
wizardforcel 已提交
873 874 875
```

```py
W
wizardforcel 已提交
876
mean
W
wizardforcel 已提交
877 878 879
```

```py
W
wizardforcel 已提交
880
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
881 882 883
```

```py
W
wizardforcel 已提交
884
support = GreaterThan(lower_bound=0.0)
W
wizardforcel 已提交
885 886 887
```

```py
W
wizardforcel 已提交
888
variance
W
wizardforcel 已提交
889 890 891 892 893
```

## Geometric

```py
W
wizardforcel 已提交
894
class torch.distributions.geometric.Geometric(probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
895 896
```

J
dis 25%  
janhu 已提交
897
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
898

J
标点  
janhu 已提交
899
创建由`probs`参数化的几何分布, 其中`probs`是伯努利试验成功的概率. 它表示概率在 ![](img/10396db36bab7b7242cfe94f04374444.jpg) 次伯努利试验中,  前 ![](img/a1c2f8d5b1226e67bdb44b12a6ddf18b.jpg) 试验失败, 然后成功.
W
wizardforcel 已提交
900

J
dis 80%  
janhu 已提交
901
样本是非负整数 [0, ![](img/06485c2c6e992cf346fdfe033a86a10d.jpg)).
W
wizardforcel 已提交
902

J
dis 25%  
janhu 已提交
903
例子:
W
wizardforcel 已提交
904 905 906 907 908 909 910 911

```py
>>> m = Geometric(torch.tensor([0.3]))
>>> m.sample()  # underlying Bernoulli has 30% chance 1; 70% chance 0
tensor([ 2.])

```

J
dis 25%  
janhu 已提交
912
参数: 
W
wizardforcel 已提交
913

J
dis 80%  
janhu 已提交
914 915
*   **probs** (_Number__,_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  抽样`1`的概率 . 必须是在范围 (0, 1]
*   **logits** (_Number__,_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 抽样 `1`的log-odds.
W
wizardforcel 已提交
916

W
wizardforcel 已提交
917

W
wizardforcel 已提交
918 919

```py
W
wizardforcel 已提交
920
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
W
wizardforcel 已提交
921 922 923
```

```py
W
wizardforcel 已提交
924
entropy()
W
wizardforcel 已提交
925 926 927
```

```py
W
wizardforcel 已提交
928
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
929 930 931
```

```py
W
wizardforcel 已提交
932
log_prob(value)
W
wizardforcel 已提交
933 934 935
```

```py
W
wizardforcel 已提交
936
logits
W
wizardforcel 已提交
937 938 939
```

```py
W
wizardforcel 已提交
940
mean
W
wizardforcel 已提交
941 942 943
```

```py
W
wizardforcel 已提交
944
probs
W
wizardforcel 已提交
945 946 947
```

```py
W
wizardforcel 已提交
948
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
949 950 951
```

```py
W
wizardforcel 已提交
952
support = IntegerGreaterThan(lower_bound=0)
W
wizardforcel 已提交
953 954 955
```

```py
W
wizardforcel 已提交
956
variance
W
wizardforcel 已提交
957 958 959 960 961
```

## Gumbel

```py
W
wizardforcel 已提交
962
class torch.distributions.gumbel.Gumbel(loc, scale, validate_args=None)
W
wizardforcel 已提交
963 964
```

J
dis 25%  
janhu 已提交
965
基类: [`torch.distributions.transformed_distribution.TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution")
W
wizardforcel 已提交
966

J
dis 80%  
janhu 已提交
967
来自Gumbel分布的样本.
W
wizardforcel 已提交
968 969 970 971 972 973 974 975 976 977

Examples:

```py
>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample()  # sample from Gumbel distribution with loc=1, scale=2
tensor([ 1.0124])

```

J
dis 25%  
janhu 已提交
978
参数: 
W
wizardforcel 已提交
979

J
dis 80%  
janhu 已提交
980 981
*   **loc** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  分布的位置参数
*   **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  分布的scale 参数
W
wizardforcel 已提交
982

W
wizardforcel 已提交
983

W
wizardforcel 已提交
984 985

```py
W
wizardforcel 已提交
986
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
987 988 989
```

```py
W
wizardforcel 已提交
990
entropy()
W
wizardforcel 已提交
991 992 993
```

```py
W
wizardforcel 已提交
994
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
995 996 997
```

```py
W
wizardforcel 已提交
998
mean
W
wizardforcel 已提交
999 1000 1001
```

```py
W
wizardforcel 已提交
1002
stddev
W
wizardforcel 已提交
1003 1004 1005
```

```py
W
wizardforcel 已提交
1006
support = Real()
W
wizardforcel 已提交
1007 1008 1009
```

```py
W
wizardforcel 已提交
1010
variance
W
wizardforcel 已提交
1011 1012 1013 1014 1015
```

## HalfCauchy

```py
W
wizardforcel 已提交
1016
class torch.distributions.half_cauchy.HalfCauchy(scale, validate_args=None)
W
wizardforcel 已提交
1017 1018
```

J
dis 25%  
janhu 已提交
1019
基类: [`torch.distributions.transformed_distribution.TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution")
W
wizardforcel 已提交
1020

J
dis 80%  
janhu 已提交
1021
创建`scale`参数化的半正态分布:
W
wizardforcel 已提交
1022 1023 1024 1025 1026 1027 1028

```py
X ~ Cauchy(0, scale)
Y = |X| ~ HalfCauchy(scale)

```

J
dis 25%  
janhu 已提交
1029
例子:
W
wizardforcel 已提交
1030 1031 1032 1033 1034 1035 1036 1037

```py
>>> m = HalfCauchy(torch.tensor([1.0]))
>>> m.sample()  # half-cauchy distributed with scale=1
tensor([ 2.3214])

```

J
dis 80%  
janhu 已提交
1038
| 参数: | **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 完全柯西分布的scale |
J
dis 25%  
janhu 已提交
1039

W
wizardforcel 已提交
1040 1041

```py
W
wizardforcel 已提交
1042
arg_constraints = {'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
1043 1044 1045
```

```py
W
wizardforcel 已提交
1046
cdf(value)
W
wizardforcel 已提交
1047 1048 1049
```

```py
W
wizardforcel 已提交
1050
entropy()
W
wizardforcel 已提交
1051 1052 1053
```

```py
W
wizardforcel 已提交
1054
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1055 1056 1057
```

```py
W
wizardforcel 已提交
1058
has_rsample = True
W
wizardforcel 已提交
1059 1060 1061
```

```py
W
wizardforcel 已提交
1062
icdf(prob)
W
wizardforcel 已提交
1063 1064 1065
```

```py
W
wizardforcel 已提交
1066
log_prob(value)
W
wizardforcel 已提交
1067 1068 1069
```

```py
W
wizardforcel 已提交
1070
mean
W
wizardforcel 已提交
1071 1072 1073
```

```py
W
wizardforcel 已提交
1074
scale
W
wizardforcel 已提交
1075 1076 1077
```

```py
W
wizardforcel 已提交
1078
support = GreaterThan(lower_bound=0.0)
W
wizardforcel 已提交
1079 1080 1081
```

```py
W
wizardforcel 已提交
1082
variance
W
wizardforcel 已提交
1083 1084 1085 1086 1087
```

## HalfNormal

```py
W
wizardforcel 已提交
1088
class torch.distributions.half_normal.HalfNormal(scale, validate_args=None)
W
wizardforcel 已提交
1089 1090
```

J
dis 25%  
janhu 已提交
1091
基类: [`torch.distributions.transformed_distribution.TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution")
W
wizardforcel 已提交
1092

J
dis 80%  
janhu 已提交
1093
创建按`scale`参数化的半正态分布:
W
wizardforcel 已提交
1094 1095 1096 1097 1098 1099 1100

```py
X ~ Normal(0, scale)
Y = |X| ~ HalfNormal(scale)

```

J
dis 25%  
janhu 已提交
1101
例子:
W
wizardforcel 已提交
1102 1103 1104 1105 1106 1107 1108 1109

```py
>>> m = HalfNormal(torch.tensor([1.0]))
>>> m.sample()  # half-normal distributed with scale=1
tensor([ 0.1046])

```

J
dis 80%  
janhu 已提交
1110
| 参数: | **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 完全正态分布的scale |
J
dis 25%  
janhu 已提交
1111

W
wizardforcel 已提交
1112 1113

```py
W
wizardforcel 已提交
1114
arg_constraints = {'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
1115 1116 1117
```

```py
W
wizardforcel 已提交
1118
cdf(value)
W
wizardforcel 已提交
1119 1120 1121
```

```py
W
wizardforcel 已提交
1122
entropy()
W
wizardforcel 已提交
1123 1124 1125
```

```py
W
wizardforcel 已提交
1126
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1127 1128 1129
```

```py
W
wizardforcel 已提交
1130
has_rsample = True
W
wizardforcel 已提交
1131 1132 1133
```

```py
W
wizardforcel 已提交
1134
icdf(prob)
W
wizardforcel 已提交
1135 1136 1137
```

```py
W
wizardforcel 已提交
1138
log_prob(value)
W
wizardforcel 已提交
1139 1140 1141
```

```py
W
wizardforcel 已提交
1142
mean
W
wizardforcel 已提交
1143 1144 1145
```

```py
W
wizardforcel 已提交
1146
scale
W
wizardforcel 已提交
1147 1148 1149
```

```py
W
wizardforcel 已提交
1150
support = GreaterThan(lower_bound=0.0)
W
wizardforcel 已提交
1151 1152 1153
```

```py
W
wizardforcel 已提交
1154
variance
W
wizardforcel 已提交
1155 1156 1157 1158 1159
```

## Independent

```py
W
wizardforcel 已提交
1160
class torch.distributions.independent.Independent(base_distribution, reinterpreted_batch_ndims, validate_args=None)
W
wizardforcel 已提交
1161 1162
```

J
dis 25%  
janhu 已提交
1163
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
1164

J
dis 80%  
janhu 已提交
1165
重新解释一些分布的批量 dims 作为 event dims.
W
wizardforcel 已提交
1166

J
标点  
janhu 已提交
1167
 这主要用于改变[`log_prob()`](#torch.distributions.independent.Independent.log_prob "torch.distributions.independent.Independent.log_prob")结果的形状.例如, 要创建与多元正态分布形状相同的对角正态分布(因此它们是可互换的), 您可以这样做:
W
wizardforcel 已提交
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183

```py
>>> loc = torch.zeros(3)
>>> scale = torch.ones(3)
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
>>> [mvn.batch_shape, mvn.event_shape]
[torch.Size(()), torch.Size((3,))]
>>> normal = Normal(loc, scale)
>>> [normal.batch_shape, normal.event_shape]
[torch.Size((3,)), torch.Size(())]
>>> diagn = Independent(normal, 1)
>>> [diagn.batch_shape, diagn.event_shape]
[torch.Size(()), torch.Size((3,))]

```

J
dis 25%  
janhu 已提交
1184
参数: 
W
wizardforcel 已提交
1185

J
dis 80%  
janhu 已提交
1186 1187
*   **base_distribution** ([_torch.distributions.distribution.Distribution_](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")) – 基础分布
*   **reinterpreted_batch_ndims** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) –要重解释的批量dims的数量
W
wizardforcel 已提交
1188

W
wizardforcel 已提交
1189 1190

```py
W
wizardforcel 已提交
1191
arg_constraints = {}
W
wizardforcel 已提交
1192 1193 1194
```

```py
W
wizardforcel 已提交
1195
entropy()
W
wizardforcel 已提交
1196 1197 1198
```

```py
W
wizardforcel 已提交
1199
enumerate_support(expand=True)
W
wizardforcel 已提交
1200 1201 1202
```

```py
W
wizardforcel 已提交
1203
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1204 1205 1206
```

```py
W
wizardforcel 已提交
1207
has_enumerate_support
W
wizardforcel 已提交
1208 1209 1210
```

```py
W
wizardforcel 已提交
1211
has_rsample
W
wizardforcel 已提交
1212 1213 1214
```

```py
W
wizardforcel 已提交
1215
log_prob(value)
W
wizardforcel 已提交
1216 1217 1218
```

```py
W
wizardforcel 已提交
1219
mean
W
wizardforcel 已提交
1220 1221 1222
```

```py
W
wizardforcel 已提交
1223
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1224 1225 1226
```

```py
W
wizardforcel 已提交
1227
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1228 1229 1230
```

```py
W
wizardforcel 已提交
1231
support
W
wizardforcel 已提交
1232 1233 1234
```

```py
W
wizardforcel 已提交
1235
variance
W
wizardforcel 已提交
1236 1237 1238 1239 1240
```

## Laplace

```py
W
wizardforcel 已提交
1241
class torch.distributions.laplace.Laplace(loc, scale, validate_args=None)
W
wizardforcel 已提交
1242 1243
```

J
dis 25%  
janhu 已提交
1244
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
1245

片刻小哥哥's avatar
片刻小哥哥 已提交
1246
创建参数化的拉普拉斯分布, 参数是 `loc` 和 :attr:'scale'.
W
wizardforcel 已提交
1247

J
dis 25%  
janhu 已提交
1248
例子:
W
wizardforcel 已提交
1249 1250 1251 1252 1253 1254 1255 1256

```py
>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # Laplace distributed with loc=0, scale=1
tensor([ 0.1046])

```

J
dis 25%  
janhu 已提交
1257
参数: 
W
wizardforcel 已提交
1258

J
dis 80%  
janhu 已提交
1259 1260
*   **loc** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布均值
*   **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布scale
W
wizardforcel 已提交
1261

W
wizardforcel 已提交
1262 1263

```py
W
wizardforcel 已提交
1264
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
1265 1266 1267
```

```py
W
wizardforcel 已提交
1268
cdf(value)
W
wizardforcel 已提交
1269 1270 1271
```

```py
W
wizardforcel 已提交
1272
entropy()
W
wizardforcel 已提交
1273 1274 1275
```

```py
W
wizardforcel 已提交
1276
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1277 1278 1279
```

```py
W
wizardforcel 已提交
1280
has_rsample = True
W
wizardforcel 已提交
1281 1282 1283
```

```py
W
wizardforcel 已提交
1284
icdf(value)
W
wizardforcel 已提交
1285 1286 1287
```

```py
W
wizardforcel 已提交
1288
log_prob(value)
W
wizardforcel 已提交
1289 1290 1291
```

```py
W
wizardforcel 已提交
1292
mean
W
wizardforcel 已提交
1293 1294 1295
```

```py
W
wizardforcel 已提交
1296
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1297 1298 1299
```

```py
W
wizardforcel 已提交
1300
stddev
W
wizardforcel 已提交
1301 1302 1303
```

```py
W
wizardforcel 已提交
1304
support = Real()
W
wizardforcel 已提交
1305 1306 1307
```

```py
W
wizardforcel 已提交
1308
variance
W
wizardforcel 已提交
1309 1310 1311 1312 1313
```

## LogNormal

```py
W
wizardforcel 已提交
1314
class torch.distributions.log_normal.LogNormal(loc, scale, validate_args=None)
W
wizardforcel 已提交
1315 1316
```

J
dis 25%  
janhu 已提交
1317
基类: [`torch.distributions.transformed_distribution.TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution")
W
wizardforcel 已提交
1318

J
标点  
janhu 已提交
1319
 创建参数化的对数正态分布, 参数为 [`loc`](#torch.distributions.log_normal.LogNormal.loc "torch.distributions.log_normal.LogNormal.loc")[`scale`](#torch.distributions.log_normal.LogNormal.scale "torch.distributions.log_normal.LogNormal.scale"):
W
wizardforcel 已提交
1320 1321 1322 1323 1324 1325 1326

```py
X ~ Normal(loc, scale)
Y = exp(X) ~ LogNormal(loc, scale)

```

J
dis 25%  
janhu 已提交
1327
例子:
W
wizardforcel 已提交
1328 1329 1330 1331 1332 1333 1334 1335

```py
>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # log-normal distributed with mean=0 and stddev=1
tensor([ 0.1046])

```

J
dis 25%  
janhu 已提交
1336
参数: 
W
wizardforcel 已提交
1337

J
dis 80%  
janhu 已提交
1338 1339
*   **loc** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  分布对数平均值
*   **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  分布对数的标准差
W
wizardforcel 已提交
1340

W
wizardforcel 已提交
1341 1342

```py
W
wizardforcel 已提交
1343
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
1344 1345 1346
```

```py
W
wizardforcel 已提交
1347
entropy()
W
wizardforcel 已提交
1348 1349 1350
```

```py
W
wizardforcel 已提交
1351
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1352 1353 1354
```

```py
W
wizardforcel 已提交
1355
has_rsample = True
W
wizardforcel 已提交
1356 1357 1358
```

```py
W
wizardforcel 已提交
1359
loc
W
wizardforcel 已提交
1360 1361 1362
```

```py
W
wizardforcel 已提交
1363
mean
W
wizardforcel 已提交
1364 1365 1366
```

```py
W
wizardforcel 已提交
1367
scale
W
wizardforcel 已提交
1368 1369 1370
```

```py
W
wizardforcel 已提交
1371
support = GreaterThan(lower_bound=0.0)
W
wizardforcel 已提交
1372 1373 1374
```

```py
W
wizardforcel 已提交
1375
variance
W
wizardforcel 已提交
1376 1377 1378 1379 1380
```

## LowRankMultivariateNormal

```py
W
wizardforcel 已提交
1381
class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)
W
wizardforcel 已提交
1382 1383
```

J
dis 25%  
janhu 已提交
1384
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
1385

J
dis 80%  
janhu 已提交
1386
使用由`cov_factor``cov_diag`参数化的低秩形式的协方差矩阵创建多元正态分布:
W
wizardforcel 已提交
1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401

```py
covariance_matrix = cov_factor @ cov_factor.T + cov_diag

```

Example

```py
>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1]))
>>> m.sample()  # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]`
tensor([-0.2102, -0.5429])

```

J
dis 25%  
janhu 已提交
1402
参数: 
W
wizardforcel 已提交
1403

J
标点  
janhu 已提交
1404
*   **loc** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的均值, 形状为 `batch_shape + event_shape`
J
dis 80%  
janhu 已提交
1405 1406
*   **cov_factor** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 协方差矩阵低秩形式的因子部分, 形状为 `batch_shape + event_shape + (rank,)`
*   **cov_diag** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 协方差矩阵的低秩形式的对角部分, 形状为 `batch_shape + event_shape`
W
wizardforcel 已提交
1407

W
wizardforcel 已提交
1408

W
wizardforcel 已提交
1409

J
dis 80%  
janhu 已提交
1410
注意
W
wizardforcel 已提交
1411

J
标点  
janhu 已提交
1412
避免了协方差矩阵的行列式和逆的计算, 当 `cov_factor.shape[1] << cov_factor.shape[0]` 由于 [Woodbury matrix identity](https://en.wikipedia.org/wiki/Woodbury_matrix_identity)[matrix determinant lemma](https://en.wikipedia.org/wiki/Matrix_determinant_lemma).  由于这些公式, 我们只需要计算小尺寸“capacitance”矩阵的行列式和逆:
W
wizardforcel 已提交
1413 1414 1415 1416 1417 1418 1419

```py
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor

```

```py
W
wizardforcel 已提交
1420
arg_constraints = {'cov_diag': GreaterThan(lower_bound=0.0), 'cov_factor': Real(), 'loc': Real()}
W
wizardforcel 已提交
1421 1422 1423
```

```py
W
wizardforcel 已提交
1424
covariance_matrix
W
wizardforcel 已提交
1425 1426 1427
```

```py
W
wizardforcel 已提交
1428
entropy()
W
wizardforcel 已提交
1429 1430 1431
```

```py
W
wizardforcel 已提交
1432
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1433 1434 1435
```

```py
W
wizardforcel 已提交
1436
has_rsample = True
W
wizardforcel 已提交
1437 1438 1439
```

```py
W
wizardforcel 已提交
1440
log_prob(value)
W
wizardforcel 已提交
1441 1442 1443
```

```py
W
wizardforcel 已提交
1444
mean
W
wizardforcel 已提交
1445 1446 1447
```

```py
W
wizardforcel 已提交
1448
precision_matrix
W
wizardforcel 已提交
1449 1450 1451
```

```py
W
wizardforcel 已提交
1452
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1453 1454 1455
```

```py
W
wizardforcel 已提交
1456
scale_tril
W
wizardforcel 已提交
1457 1458 1459
```

```py
W
wizardforcel 已提交
1460
support = Real()
W
wizardforcel 已提交
1461 1462 1463
```

```py
W
wizardforcel 已提交
1464
variance
W
wizardforcel 已提交
1465 1466 1467 1468 1469
```

## Multinomial

```py
W
wizardforcel 已提交
1470
class torch.distributions.multinomial.Multinomial(total_count=1, probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
1471 1472
```

J
dis 25%  
janhu 已提交
1473
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
1474

片刻小哥哥's avatar
片刻小哥哥 已提交
1475
创建由`total_count``probs``logits`(但不是两者)参数化的多项式分布.  `probs`的最内层维度是对类别的索引.  所有其他维度索引批次. 
W
wizardforcel 已提交
1476

J
标点  
janhu 已提交
1477
注意 `total_count` 不需要指定, 当只有 [`log_prob()`](#torch.distributions.multinomial.Multinomial.log_prob "torch.distributions.multinomial.Multinomial.log_prob") 被调用
W
wizardforcel 已提交
1478

J
dis 80%  
janhu 已提交
1479
注意
W
wizardforcel 已提交
1480

J
标点  
janhu 已提交
1481
[`probs`](#torch.distributions.multinomial.Multinomial.probs "torch.distributions.multinomial.Multinomial.probs") 必须是非负的、有限的并且具有非零和, 并且它将被归一化为和为1.
W
wizardforcel 已提交
1482

J
dis 80%  
janhu 已提交
1483 1484
*   [`sample()`](#torch.distributions.multinomial.Multinomial.sample "torch.distributions.multinomial.Multinomial.sample") 所有参数和样本都需要一个共享的`total_count`.
*   [`log_prob()`](#torch.distributions.multinomial.Multinomial.log_prob "torch.distributions.multinomial.Multinomial.log_prob")  允许每个参数和样本使用不同的`total_count`.
W
wizardforcel 已提交
1485

J
dis 25%  
janhu 已提交
1486
例子:
W
wizardforcel 已提交
1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497

```py
>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
>>> x = m.sample()  # equal probability of 0, 1, 2, 3
tensor([ 21.,  24.,  30.,  25.])

>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
tensor([-4.1338])

```

J
dis 25%  
janhu 已提交
1498
参数: 
W
wizardforcel 已提交
1499

J
dis 80%  
janhu 已提交
1500 1501 1502
*   **total_count** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – 试验次数
*   **probs** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 事件概率
*   **logits** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 事件对数概率
W
wizardforcel 已提交
1503

W
wizardforcel 已提交
1504

W
wizardforcel 已提交
1505 1506

```py
W
wizardforcel 已提交
1507
arg_constraints = {'logits': Real(), 'probs': Simplex()}
W
wizardforcel 已提交
1508 1509 1510
```

```py
W
wizardforcel 已提交
1511
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1512 1513 1514
```

```py
W
wizardforcel 已提交
1515
log_prob(value)
W
wizardforcel 已提交
1516 1517 1518
```

```py
W
wizardforcel 已提交
1519
logits
W
wizardforcel 已提交
1520 1521 1522
```

```py
W
wizardforcel 已提交
1523
mean
W
wizardforcel 已提交
1524 1525 1526
```

```py
W
wizardforcel 已提交
1527
param_shape
W
wizardforcel 已提交
1528 1529 1530
```

```py
W
wizardforcel 已提交
1531
probs
W
wizardforcel 已提交
1532 1533 1534
```

```py
W
wizardforcel 已提交
1535
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1536 1537 1538
```

```py
W
wizardforcel 已提交
1539
support
W
wizardforcel 已提交
1540 1541 1542
```

```py
W
wizardforcel 已提交
1543
variance
W
wizardforcel 已提交
1544 1545 1546 1547 1548
```

## MultivariateNormal

```py
W
wizardforcel 已提交
1549
class torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)
W
wizardforcel 已提交
1550 1551
```

J
dis 25%  
janhu 已提交
1552
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
1553

J
dis 80%  
janhu 已提交
1554
创建由均值向量和协方差矩阵参数化的多元正态(也称为高斯)分布.
W
wizardforcel 已提交
1555

J
dis 80%  
janhu 已提交
1556
多元正态分布可以用正定协方差矩阵![](img/ea86c11eaef9af2b4d699b88c2474ffd.jpg)来参数化或者一个正定的精度矩阵 ![](img/1949bfcc1decf198a2ff50b6e25f4cf6.jpg)  或者是一个正对角项的下三角矩阵 ![](img/f4996f1b5056dd364eab16f975b808ff.jpg), 例如 ![](img/6749b6afc75abfc8e0652ac8e5c0b8d8.jpg). 这个三角矩阵可以通过协方差的Cholesky分解得到.
W
wizardforcel 已提交
1557

J
dis 80%  
janhu 已提交
1558
例子
W
wizardforcel 已提交
1559 1560 1561 1562 1563 1564 1565 1566

```py
>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> m.sample()  # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
tensor([-0.2102, -0.5429])

```

J
dis 25%  
janhu 已提交
1567
参数: 
W
wizardforcel 已提交
1568

J
dis 80%  
janhu 已提交
1569 1570 1571 1572
*   **loc** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的均值
*   **covariance_matrix** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 正定协方差矩阵
*   **precision_matrix** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 正定精度矩阵
*   **scale_tril** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 具有正值对角线的下三角协方差因子
W
wizardforcel 已提交
1573

W
wizardforcel 已提交
1574

W
wizardforcel 已提交
1575

J
dis 80%  
janhu 已提交
1576
注意
W
wizardforcel 已提交
1577

J
dis 80%  
janhu 已提交
1578
仅仅一个 [`covariance_matrix`](#torch.distributions.multivariate_normal.MultivariateNormal.covariance_matrix "torch.distributions.multivariate_normal.MultivariateNormal.covariance_matrix") 或者 [`precision_matrix`](#torch.distributions.multivariate_normal.MultivariateNormal.precision_matrix "torch.distributions.multivariate_normal.MultivariateNormal.precision_matrix") 或者 [`scale_tril`](#torch.distributions.multivariate_normal.MultivariateNormal.scale_tril "torch.distributions.multivariate_normal.MultivariateNormal.scale_tril") 可被指定.
W
wizardforcel 已提交
1579

J
dis 80%  
janhu 已提交
1580
使用 [`scale_tril`](#torch.distributions.multivariate_normal.MultivariateNormal.scale_tril "torch.distributions.multivariate_normal.MultivariateNormal.scale_tril")  会更有效率: 内部的所有计算都基于 [`scale_tril`](#torch.distributions.multivariate_normal.MultivariateNormal.scale_tril "torch.distributions.multivariate_normal.MultivariateNormal.scale_tril"). 如果 [`covariance_matrix`](#torch.distributions.multivariate_normal.MultivariateNormal.covariance_matrix "torch.distributions.multivariate_normal.MultivariateNormal.covariance_matrix") 或者 [`precision_matrix`](#torch.distributions.multivariate_normal.MultivariateNormal.precision_matrix "torch.distributions.multivariate_normal.MultivariateNormal.precision_matrix") 已经被传入, 它仅用于使用Cholesky分解计算相应的下三角矩阵.
W
wizardforcel 已提交
1581 1582

```py
W
wizardforcel 已提交
1583
arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': RealVector(), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
W
wizardforcel 已提交
1584 1585 1586
```

```py
W
wizardforcel 已提交
1587
covariance_matrix
W
wizardforcel 已提交
1588 1589 1590
```

```py
W
wizardforcel 已提交
1591
entropy()
W
wizardforcel 已提交
1592 1593 1594
```

```py
W
wizardforcel 已提交
1595
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1596 1597 1598
```

```py
W
wizardforcel 已提交
1599
has_rsample = True
W
wizardforcel 已提交
1600 1601 1602
```

```py
W
wizardforcel 已提交
1603
log_prob(value)
W
wizardforcel 已提交
1604 1605 1606
```

```py
W
wizardforcel 已提交
1607
mean
W
wizardforcel 已提交
1608 1609 1610
```

```py
W
wizardforcel 已提交
1611
precision_matrix
W
wizardforcel 已提交
1612 1613 1614
```

```py
W
wizardforcel 已提交
1615
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1616 1617 1618
```

```py
W
wizardforcel 已提交
1619
scale_tril
W
wizardforcel 已提交
1620 1621 1622
```

```py
W
wizardforcel 已提交
1623
support = Real()
W
wizardforcel 已提交
1624 1625 1626
```

```py
W
wizardforcel 已提交
1627
variance
W
wizardforcel 已提交
1628 1629 1630 1631 1632
```

## NegativeBinomial

```py
W
wizardforcel 已提交
1633
class torch.distributions.negative_binomial.NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
1634 1635
```

J
dis 25%  
janhu 已提交
1636
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
1637

J
标点  
janhu 已提交
1638
创建一个负二项分布, 即在达到`total_count`失败之前所需的独立相同伯努利试验的数量的分布. 每次伯努利试验成功的概率都是`probs`. 
W
wizardforcel 已提交
1639

J
dis 25%  
janhu 已提交
1640
参数: 
W
wizardforcel 已提交
1641

J
dis 80%  
janhu 已提交
1642
*   **total_count** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  非负数伯努利试验停止的次数, 虽然分布仍然对实数有效
J
标点  
janhu 已提交
1643
*   **probs** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 事件概率, 区间为 [0, 1)
J
dis 80%  
janhu 已提交
1644
*   **logits** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 事件对数几率 - 成功概率的几率
W
wizardforcel 已提交
1645

W
wizardforcel 已提交
1646

W
wizardforcel 已提交
1647 1648

```py
W
wizardforcel 已提交
1649
arg_constraints = {'logits': Real(), 'probs': HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': GreaterThanEq(lower_bound=0)}
W
wizardforcel 已提交
1650 1651 1652
```

```py
W
wizardforcel 已提交
1653
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1654 1655 1656
```

```py
W
wizardforcel 已提交
1657
log_prob(value)
W
wizardforcel 已提交
1658 1659 1660
```

```py
W
wizardforcel 已提交
1661
logits
W
wizardforcel 已提交
1662 1663 1664
```

```py
W
wizardforcel 已提交
1665
mean
W
wizardforcel 已提交
1666 1667 1668
```

```py
W
wizardforcel 已提交
1669
param_shape
W
wizardforcel 已提交
1670 1671 1672
```

```py
W
wizardforcel 已提交
1673
probs
W
wizardforcel 已提交
1674 1675 1676
```

```py
W
wizardforcel 已提交
1677
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1678 1679 1680
```

```py
W
wizardforcel 已提交
1681
support = IntegerGreaterThan(lower_bound=0)
W
wizardforcel 已提交
1682 1683 1684
```

```py
W
wizardforcel 已提交
1685
variance
W
wizardforcel 已提交
1686 1687 1688 1689 1690
```

## Normal

```py
W
wizardforcel 已提交
1691
class torch.distributions.normal.Normal(loc, scale, validate_args=None)
W
wizardforcel 已提交
1692 1693
```

J
dis 25%  
janhu 已提交
1694
基类: [`torch.distributions.exp_family.ExponentialFamily`](#torch.distributions.exp_family.ExponentialFamily "torch.distributions.exp_family.ExponentialFamily")
W
wizardforcel 已提交
1695

片刻小哥哥's avatar
片刻小哥哥 已提交
1696
创建由`loc``scale`参数化的正态(也称为高斯)分布
W
wizardforcel 已提交
1697

J
dis 25%  
janhu 已提交
1698
例子:
W
wizardforcel 已提交
1699 1700 1701 1702 1703 1704 1705 1706

```py
>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # normally distributed with loc=0 and scale=1
tensor([ 0.1046])

```

J
dis 25%  
janhu 已提交
1707
参数: 
W
wizardforcel 已提交
1708

J
dis 80%  
janhu 已提交
1709 1710
*   **loc** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 均值 (也被称为 mu)
*   **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 标准差(也被称为) sigma)
W
wizardforcel 已提交
1711

W
wizardforcel 已提交
1712

W
wizardforcel 已提交
1713 1714

```py
W
wizardforcel 已提交
1715
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
1716 1717 1718
```

```py
W
wizardforcel 已提交
1719
cdf(value)
W
wizardforcel 已提交
1720 1721 1722
```

```py
W
wizardforcel 已提交
1723
entropy()
W
wizardforcel 已提交
1724 1725 1726
```

```py
W
wizardforcel 已提交
1727
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1728 1729 1730
```

```py
W
wizardforcel 已提交
1731
has_rsample = True
W
wizardforcel 已提交
1732 1733 1734
```

```py
W
wizardforcel 已提交
1735
icdf(value)
W
wizardforcel 已提交
1736 1737 1738
```

```py
W
wizardforcel 已提交
1739
log_prob(value)
W
wizardforcel 已提交
1740 1741 1742
```

```py
W
wizardforcel 已提交
1743
mean
W
wizardforcel 已提交
1744 1745 1746
```

```py
W
wizardforcel 已提交
1747
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1748 1749 1750
```

```py
W
wizardforcel 已提交
1751
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1752 1753 1754
```

```py
W
wizardforcel 已提交
1755
stddev
W
wizardforcel 已提交
1756 1757 1758
```

```py
W
wizardforcel 已提交
1759
support = Real()
W
wizardforcel 已提交
1760 1761 1762
```

```py
W
wizardforcel 已提交
1763
variance
W
wizardforcel 已提交
1764 1765 1766 1767 1768
```

## OneHotCategorical

```py
W
wizardforcel 已提交
1769
class torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
1770 1771
```

J
dis 25%  
janhu 已提交
1772
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
1773

J
dis 80%  
janhu 已提交
1774
创建一个由`probs`或l`ogits`参数化的One Hot Categorical 分布
W
wizardforcel 已提交
1775

J
dis 80%  
janhu 已提交
1776
样本是大小为 `probs.size(-1)`热编码向量.
W
wizardforcel 已提交
1777

J
dis 80%  
janhu 已提交
1778
注意
W
wizardforcel 已提交
1779

J
标点  
janhu 已提交
1780
`probs`必须是非负的, 有限的并且具有非零和, 并且它将被归一化为总和为1. 
W
wizardforcel 已提交
1781

J
dis 80%  
janhu 已提交
1782
请参见: `torch.distributions.Categorical()` 对于指定 [`probs`](#torch.distributions.one_hot_categorical.OneHotCategorical.probs "torch.distributions.one_hot_categorical.OneHotCategorical.probs")[`logits`](#torch.distributions.one_hot_categorical.OneHotCategorical.logits "torch.distributions.one_hot_categorical.OneHotCategorical.logits").
W
wizardforcel 已提交
1783

J
dis 25%  
janhu 已提交
1784
例子:
W
wizardforcel 已提交
1785 1786 1787 1788 1789 1790 1791 1792

```py
>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample()  # equal probability of 0, 1, 2, 3
tensor([ 0.,  0.,  0.,  1.])

```

J
dis 25%  
janhu 已提交
1793
参数: 
W
wizardforcel 已提交
1794 1795 1796 1797

*   **probs** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – event probabilities
*   **logits** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – event log probabilities

W
wizardforcel 已提交
1798

W
wizardforcel 已提交
1799 1800

```py
W
wizardforcel 已提交
1801
arg_constraints = {'logits': Real(), 'probs': Simplex()}
W
wizardforcel 已提交
1802 1803 1804
```

```py
W
wizardforcel 已提交
1805
entropy()
W
wizardforcel 已提交
1806 1807 1808
```

```py
W
wizardforcel 已提交
1809
enumerate_support(expand=True)
W
wizardforcel 已提交
1810 1811 1812
```

```py
W
wizardforcel 已提交
1813
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1814 1815 1816
```

```py
W
wizardforcel 已提交
1817
has_enumerate_support = True
W
wizardforcel 已提交
1818 1819 1820
```

```py
W
wizardforcel 已提交
1821
log_prob(value)
W
wizardforcel 已提交
1822 1823 1824
```

```py
W
wizardforcel 已提交
1825
logits
W
wizardforcel 已提交
1826 1827 1828
```

```py
W
wizardforcel 已提交
1829
mean
W
wizardforcel 已提交
1830 1831 1832
```

```py
W
wizardforcel 已提交
1833
param_shape
W
wizardforcel 已提交
1834 1835 1836
```

```py
W
wizardforcel 已提交
1837
probs
W
wizardforcel 已提交
1838 1839 1840
```

```py
W
wizardforcel 已提交
1841
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1842 1843 1844
```

```py
W
wizardforcel 已提交
1845
support = Simplex()
W
wizardforcel 已提交
1846 1847 1848
```

```py
W
wizardforcel 已提交
1849
variance
W
wizardforcel 已提交
1850 1851 1852 1853 1854
```

## Pareto

```py
W
wizardforcel 已提交
1855
class torch.distributions.pareto.Pareto(scale, alpha, validate_args=None)
W
wizardforcel 已提交
1856 1857
```

J
dis 25%  
janhu 已提交
1858
基类: [`torch.distributions.transformed_distribution.TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution")
W
wizardforcel 已提交
1859

J
dis 80%  
janhu 已提交
1860
来自Pareto Type 1分布的样本.
W
wizardforcel 已提交
1861

J
dis 25%  
janhu 已提交
1862
例子:
W
wizardforcel 已提交
1863 1864 1865 1866 1867 1868 1869 1870

```py
>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # sample from a Pareto distribution with scale=1 and alpha=1
tensor([ 1.5623])

```

J
dis 25%  
janhu 已提交
1871
参数: 
W
wizardforcel 已提交
1872

J
dis 80%  
janhu 已提交
1873 1874
*   **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的Scale
*   **alpha** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的Shape
W
wizardforcel 已提交
1875

W
wizardforcel 已提交
1876

W
wizardforcel 已提交
1877 1878

```py
W
wizardforcel 已提交
1879
arg_constraints = {'alpha': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
1880 1881 1882
```

```py
W
wizardforcel 已提交
1883
entropy()
W
wizardforcel 已提交
1884 1885 1886
```

```py
W
wizardforcel 已提交
1887
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1888 1889 1890
```

```py
W
wizardforcel 已提交
1891
mean
W
wizardforcel 已提交
1892 1893 1894
```

```py
W
wizardforcel 已提交
1895
support
W
wizardforcel 已提交
1896 1897 1898
```

```py
W
wizardforcel 已提交
1899
variance
W
wizardforcel 已提交
1900 1901 1902 1903 1904
```

## Poisson

```py
W
wizardforcel 已提交
1905
class torch.distributions.poisson.Poisson(rate, validate_args=None)
W
wizardforcel 已提交
1906 1907
```

J
dis 25%  
janhu 已提交
1908
基类: [`torch.distributions.exp_family.ExponentialFamily`](#torch.distributions.exp_family.ExponentialFamily "torch.distributions.exp_family.ExponentialFamily")
W
wizardforcel 已提交
1909

J
dis 80%  
janhu 已提交
1910
创建按`rate`参数化的泊松分布
W
wizardforcel 已提交
1911

J
标点  
janhu 已提交
1912
样本是非负整数, pmf是
W
wizardforcel 已提交
1913 1914 1915

![](img/32c47de57300c954795486fea3201bdc.jpg)

J
dis 25%  
janhu 已提交
1916
例子:
W
wizardforcel 已提交
1917 1918 1919 1920 1921 1922 1923 1924

```py
>>> m = Poisson(torch.tensor([4]))
>>> m.sample()
tensor([ 3.])

```

J
dis 80%  
janhu 已提交
1925
| 参数: | **rate** (_Number__,_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – rate 参数 |
J
dis 25%  
janhu 已提交
1926

W
wizardforcel 已提交
1927 1928

```py
W
wizardforcel 已提交
1929
arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
1930 1931 1932
```

```py
W
wizardforcel 已提交
1933
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1934 1935 1936
```

```py
W
wizardforcel 已提交
1937
log_prob(value)
W
wizardforcel 已提交
1938 1939 1940
```

```py
W
wizardforcel 已提交
1941
mean
W
wizardforcel 已提交
1942 1943 1944
```

```py
W
wizardforcel 已提交
1945
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
1946 1947 1948
```

```py
W
wizardforcel 已提交
1949
support = IntegerGreaterThan(lower_bound=0)
W
wizardforcel 已提交
1950 1951 1952
```

```py
W
wizardforcel 已提交
1953
variance
W
wizardforcel 已提交
1954 1955 1956 1957 1958
```

## RelaxedBernoulli

```py
W
wizardforcel 已提交
1959
class torch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
1960 1961
```

J
dis 25%  
janhu 已提交
1962
基类: [`torch.distributions.transformed_distribution.TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution")
W
wizardforcel 已提交
1963

片刻小哥哥's avatar
片刻小哥哥 已提交
1964
创建一个RelaxedBernoulli分布, 通过[`temperature`](#torch.distributions.relaxed_bernoulli.RelaxedBernoulli.temperature "torch.distributions.relaxed_bernoulli.RelaxedBernoulli.temperature")参数化, 以及`probs``logits`(但不是两者).  这是伯努利分布的松弛版本, 因此值在(0,1)中, 并且具有可重参数化的样本. 
W
wizardforcel 已提交
1965

J
dis 25%  
janhu 已提交
1966
例子:
W
wizardforcel 已提交
1967 1968 1969 1970 1971 1972 1973 1974 1975

```py
>>> m = RelaxedBernoulli(torch.tensor([2.2]),
 torch.tensor([0.1, 0.2, 0.3, 0.99]))
>>> m.sample()
tensor([ 0.2951,  0.3442,  0.8918,  0.9021])

```

J
dis 25%  
janhu 已提交
1976
参数: 
W
wizardforcel 已提交
1977

J
dis 80%  
janhu 已提交
1978 1979 1980
*   **temperature** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 松弛 temperature
*   **probs** (_Number__,_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –采样 `1` 的概率
*   **logits** (_Number__,_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 采样 `1` 的对数概率
W
wizardforcel 已提交
1981

W
wizardforcel 已提交
1982

W
wizardforcel 已提交
1983 1984

```py
W
wizardforcel 已提交
1985
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
W
wizardforcel 已提交
1986 1987 1988
```

```py
W
wizardforcel 已提交
1989
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
1990 1991 1992
```

```py
W
wizardforcel 已提交
1993
has_rsample = True
W
wizardforcel 已提交
1994 1995 1996
```

```py
W
wizardforcel 已提交
1997
logits
W
wizardforcel 已提交
1998 1999 2000
```

```py
W
wizardforcel 已提交
2001
probs
W
wizardforcel 已提交
2002 2003 2004
```

```py
W
wizardforcel 已提交
2005
support = Interval(lower_bound=0.0, upper_bound=1.0)
W
wizardforcel 已提交
2006 2007 2008
```

```py
W
wizardforcel 已提交
2009
temperature
W
wizardforcel 已提交
2010 2011 2012 2013 2014
```

## RelaxedOneHotCategorical

```py
W
wizardforcel 已提交
2015
class torch.distributions.relaxed_categorical.RelaxedOneHotCategorical(temperature, probs=None, logits=None, validate_args=None)
W
wizardforcel 已提交
2016 2017
```

J
dis 25%  
janhu 已提交
2018
基类: [`torch.distributions.transformed_distribution.TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution")
W
wizardforcel 已提交
2019

J
标点  
janhu 已提交
2020
创建一个由温度参数化的`RelaxedOneHotCategorical`分布, 以及`probs``logits`.  这是`OneHotCategorical`分布的松弛版本, 因此它的样本是单一的, 并且可以重参数化. 
W
wizardforcel 已提交
2021

J
dis 25%  
janhu 已提交
2022
例子:
W
wizardforcel 已提交
2023 2024 2025 2026 2027 2028 2029 2030 2031

```py
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
 torch.tensor([0.1, 0.2, 0.3, 0.4]))
>>> m.sample()
tensor([ 0.1294,  0.2324,  0.3859,  0.2523])

```

J
dis 25%  
janhu 已提交
2032
参数: 
W
wizardforcel 已提交
2033

J
dis 80%  
janhu 已提交
2034 2035 2036
*   **temperature** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 松弛 temperature
*   **probs** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 事件概率
*   **logits** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –对数事件概率.
W
wizardforcel 已提交
2037

W
wizardforcel 已提交
2038

W
wizardforcel 已提交
2039 2040

```py
W
wizardforcel 已提交
2041
arg_constraints = {'logits': Real(), 'probs': Simplex()}
W
wizardforcel 已提交
2042 2043 2044
```

```py
W
wizardforcel 已提交
2045
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
2046 2047 2048
```

```py
W
wizardforcel 已提交
2049
has_rsample = True
W
wizardforcel 已提交
2050 2051 2052
```

```py
W
wizardforcel 已提交
2053
logits
W
wizardforcel 已提交
2054 2055 2056
```

```py
W
wizardforcel 已提交
2057
probs
W
wizardforcel 已提交
2058 2059 2060
```

```py
W
wizardforcel 已提交
2061
support = Simplex()
W
wizardforcel 已提交
2062 2063 2064
```

```py
W
wizardforcel 已提交
2065
temperature
W
wizardforcel 已提交
2066 2067 2068 2069 2070
```

## StudentT

```py
W
wizardforcel 已提交
2071
class torch.distributions.studentT.StudentT(df, loc=0.0, scale=1.0, validate_args=None)
W
wizardforcel 已提交
2072 2073
```

J
dis 25%  
janhu 已提交
2074
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
2075

J
标点  
janhu 已提交
2076
根据自由度`df`, 平均`loc``scale`创建学生t分布. 
W
wizardforcel 已提交
2077

J
dis 25%  
janhu 已提交
2078
例子:
W
wizardforcel 已提交
2079 2080 2081 2082 2083 2084 2085 2086

```py
>>> m = StudentT(torch.tensor([2.0]))
>>> m.sample()  # Student's t-distributed with degrees of freedom=2
tensor([ 0.1046])

```

J
dis 25%  
janhu 已提交
2087
参数: 
W
wizardforcel 已提交
2088

J
dis 80%  
janhu 已提交
2089 2090 2091
*   **df** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 自由度
*   **loc** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 均值
*   **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 分布的scale
W
wizardforcel 已提交
2092

W
wizardforcel 已提交
2093

W
wizardforcel 已提交
2094 2095

```py
W
wizardforcel 已提交
2096
arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
2097 2098 2099
```

```py
W
wizardforcel 已提交
2100
entropy()
W
wizardforcel 已提交
2101 2102 2103
```

```py
W
wizardforcel 已提交
2104
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
2105 2106 2107
```

```py
W
wizardforcel 已提交
2108
has_rsample = True
W
wizardforcel 已提交
2109 2110 2111
```

```py
W
wizardforcel 已提交
2112
log_prob(value)
W
wizardforcel 已提交
2113 2114 2115
```

```py
W
wizardforcel 已提交
2116
mean
W
wizardforcel 已提交
2117 2118 2119
```

```py
W
wizardforcel 已提交
2120
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
2121 2122 2123
```

```py
W
wizardforcel 已提交
2124
support = Real()
W
wizardforcel 已提交
2125 2126 2127
```

```py
W
wizardforcel 已提交
2128
variance
W
wizardforcel 已提交
2129 2130 2131 2132 2133
```

## TransformedDistribution

```py
W
wizardforcel 已提交
2134
class torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms, validate_args=None)
W
wizardforcel 已提交
2135 2136
```

J
dis 25%  
janhu 已提交
2137
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
2138

J
标点  
janhu 已提交
2139
Distribution类的扩展, 它将一系列变换应用于基本分布. 假设f是所应用变换的组成:
W
wizardforcel 已提交
2140 2141 2142 2143 2144 2145 2146 2147

```py
X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|

```

J
标点  
janhu 已提交
2148
注意 `.event_shape` of a [`TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution") 是其基本分布及其变换的最大形状, 因为变换可以引入事件之间的相关性.
W
wizardforcel 已提交
2149

J
dis 80%  
janhu 已提交
2150
一个使用例子 [`TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution"):
W
wizardforcel 已提交
2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162

```py
# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)

```

J
标点  
janhu 已提交
2163
有关更多示例, 请查看有关实现 [`Gumbel`](#torch.distributions.gumbel.Gumbel "torch.distributions.gumbel.Gumbel"), [`HalfCauchy`](#torch.distributions.half_cauchy.HalfCauchy "torch.distributions.half_cauchy.HalfCauchy"), [`HalfNormal`](#torch.distributions.half_normal.HalfNormal "torch.distributions.half_normal.HalfNormal"), [`LogNormal`](#torch.distributions.log_normal.LogNormal "torch.distributions.log_normal.LogNormal"), [`Pareto`](#torch.distributions.pareto.Pareto "torch.distributions.pareto.Pareto"), [`Weibull`](#torch.distributions.weibull.Weibull "torch.distributions.weibull.Weibull"), [`RelaxedBernoulli`](#torch.distributions.relaxed_bernoulli.RelaxedBernoulli "torch.distributions.relaxed_bernoulli.RelaxedBernoulli")[`RelaxedOneHotCategorical`](#torch.distributions.relaxed_categorical.RelaxedOneHotCategorical "torch.distributions.relaxed_categorical.RelaxedOneHotCategorical")
W
wizardforcel 已提交
2164 2165

```py
W
wizardforcel 已提交
2166
arg_constraints = {}
W
wizardforcel 已提交
2167 2168 2169
```

```py
W
wizardforcel 已提交
2170
cdf(value)
W
wizardforcel 已提交
2171 2172
```

J
dis 80%  
janhu 已提交
2173
通过逆变换和计算基分布的分数来计算累积分布函数.
W
wizardforcel 已提交
2174 2175

```py
W
wizardforcel 已提交
2176
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
2177 2178 2179
```

```py
W
wizardforcel 已提交
2180
has_rsample
W
wizardforcel 已提交
2181 2182 2183
```

```py
W
wizardforcel 已提交
2184
icdf(value)
W
wizardforcel 已提交
2185 2186
```

J
标点  
janhu 已提交
2187
使用transform(s)计算逆累积分布函数, 并计算基分布的分数.
W
wizardforcel 已提交
2188 2189

```py
W
wizardforcel 已提交
2190
log_prob(value)
W
wizardforcel 已提交
2191 2192
```

J
dis 80%  
janhu 已提交
2193
通过反转变换并使用基本分布的分数和日志abs det jacobian计算分数来对样本进行评分
W
wizardforcel 已提交
2194 2195

```py
W
wizardforcel 已提交
2196
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
2197 2198
```

J
标点  
janhu 已提交
2199
如果分布参数是批处理的, 则生成sample_shape形状的重新参数化样本或sample_shape形状的重新参数化样本批次.  首先从基本分布中采样, 并对列表中的每个变换应用`transform()`
W
wizardforcel 已提交
2200 2201

```py
W
wizardforcel 已提交
2202
sample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
2203 2204
```

J
标点  
janhu 已提交
2205
如果分布参数是批处理的, 则生成sample_shape形样本或sample_shape形样本批处理.  首先从基本分布中采样, 并对列表中的每个变换应用`transform()`. 
W
wizardforcel 已提交
2206 2207

```py
W
wizardforcel 已提交
2208
support
W
wizardforcel 已提交
2209 2210 2211 2212 2213
```

## Uniform

```py
W
wizardforcel 已提交
2214
class torch.distributions.uniform.Uniform(low, high, validate_args=None)
W
wizardforcel 已提交
2215 2216
```

J
dis 25%  
janhu 已提交
2217
基类: [`torch.distributions.distribution.Distribution`](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")
W
wizardforcel 已提交
2218

J
dis 80%  
janhu 已提交
2219 2220

从半开区间`[low, high)`生成均匀分布的随机样本
W
wizardforcel 已提交
2221

J
dis 25%  
janhu 已提交
2222
例子:
W
wizardforcel 已提交
2223 2224 2225 2226 2227 2228 2229 2230

```py
>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
>>> m.sample()  # uniformly distributed in the range [0.0, 5.0)
tensor([ 2.3418])

```

J
dis 25%  
janhu 已提交
2231
参数: 
W
wizardforcel 已提交
2232

片刻小哥哥's avatar
片刻小哥哥 已提交
2233
*   **low** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) –  下限(含).
J
dis 80%  
janhu 已提交
2234
*   **high** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 上限(排除).
W
wizardforcel 已提交
2235

W
wizardforcel 已提交
2236 2237

```py
W
wizardforcel 已提交
2238
arg_constraints = {'high': Dependent(), 'low': Dependent()}
W
wizardforcel 已提交
2239 2240 2241
```

```py
W
wizardforcel 已提交
2242
cdf(value)
W
wizardforcel 已提交
2243 2244 2245
```

```py
W
wizardforcel 已提交
2246
entropy()
W
wizardforcel 已提交
2247 2248 2249
```

```py
W
wizardforcel 已提交
2250
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
2251 2252 2253
```

```py
W
wizardforcel 已提交
2254
has_rsample = True
W
wizardforcel 已提交
2255 2256 2257
```

```py
W
wizardforcel 已提交
2258
icdf(value)
W
wizardforcel 已提交
2259 2260 2261
```

```py
W
wizardforcel 已提交
2262
log_prob(value)
W
wizardforcel 已提交
2263 2264 2265
```

```py
W
wizardforcel 已提交
2266
mean
W
wizardforcel 已提交
2267 2268 2269
```

```py
W
wizardforcel 已提交
2270
rsample(sample_shape=torch.Size([]))
W
wizardforcel 已提交
2271 2272 2273
```

```py
W
wizardforcel 已提交
2274
stddev
W
wizardforcel 已提交
2275 2276 2277
```

```py
W
wizardforcel 已提交
2278
support
W
wizardforcel 已提交
2279 2280 2281
```

```py
W
wizardforcel 已提交
2282
variance
W
wizardforcel 已提交
2283 2284 2285 2286 2287
```

## Weibull

```py
W
wizardforcel 已提交
2288
class torch.distributions.weibull.Weibull(scale, concentration, validate_args=None)
W
wizardforcel 已提交
2289 2290
```

J
dis 25%  
janhu 已提交
2291
基类: [`torch.distributions.transformed_distribution.TransformedDistribution`](#torch.distributions.transformed_distribution.TransformedDistribution "torch.distributions.transformed_distribution.TransformedDistribution")
W
wizardforcel 已提交
2292

J
dis 80%  
janhu 已提交
2293
来自双参数Weibull分布的样本.
W
wizardforcel 已提交
2294 2295 2296 2297 2298 2299 2300 2301 2302 2303

Example

```py
>>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # sample from a Weibull distribution with scale=1, concentration=1
tensor([ 0.4784])

```

J
dis 25%  
janhu 已提交
2304
参数: 
W
wizardforcel 已提交
2305

J
dis 80%  
janhu 已提交
2306 2307
*   **scale** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – Scale (lambda).
*   **concentration** ([_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – Concentration (k/shape).
W
wizardforcel 已提交
2308

W
wizardforcel 已提交
2309

W
wizardforcel 已提交
2310 2311

```py
W
wizardforcel 已提交
2312
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}
W
wizardforcel 已提交
2313 2314 2315
```

```py
W
wizardforcel 已提交
2316
entropy()
W
wizardforcel 已提交
2317 2318 2319
```

```py
W
wizardforcel 已提交
2320
expand(batch_shape, _instance=None)
W
wizardforcel 已提交
2321 2322 2323
```

```py
W
wizardforcel 已提交
2324
mean
W
wizardforcel 已提交
2325 2326 2327
```

```py
W
wizardforcel 已提交
2328
support = GreaterThan(lower_bound=0.0)
W
wizardforcel 已提交
2329 2330 2331
```

```py
W
wizardforcel 已提交
2332
variance
W
wizardforcel 已提交
2333 2334
```

W
wizardforcel 已提交
2335
## `KL Divergence`
W
wizardforcel 已提交
2336 2337

```py
W
wizardforcel 已提交
2338
torch.distributions.kl.kl_divergence(p, q)
W
wizardforcel 已提交
2339 2340
```

J
dis 80%  
janhu 已提交
2341
计算Kullback-Leibler散度 ![](img/739a8e4cd0597805c3e4daf35c0fc7c6.jpg) 对于两个分布.
W
wizardforcel 已提交
2342 2343 2344

![](img/ff8dcec3abe559720f8b0b464d2471b2.jpg)

J
dis 25%  
janhu 已提交
2345
参数: 
W
wizardforcel 已提交
2346

J
dis 80%  
janhu 已提交
2347 2348
*   **p** ([_Distribution_](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")) – `Distribution` 对象.
*   **q** ([_Distribution_](#torch.distributions.distribution.Distribution "torch.distributions.distribution.Distribution")) – `Distribution` 对象.
W
wizardforcel 已提交
2349

W
wizardforcel 已提交
2350

J
标点  
janhu 已提交
2351
| 返回值: | 批量的 KL 散度, 形状为 `batch_shape`. |
J
dis 25%  
janhu 已提交
2352

J
dis 80%  
janhu 已提交
2353
| 返回类型: | [Tensor](tensors.html#torch.Tensor "torch.Tensor") |
J
dis 25%  
janhu 已提交
2354

J
dis 80%  
janhu 已提交
2355
| 异常: | [`NotImplementedError`](https://docs.python.org/3/library/exceptions.html#NotImplementedError "(in Python v3.7)") – 如果分布类型尚未通过注册 [`register_kl()`](#torch.distributions.kl.register_kl "torch.distributions.kl.register_kl"). |
J
dis 25%  
janhu 已提交
2356

W
wizardforcel 已提交
2357 2358

```py
W
wizardforcel 已提交
2359
torch.distributions.kl.register_kl(type_p, type_q)
W
wizardforcel 已提交
2360 2361
```

J
dis 80%  
janhu 已提交
2362
装饰器注册[`kl_divergence()`](#torch.distributions.kl.kl_divergence "torch.distributions.kl.kl_divergence")的成对函数
W
wizardforcel 已提交
2363 2364 2365 2366 2367 2368 2369 2370

```py
@register_kl(Normal, Normal)
def kl_normal_normal(p, q):
    # insert implementation here

```

J
标点  
janhu 已提交
2371
Lookup返回由子类排序的最具体(type,type)匹配.  如果匹配不明确, 则会引发`RuntimeWarning`.  例如, 解决模棱两可的情况
W
wizardforcel 已提交
2372 2373 2374 2375 2376 2377 2378 2379 2380

```py
@register_kl(BaseP, DerivedQ)
def kl_version1(p, q): ...
@register_kl(DerivedP, BaseQ)
def kl_version2(p, q): ...

```

J
dis 80%  
janhu 已提交
2381
你应该注册第三个最具体的实现, 例如:
W
wizardforcel 已提交
2382 2383 2384 2385 2386 2387

```py
register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.

```

J
dis 25%  
janhu 已提交
2388
参数: 
W
wizardforcel 已提交
2389

J
dis 80%  
janhu 已提交
2390 2391
*   **type_p** ([_type_](https://docs.python.org/3/library/functions.html#type "(in Python v3.7)")) – 子类 `Distribution`.
*   **type_q** ([_type_](https://docs.python.org/3/library/functions.html#type "(in Python v3.7)")) – 子类 `Distribution`.
W
wizardforcel 已提交
2392

W
wizardforcel 已提交
2393

W
wizardforcel 已提交
2394

W
wizardforcel 已提交
2395
## `Transforms`
W
wizardforcel 已提交
2396 2397

```py
W
wizardforcel 已提交
2398
class torch.distributions.transforms.Transform(cache_size=0)
W
wizardforcel 已提交
2399 2400
```

J
标点  
janhu 已提交
2401
有可计算的log det jacobians进行可逆变换的抽象类.  它们主要用于 `torch.distributions.TransformedDistribution`.
W
wizardforcel 已提交
2402

J
标点  
janhu 已提交
2403
缓存对于其反转昂贵或数值不稳定的变换很有用.  请注意, 必须注意记忆值, 因为可以颠倒自动记录图.  例如, 以下操作有或没有缓存:
W
wizardforcel 已提交
2404 2405 2406 2407 2408 2409 2410

```py
y = t(x)
t.log_abs_det_jacobian(x, y).backward()  # x will receive gradients.

```

J
标点  
janhu 已提交
2411
但是, 由于依赖性反转, 缓存时会出现以下错误:
W
wizardforcel 已提交
2412 2413 2414 2415 2416 2417 2418 2419

```py
y = t(x)
z = t.inv(y)
grad(z.sum(), [y])  # error because z is x

```

J
标点  
janhu 已提交
2420
 派生类应该实现`_call()``_inverse()`中的一个或两个.  设置`bijective=True`的派生类也应该实现`log_abs_det_jacobian()`
W
wizardforcel 已提交
2421

J
标点  
janhu 已提交
2422
| 参数: | **cache_size** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – 缓存大小.  如果为零, 则不进行缓存.  如果是, 则缓存最新的单个值.  仅支持0和1 |
J
dis 25%  
janhu 已提交
2423

W
wizardforcel 已提交
2424 2425
| Variables: | 

J
标点  
janhu 已提交
2426 2427
*   **domain** ([`Constraint`](#torch.distributions.constraints.Constraint "torch.distributions.constraints.Constraint")) –  表示该变换有效输入的约束.
*   **codomain** ([`Constraint`](#torch.distributions.constraints.Constraint "torch.distributions.constraints.Constraint")) – 表示此转换的有效输出的约束, 这些输出是逆变换的输入.
片刻小哥哥's avatar
片刻小哥哥 已提交
2428
*   **bijective** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")) –  这个变换是否是双射的. 变换 `t` 是双射的 如果 `t.inv(t(x)) == x` 并且 `t(t.inv(y)) == y` 对于每一个 `x``y`. 不是双射的变换应该至少保持较弱的伪逆属性 `t(t.inv(t(x)) == t(x)` and `t.inv(t(t.inv(y))) == t.inv(y)`.
J
标点  
janhu 已提交
2429 2430
*   **sign** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)") _or_ [_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 对于双射单变量变换, 它应该是+1或-1, 这取决于变换是单调递增还是递减.
*   **event_dim** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – 变换event_shape中相关的维数.  这对于逐点变换应该是0, 对于在矢量上共同作用的变换是1, 对于在矩阵上共同作用的变换是2, 等等.
W
wizardforcel 已提交
2431

W
wizardforcel 已提交
2432

W
wizardforcel 已提交
2433
```py
W
wizardforcel 已提交
2434
inv
W
wizardforcel 已提交
2435 2436
```

J
janhu 已提交
2437
返回逆[`Transform`](#torch.distributions.transforms.Transform "torch.distributions.transforms.Transform"). 满足 `t.inv.inv is t`.
W
wizardforcel 已提交
2438 2439

```py
W
wizardforcel 已提交
2440
sign
W
wizardforcel 已提交
2441 2442
```

J
janhu 已提交
2443
如果适用, 返回雅可比行列式的符号.  一般来说, 这只适用于双射变换.
W
wizardforcel 已提交
2444 2445

```py
W
wizardforcel 已提交
2446
log_abs_det_jacobian(x, y)
W
wizardforcel 已提交
2447 2448
```

J
janhu 已提交
2449
计算 log det jacobian `log |dy/dx|` 给定输入和输出.
W
wizardforcel 已提交
2450 2451

```py
W
wizardforcel 已提交
2452
class torch.distributions.transforms.ComposeTransform(parts)
W
wizardforcel 已提交
2453 2454
```

J
janhu 已提交
2455
在一个链中组合多个转换. 正在组合的转换负责缓存.
W
wizardforcel 已提交
2456

J
janhu 已提交
2457
| 参数: | **parts** (list of [`Transform`](#torch.distributions.transforms.Transform "torch.distributions.transforms.Transform")) – 列表 transforms. |
J
dis 25%  
janhu 已提交
2458

W
wizardforcel 已提交
2459 2460

```py
W
wizardforcel 已提交
2461
class torch.distributions.transforms.ExpTransform(cache_size=0)
W
wizardforcel 已提交
2462 2463
```

J
janhu 已提交
2464
转换通过映射 ![](img/ec8d939394f24908d017d86153e312ea.jpg).
W
wizardforcel 已提交
2465 2466

```py
W
wizardforcel 已提交
2467
class torch.distributions.transforms.PowerTransform(exponent, cache_size=0)
W
wizardforcel 已提交
2468 2469
```

J
janhu 已提交
2470
转换通过映射 ![](img/2062af7179e0c19c3599816de6768cee.jpg).
W
wizardforcel 已提交
2471 2472

```py
W
wizardforcel 已提交
2473
class torch.distributions.transforms.SigmoidTransform(cache_size=0)
W
wizardforcel 已提交
2474 2475
```

J
janhu 已提交
2476
转换通过映射 ![](img/749abef3418941161a1c6ff80d9eae76.jpg) and ![](img/6feb73eb74f2267e5caa87d9693362cb.jpg).
W
wizardforcel 已提交
2477 2478

```py
W
wizardforcel 已提交
2479
class torch.distributions.transforms.AbsTransform(cache_size=0)
W
wizardforcel 已提交
2480 2481
```

J
janhu 已提交
2482
转换通过映射 ![](img/dca0dc2e17c81b7ec261e70549de5507.jpg).
W
wizardforcel 已提交
2483 2484

```py
W
wizardforcel 已提交
2485
class torch.distributions.transforms.AffineTransform(loc, scale, event_dim=0, cache_size=0)
W
wizardforcel 已提交
2486 2487
```

J
janhu 已提交
2488
通过逐点仿射映射![](img/e1df459e7ff26d682fc956b62868f7c4.jpg)进行转换 .
W
wizardforcel 已提交
2489

J
dis 25%  
janhu 已提交
2490
参数: 
W
wizardforcel 已提交
2491

J
janhu 已提交
2492 2493 2494
*   **loc** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor") _or_ [_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)")) – Location.
*   **scale** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor") _or_ [_float_](https://docs.python.org/3/library/functions.html#float "(in Python v3.7)")) – Scale.
*   **event_dim** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – 可选的 `event_shape` 大小. T对于单变量随机变量, 该值应为零, 对于矢量分布, 1应为零, 对于矩阵的分布, 应为2.
W
wizardforcel 已提交
2495

W
wizardforcel 已提交
2496

W
wizardforcel 已提交
2497 2498

```py
W
wizardforcel 已提交
2499
class torch.distributions.transforms.SoftmaxTransform(cache_size=0)
W
wizardforcel 已提交
2500 2501
```

J
janhu 已提交
2502
从无约束空间到单纯形的转换, 通过 ![](img/ec8d939394f24908d017d86153e312ea.jpg) 然后归一化.
W
wizardforcel 已提交
2503

片刻小哥哥's avatar
片刻小哥哥 已提交
2504
这不是双射的, 不能用于HMC.  然而, 这主要是协调的(除了最终的归一化), 因此适合于坐标方式的优化算法. 
W
wizardforcel 已提交
2505 2506

```py
W
wizardforcel 已提交
2507
class torch.distributions.transforms.StickBreakingTransform(cache_size=0)
W
wizardforcel 已提交
2508 2509
```

J
janhu 已提交
2510
将无约束空间通过 stick-breaking 过程转化为一个额外维度的单纯形. 
W
wizardforcel 已提交
2511

J
janhu 已提交
2512
这种变换是`Dirichlet`分布的破棒构造中的迭代sigmoid变换:第一个逻辑通过sigmoid变换成第一个概率和所有其他概率, 然后这个过程重复出现. 
W
wizardforcel 已提交
2513

J
janhu 已提交
2514
这是双射的, 适合在HMC中使用; 然而, 它将坐标混合在一起, 不太适合优化.
W
wizardforcel 已提交
2515 2516

```py
W
wizardforcel 已提交
2517
class torch.distributions.transforms.LowerCholeskyTransform(cache_size=0)
W
wizardforcel 已提交
2518 2519
```

J
janhu 已提交
2520
将无约束矩阵转换为具有非负对角项的下三角矩阵.
W
wizardforcel 已提交
2521

J
janhu 已提交
2522
这对于根据Cholesky分解来参数化正定矩阵是有用的.
W
wizardforcel 已提交
2523

W
wizardforcel 已提交
2524
## `Constraints`
W
wizardforcel 已提交
2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544

The following constraints are implemented:

*   `constraints.boolean`
*   `constraints.dependent`
*   `constraints.greater_than(lower_bound)`
*   `constraints.integer_interval(lower_bound, upper_bound)`
*   `constraints.interval(lower_bound, upper_bound)`
*   `constraints.lower_cholesky`
*   `constraints.lower_triangular`
*   `constraints.nonnegative_integer`
*   `constraints.positive`
*   `constraints.positive_definite`
*   `constraints.positive_integer`
*   `constraints.real`
*   `constraints.real_vector`
*   `constraints.simplex`
*   `constraints.unit_interval`

```py
W
wizardforcel 已提交
2545
class torch.distributions.constraints.Constraint
W
wizardforcel 已提交
2546 2547
```

J
janhu 已提交
2548
constraints 的抽象基类.
W
wizardforcel 已提交
2549

J
janhu 已提交
2550
constraint对象表示变量有效的区域, 例如,  其中可以优化变量
W
wizardforcel 已提交
2551 2552

```py
W
wizardforcel 已提交
2553
check(value)
W
wizardforcel 已提交
2554 2555
```

J
janhu 已提交
2556
返回一个字节张量 `sample_shape + batch_shape` 指示值中的每个事件是否满足此约束.
W
wizardforcel 已提交
2557 2558

```py
W
wizardforcel 已提交
2559
torch.distributions.constraints.dependent_property
W
wizardforcel 已提交
2560 2561 2562 2563 2564
```

alias of `torch.distributions.constraints._DependentProperty`

```py
W
wizardforcel 已提交
2565
torch.distributions.constraints.integer_interval
W
wizardforcel 已提交
2566 2567 2568 2569 2570
```

alias of `torch.distributions.constraints._IntegerInterval`

```py
W
wizardforcel 已提交
2571
torch.distributions.constraints.greater_than
W
wizardforcel 已提交
2572 2573 2574 2575 2576
```

alias of `torch.distributions.constraints._GreaterThan`

```py
W
wizardforcel 已提交
2577
torch.distributions.constraints.greater_than_eq
W
wizardforcel 已提交
2578 2579 2580 2581 2582
```

alias of `torch.distributions.constraints._GreaterThanEq`

```py
W
wizardforcel 已提交
2583
torch.distributions.constraints.less_than
W
wizardforcel 已提交
2584 2585 2586 2587 2588
```

alias of `torch.distributions.constraints._LessThan`

```py
W
wizardforcel 已提交
2589
torch.distributions.constraints.interval
W
wizardforcel 已提交
2590 2591 2592 2593 2594
```

alias of `torch.distributions.constraints._Interval`

```py
W
wizardforcel 已提交
2595
torch.distributions.constraints.half_open_interval
W
wizardforcel 已提交
2596 2597 2598 2599
```

alias of `torch.distributions.constraints._HalfOpenInterval`

W
wizardforcel 已提交
2600
## `Constraint Registry`
W
wizardforcel 已提交
2601

J
janhu 已提交
2602
PyTorch 提供两个全局 [`ConstraintRegistry`](#torch.distributions.constraint_registry.ConstraintRegistry "torch.distributions.constraint_registry.ConstraintRegistry") 对象 , 链接 [`Constraint`](#torch.distributions.constraints.Constraint "torch.distributions.constraints.Constraint") 对象到 [`Transform`](#torch.distributions.transforms.Transform "torch.distributions.transforms.Transform") 对象. 这些对象既有输入约束, 也有返回变换, 但是它们对双射性有不同的保证.
W
wizardforcel 已提交
2603

J
janhu 已提交
2604 2605
1.  `biject_to(constraint)`  查找一个双射的 [`Transform`](#torch.distributions.transforms.Transform "torch.distributions.transforms.Transform")`constraints.real` 到给定的 `constraint`.  返回的转换保证具有 `.bijective = True` 并且应该实现了 `.log_abs_det_jacobian()`.
2.  `transform_to(constraint)` 查找一个不一定是双射的 [`Transform`](#torch.distributions.transforms.Transform "torch.distributions.transforms.Transform")`constraints.real` 到给定的 `constraint`. 返回的转换不保证实现 `.log_abs_det_jacobian()`.
W
wizardforcel 已提交
2606

J
janhu 已提交
2607
`transform_to()`注册表对于对概率分布的约束参数执行无约束优化非常有用, 这些参数由每个分布的`.arg_constraints`指示.  这些变换通常会过度参数化空间以避免旋转; 因此, 它们更适合像Adam那样的坐标优化算法
W
wizardforcel 已提交
2608 2609 2610 2611 2612 2613 2614 2615 2616

```py
loc = torch.zeros(100, requires_grad=True)
unconstrained = torch.zeros(100, requires_grad=True)
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
loss = -Normal(loc, scale).log_prob(data).sum()

```

J
janhu 已提交
2617
`biject_to()` 注册表对于Hamiltonian Monte Carlo非常有用, 其中来自具有约束. `.support`的概率分布的样本在无约束空间中传播, 并且算法通常是旋转不变的
W
wizardforcel 已提交
2618 2619 2620 2621 2622 2623 2624 2625 2626

```py
dist = Exponential(rate)
unconstrained = torch.zeros(100, requires_grad=True)
sample = biject_to(dist.support)(unconstrained)
potential_energy = -dist.log_prob(sample).sum()

```

J
janhu 已提交
2627
注意
W
wizardforcel 已提交
2628

J
janhu 已提交
2629
一个 `transform_to``biject_to` 不同的例子是 `constraints.simplex`: `transform_to(constraints.simplex)` 返回一个 [`SoftmaxTransform`](#torch.distributions.transforms.SoftmaxTransform "torch.distributions.transforms.SoftmaxTransform") 简单地对其输入进行指数化和归一化;  这是一种廉价且主要是坐标的操作, 适用于像SVI这样的算法. 相反, `biject_to(constraints.simplex)` 返回一个 [`StickBreakingTransform`](#torch.distributions.transforms.StickBreakingTransform "torch.distributions.transforms.StickBreakingTransform") 将其输入生成一个较小维度的空间; 这是一种更昂贵的数值更少的数值稳定的变换, 但对于像HM​​C这样的算法是必需的. 
W
wizardforcel 已提交
2630

J
janhu 已提交
2631
`biject_to``transform_to` 对象可以通过用户定义的约束进行扩展, 并使用`.register()`方法进行转换, 作为单例约束的函数
W
wizardforcel 已提交
2632 2633 2634 2635 2636 2637

```py
transform_to.register(my_constraint, my_transform)

```

J
janhu 已提交
2638
或作为参数化约束的装饰器:
W
wizardforcel 已提交
2639 2640 2641 2642 2643 2644 2645 2646 2647

```py
@transform_to.register(MyConstraintClass)
def my_factory(constraint):
    assert isinstance(constraint, MyConstraintClass)
    return MyTransform(constraint.param1, constraint.param2)

```

J
janhu 已提交
2648
 您可以通过创建新的[`ConstraintRegistry`](#torch.distributions.constraint_registry.ConstraintRegistry "torch.distributions.constraint_registry.ConstraintRegistry")创建自己的注册表.
W
wizardforcel 已提交
2649 2650

```py
W
wizardforcel 已提交
2651
class torch.distributions.constraint_registry.ConstraintRegistry
W
wizardforcel 已提交
2652 2653
```

J
janhu 已提交
2654
注册表, 将约束链接到转换.
W
wizardforcel 已提交
2655 2656

```py
W
wizardforcel 已提交
2657
register(constraint, factory=None)
W
wizardforcel 已提交
2658 2659
```

J
janhu 已提交
2660
在此注册表注册一个 [`Constraint`](#torch.distributions.constraints.Constraint "torch.distributions.constraints.Constraint") 子类. 用法:
W
wizardforcel 已提交
2661 2662 2663 2664 2665 2666 2667 2668 2669

```py
@my_registry.register(MyConstraintClass)
def construct_transform(constraint):
    assert isinstance(constraint, MyConstraint)
    return MyTransform(constraint.arg_constraints)

```

J
dis 25%  
janhu 已提交
2670
参数: 
W
wizardforcel 已提交
2671

J
janhu 已提交
2672 2673
*   **constraint** (subclass of [`Constraint`](#torch.distributions.constraints.Constraint "torch.distributions.constraints.Constraint")) –  [`Constraint`]的子类(#torch.distributions.constraints.Constraint "torch.distributions.constraints.Constraint"), 或者派生类的对象.
*   **factory** (_callable_) – 可调用对象, 输入 constraint 对象返回 [`Transform`](#torch.distributions.transforms.Transform "torch.distributions.transforms.Transform") 对象.
W
wizardforcel 已提交
2674

W
wizardforcel 已提交
2675

W
wizardforcel 已提交
2676