diff --git a/python/paddle/fluid/layers/distributions.py b/python/paddle/fluid/layers/distributions.py index 6954b2d88298e98c66358828f461bb676a0a3da6..69b99bf577f7515d6a5cf67dffdfa6b319412d85 100644 --- a/python/paddle/fluid/layers/distributions.py +++ b/python/paddle/fluid/layers/distributions.py @@ -404,8 +404,18 @@ class Categorical(Distribution): one of K possible categories, with the probability of each category separately specified. + The probability mass function (pmf) is: + + .. math:: + + pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]} + + In the above equation: + + * :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise. + Args: - logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. + logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32. Examples: .. code-block:: python @@ -439,7 +449,7 @@ class Categorical(Distribution): def __init__(self, logits): """ Args: - logits: A float32 tensor + logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32. """ if self._validate_args(logits): self.logits = logits @@ -450,7 +460,7 @@ class Categorical(Distribution): """The KL-divergence between two Categorical distributions. Args: - other (Categorical): instance of Categorical. + other (Categorical): instance of Categorical. The data type is float32. Returns: Variable: kl-divergence between two Categorical distributions. @@ -477,7 +487,7 @@ class Categorical(Distribution): """Shannon entropy in nats. Returns: - Variable: Shannon entropy of Categorical distribution. + Variable: Shannon entropy of Categorical distribution. The data type is float32. """ logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) @@ -495,10 +505,31 @@ class MultivariateNormalDiag(Distribution): A multivariate normal (also called Gaussian) distribution parameterized by a mean vector and a covariance matrix. + The probability density function (pdf) is: + + .. math:: + + pdf(x; loc, scale) = \\frac{e^{-\\frac{||y||^2}{2}}}{Z} + + where: + .. math:: + + y = inv(scale) @ (x - loc) + Z = (2\\pi)^{0.5k} |det(scale)| + + + In the above equation: + + * :math:`inv` : denotes to take the inverse of the matrix. + * :math:`@` : denotes matrix multiplication. + * :math:`det` : denotes to evaluate the determinant. + Args: - loc(list|numpy.ndarray|Variable): The mean of multivariateNormal distribution. - scale(list|numpy.ndarray|Variable): The positive definite diagonal covariance matrix of - multivariateNormal distribution. + loc(list|numpy.ndarray|Variable): The mean of multivariateNormal distribution with shape :math:`[k]` . + The data type is float32. + scale(list|numpy.ndarray|Variable): The positive definite diagonal covariance matrix of multivariateNormal + distribution with shape :math:`[k, k]` . All elements are 0 except diagonal elements. The data type is + float32. Examples: .. code-block:: python @@ -570,7 +601,7 @@ class MultivariateNormalDiag(Distribution): """Shannon entropy in nats. Returns: - Variable: Shannon entropy of Multivariate Normal distribution. + Variable: Shannon entropy of Multivariate Normal distribution. The data type is float32. """ entropy = 0.5 * ( @@ -586,7 +617,7 @@ class MultivariateNormalDiag(Distribution): other (MultivariateNormalDiag): instance of Multivariate Normal. Returns: - Variable: kl-divergence between two Multivariate Normal distributions. + Variable: kl-divergence between two Multivariate Normal distributions. The data type is float32. """ assert isinstance(other, MultivariateNormalDiag)