dirichlet.py 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16 17
from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_variable_and_dtype
18
from paddle.fluid.framework import in_dygraph_mode
19
from paddle.fluid.layer_helper import LayerHelper
20 21


22
class Dirichlet(exponential_family.ExponentialFamily):
23
    r"""
24
    Dirichlet distribution with parameter "concentration".
25

26
    The Dirichlet distribution is defined over the `(k-1)-simplex` using a
27 28 29
    positive, lenght-k vector concentration(`k > 1`).
    The Dirichlet is identically the Beta distribution when `k = 2`.

30 31 32
    For independent and identically distributed continuous random variable
    :math:`\boldsymbol X \in R_k` , and support
    :math:`\boldsymbol X \in (0,1), ||\boldsymbol X|| = 1` ,
33 34 35 36
    The probability density function (pdf) is

    .. math::

37 38 39
        f(\boldsymbol X; \boldsymbol \alpha) = \frac{1}{B(\boldsymbol \alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1}

    where :math:`\boldsymbol \alpha = {\alpha_1,...,\alpha_k}, k \ge 2` is
40
    parameter, the normalizing constant is the multivariate beta function.
41

42 43 44 45
    .. math::

        B(\boldsymbol \alpha) = \frac{\prod_{i=1}^{k} \Gamma(\alpha_i)}{\Gamma(\alpha_0)}

46
    :math:`\alpha_0=\sum_{i=1}^{k} \alpha_i` is the sum of parameters,
47
    :math:`\Gamma(\alpha)` is gamma function.
48 49

    Args:
50 51
        concentration (Tensor): "Concentration" parameter of dirichlet
            distribution, also called :math:`\alpha`. When it's over one
52 53 54
            dimension, the last axis denotes the parameter of distribution,
            ``event_shape=concentration.shape[-1:]`` , axes other than last are
            condsider batch dimensions with ``batch_shape=concentration.shape[:-1]`` .
55 56 57 58 59 60 61 62 63 64

    Examples:

        .. code-block:: python

            import paddle

            dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))

            print(dirichlet.entropy())
65
            # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
66
            #        -1.24434423)
67
            print(dirichlet.prob(paddle.to_tensor([.3, .5, .6])))
68
            # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
69
            #        10.80000114)
70 71 72 73 74 75

    """

    def __init__(self, concentration):
        if concentration.dim() < 1:
            raise ValueError(
76 77
                "`concentration` parameter must be at least one dimensional"
            )
78 79

        self.concentration = concentration
80
        super().__init__(concentration.shape[:-1], concentration.shape[-1:])
81 82 83

    @property
    def mean(self):
84
        """Mean of Dirichelt distribution.
85 86

        Returns:
87
            Mean value of distribution.
88 89 90 91 92
        """
        return self.concentration / self.concentration.sum(-1, keepdim=True)

    @property
    def variance(self):
93
        """Variance of Dirichlet distribution.
94 95

        Returns:
96
            Variance value of distribution.
97 98 99
        """
        concentration0 = self.concentration.sum(-1, keepdim=True)
        return (self.concentration * (concentration0 - self.concentration)) / (
100 101
            concentration0.pow(2) * (concentration0 + 1)
        )
102 103

    def sample(self, shape=()):
104
        """Sample from dirichlet distribution.
105 106

        Args:
107
            shape (Sequence[int], optional): Sample shape. Defaults to empty tuple.
108 109 110 111 112
        """
        shape = shape if isinstance(shape, tuple) else tuple(shape)
        return _dirichlet(self.concentration.expand(self._extend_shape(shape)))

    def prob(self, value):
113
        """Probability density function(PDF) evaluated at value.
114 115

        Args:
116
            value (Tensor): Value to be evaluated.
117 118

        Returns:
119
            PDF evaluated at value.
120 121 122 123
        """
        return paddle.exp(self.log_prob(value))

    def log_prob(self, value):
124
        """Log of probability densitiy function.
125 126

        Args:
127
            value (Tensor): Value to be evaluated.
128
        """
129 130 131 132 133
        return (
            (paddle.log(value) * (self.concentration - 1.0)).sum(-1)
            + paddle.lgamma(self.concentration.sum(-1))
            - paddle.lgamma(self.concentration).sum(-1)
        )
134 135

    def entropy(self):
136
        """Entropy of Dirichlet distribution.
137 138

        Returns:
139
            Entropy of distribution.
140 141 142
        """
        concentration0 = self.concentration.sum(-1)
        k = self.concentration.shape[-1]
143 144 145 146 147 148 149 150
        return (
            paddle.lgamma(self.concentration).sum(-1)
            - paddle.lgamma(concentration0)
            - (k - concentration0) * paddle.digamma(concentration0)
            - (
                (self.concentration - 1.0) * paddle.digamma(self.concentration)
            ).sum(-1)
        )
151 152 153

    @property
    def _natural_parameters(self):
154
        return (self.concentration,)
155 156 157 158 159 160

    def _log_normalizer(self, x):
        return x.lgamma().sum(-1) - paddle.lgamma(x.sum(-1))


def _dirichlet(concentration, name=None):
161

162
    if in_dygraph_mode():
163 164
        return paddle._C_ops.dirichlet(concentration)
    else:
165 166
        op_type = 'dirichlet'
        check_variable_and_dtype(
167 168 169 170
            concentration,
            'concentration',
            ['float16', 'float32', 'float64', 'uint16'],
            op_type,
171
        )
172 173
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
174 175 176 177 178 179 180 181
            dtype=concentration.dtype
        )
        helper.append_op(
            type=op_type,
            inputs={"Alpha": concentration},
            outputs={'Out': out},
            attrs={},
        )
182
        return out