未验证 提交 28dd2a58 编写于 作者: A Aurelius84 提交者: GitHub

refine Categorical and MultivariateNormalDiag en doc (#20723)

* refine Categorical and MultivariateNormalDiag en doc test=develop, test=document_fix

* refine Categorical and MultivariateNormalDiag en doc test=develop, test=document_fix
上级 dfa0549f
......@@ -404,8 +404,18 @@ class Categorical(Distribution):
one of K possible categories, with the probability of each category
separately specified.
The probability mass function (pmf) is:
.. math::
pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]}
In the above equation:
* :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise.
Args:
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution.
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32.
Examples:
.. code-block:: python
......@@ -439,7 +449,7 @@ class Categorical(Distribution):
def __init__(self, logits):
"""
Args:
logits: A float32 tensor
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32.
"""
if self._validate_args(logits):
self.logits = logits
......@@ -450,7 +460,7 @@ class Categorical(Distribution):
"""The KL-divergence between two Categorical distributions.
Args:
other (Categorical): instance of Categorical.
other (Categorical): instance of Categorical. The data type is float32.
Returns:
Variable: kl-divergence between two Categorical distributions.
......@@ -477,7 +487,7 @@ class Categorical(Distribution):
"""Shannon entropy in nats.
Returns:
Variable: Shannon entropy of Categorical distribution.
Variable: Shannon entropy of Categorical distribution. The data type is float32.
"""
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
......@@ -495,10 +505,31 @@ class MultivariateNormalDiag(Distribution):
A multivariate normal (also called Gaussian) distribution parameterized by a mean vector
and a covariance matrix.
The probability density function (pdf) is:
.. math::
pdf(x; loc, scale) = \\frac{e^{-\\frac{||y||^2}{2}}}{Z}
where:
.. math::
y = inv(scale) @ (x - loc)
Z = (2\\pi)^{0.5k} |det(scale)|
In the above equation:
* :math:`inv` : denotes to take the inverse of the matrix.
* :math:`@` : denotes matrix multiplication.
* :math:`det` : denotes to evaluate the determinant.
Args:
loc(list|numpy.ndarray|Variable): The mean of multivariateNormal distribution.
scale(list|numpy.ndarray|Variable): The positive definite diagonal covariance matrix of
multivariateNormal distribution.
loc(list|numpy.ndarray|Variable): The mean of multivariateNormal distribution with shape :math:`[k]` .
The data type is float32.
scale(list|numpy.ndarray|Variable): The positive definite diagonal covariance matrix of multivariateNormal
distribution with shape :math:`[k, k]` . All elements are 0 except diagonal elements. The data type is
float32.
Examples:
.. code-block:: python
......@@ -570,7 +601,7 @@ class MultivariateNormalDiag(Distribution):
"""Shannon entropy in nats.
Returns:
Variable: Shannon entropy of Multivariate Normal distribution.
Variable: Shannon entropy of Multivariate Normal distribution. The data type is float32.
"""
entropy = 0.5 * (
......@@ -586,7 +617,7 @@ class MultivariateNormalDiag(Distribution):
other (MultivariateNormalDiag): instance of Multivariate Normal.
Returns:
Variable: kl-divergence between two Multivariate Normal distributions.
Variable: kl-divergence between two Multivariate Normal distributions. The data type is float32.
"""
assert isinstance(other, MultivariateNormalDiag)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册