dirichlet.py 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16 17
from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_variable_and_dtype
18
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
19
from paddle.fluid.layer_helper import LayerHelper
20 21


22
class Dirichlet(exponential_family.ExponentialFamily):
23
    r"""
24
    Dirichlet distribution with parameter "concentration".
25 26 27 28 29

    The Dirichlet distribution is defined over the `(k-1)-simplex` using a 
    positive, lenght-k vector concentration(`k > 1`).
    The Dirichlet is identically the Beta distribution when `k = 2`.

30 31 32
    For independent and identically distributed continuous random variable 
    :math:`\boldsymbol X \in R_k` , and support 
    :math:`\boldsymbol X \in (0,1), ||\boldsymbol X|| = 1` , 
33 34 35
    The probability density function (pdf) is

    .. math::
36 37
    
        f(\boldsymbol X; \boldsymbol \alpha) = \frac{1}{B(\boldsymbol \alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1} 
38

39 40
    where :math:`\boldsymbol \alpha = {\alpha_1,...,\alpha_k}, k \ge 2` is 
    parameter, the normalizing constant is the multivariate beta function.
41

42 43 44 45 46 47
    .. math::

        B(\boldsymbol \alpha) = \frac{\prod_{i=1}^{k} \Gamma(\alpha_i)}{\Gamma(\alpha_0)}

    :math:`\alpha_0=\sum_{i=1}^{k} \alpha_i` is the sum of parameters, 
    :math:`\Gamma(\alpha)` is gamma function.
48 49

    Args:
50 51 52 53 54
        concentration (Tensor): "Concentration" parameter of dirichlet 
            distribution, also called :math:`\alpha`. When it's over one 
            dimension, the last axis denotes the parameter of distribution,
            ``event_shape=concentration.shape[-1:]`` , axes other than last are
            condsider batch dimensions with ``batch_shape=concentration.shape[:-1]`` .
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

    Examples:

        .. code-block:: python

            import paddle

            dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))

            print(dirichlet.entropy())
            # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            #        [-1.24434423])
            print(dirichlet.prob(paddle.to_tensor([.3, .5, .6])))
            # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            #        [10.80000114])

    """

    def __init__(self, concentration):
        if concentration.dim() < 1:
            raise ValueError(
                "`concentration` parameter must be at least one dimensional")

        self.concentration = concentration
        super(Dirichlet, self).__init__(concentration.shape[:-1],
                                        concentration.shape[-1:])

    @property
    def mean(self):
84
        """Mean of Dirichelt distribution.
85 86

        Returns:
87
            Mean value of distribution.
88 89 90 91 92
        """
        return self.concentration / self.concentration.sum(-1, keepdim=True)

    @property
    def variance(self):
93
        """Variance of Dirichlet distribution.
94 95

        Returns:
96
            Variance value of distribution.
97 98 99 100 101 102
        """
        concentration0 = self.concentration.sum(-1, keepdim=True)
        return (self.concentration * (concentration0 - self.concentration)) / (
            concentration0.pow(2) * (concentration0 + 1))

    def sample(self, shape=()):
103
        """Sample from dirichlet distribution.
104 105

        Args:
106
            shape (Sequence[int], optional): Sample shape. Defaults to empty tuple.
107 108 109 110 111
        """
        shape = shape if isinstance(shape, tuple) else tuple(shape)
        return _dirichlet(self.concentration.expand(self._extend_shape(shape)))

    def prob(self, value):
112
        """Probability density function(PDF) evaluated at value.
113 114

        Args:
115
            value (Tensor): Value to be evaluated.
116 117

        Returns:
118
            PDF evaluated at value.
119 120 121 122
        """
        return paddle.exp(self.log_prob(value))

    def log_prob(self, value):
123
        """Log of probability densitiy function.
124 125

        Args:
126
            value (Tensor): Value to be evaluated.
127
        """
128 129
        return ((paddle.log(value) * (self.concentration - 1.0)).sum(-1) +
                paddle.lgamma(self.concentration.sum(-1)) -
130 131 132
                paddle.lgamma(self.concentration).sum(-1))

    def entropy(self):
133
        """Entropy of Dirichlet distribution.
134 135

        Returns:
136
            Entropy of distribution.
137 138 139 140 141
        """
        concentration0 = self.concentration.sum(-1)
        k = self.concentration.shape[-1]
        return (paddle.lgamma(self.concentration).sum(-1) -
                paddle.lgamma(concentration0) -
142 143 144
                (k - concentration0) * paddle.digamma(concentration0) -
                ((self.concentration - 1.0) *
                 paddle.digamma(self.concentration)).sum(-1))
145 146 147 148 149 150 151 152 153 154

    @property
    def _natural_parameters(self):
        return (self.concentration, )

    def _log_normalizer(self, x):
        return x.lgamma().sum(-1) - paddle.lgamma(x.sum(-1))


def _dirichlet(concentration, name=None):
155 156 157 158 159
    op_type = 'dirichlet'

    check_variable_and_dtype(concentration, 'concentration',
                             ['float32', 'float64'], op_type)

160 161 162
    if in_dygraph_mode():
        return paddle._C_ops.final_state_dirichlet(concentration)
    elif _in_legacy_dygraph():
163 164 165 166 167
        return paddle._C_ops.dirichlet(concentration)
    else:
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=concentration.dtype)
168 169 170 171
        helper.append_op(type=op_type,
                         inputs={"Alpha": concentration},
                         outputs={'Out': out},
                         attrs={})
172
        return out