dirichlet.py 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from ..fluid.data_feeder import check_variable_and_dtype
J
Jiabin Yang 已提交
18
from ..fluid.framework import _non_static_mode
19 20 21 22 23 24
from ..fluid.layer_helper import LayerHelper
from .exponential_family import ExponentialFamily


class Dirichlet(ExponentialFamily):
    r"""
25
    Dirichlet distribution with parameter "concentration".
26 27 28 29 30

    The Dirichlet distribution is defined over the `(k-1)-simplex` using a 
    positive, lenght-k vector concentration(`k > 1`).
    The Dirichlet is identically the Beta distribution when `k = 2`.

31 32 33
    For independent and identically distributed continuous random variable 
    :math:`\boldsymbol X \in R_k` , and support 
    :math:`\boldsymbol X \in (0,1), ||\boldsymbol X|| = 1` , 
34 35 36
    The probability density function (pdf) is

    .. math::
37 38
    
        f(\boldsymbol X; \boldsymbol \alpha) = \frac{1}{B(\boldsymbol \alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1} 
39

40 41
    where :math:`\boldsymbol \alpha = {\alpha_1,...,\alpha_k}, k \ge 2` is 
    parameter, the normalizing constant is the multivariate beta function.
42

43 44 45 46 47 48
    .. math::

        B(\boldsymbol \alpha) = \frac{\prod_{i=1}^{k} \Gamma(\alpha_i)}{\Gamma(\alpha_0)}

    :math:`\alpha_0=\sum_{i=1}^{k} \alpha_i` is the sum of parameters, 
    :math:`\Gamma(\alpha)` is gamma function.
49 50

    Args:
51 52 53 54 55
        concentration (Tensor): "Concentration" parameter of dirichlet 
            distribution, also called :math:`\alpha`. When it's over one 
            dimension, the last axis denotes the parameter of distribution,
            ``event_shape=concentration.shape[-1:]`` , axes other than last are
            condsider batch dimensions with ``batch_shape=concentration.shape[:-1]`` .
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

    Examples:

        .. code-block:: python

            import paddle

            dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))

            print(dirichlet.entropy())
            # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            #        [-1.24434423])
            print(dirichlet.prob(paddle.to_tensor([.3, .5, .6])))
            # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            #        [10.80000114])

    """

    def __init__(self, concentration):
        if concentration.dim() < 1:
            raise ValueError(
                "`concentration` parameter must be at least one dimensional")

        self.concentration = concentration
        super(Dirichlet, self).__init__(concentration.shape[:-1],
                                        concentration.shape[-1:])

    @property
    def mean(self):
85
        """Mean of Dirichelt distribution.
86 87

        Returns:
88
            Mean value of distribution.
89 90 91 92 93
        """
        return self.concentration / self.concentration.sum(-1, keepdim=True)

    @property
    def variance(self):
94
        """Variance of Dirichlet distribution.
95 96

        Returns:
97
            Variance value of distribution.
98 99 100 101 102 103
        """
        concentration0 = self.concentration.sum(-1, keepdim=True)
        return (self.concentration * (concentration0 - self.concentration)) / (
            concentration0.pow(2) * (concentration0 + 1))

    def sample(self, shape=()):
104
        """Sample from dirichlet distribution.
105 106

        Args:
107
            shape (Sequence[int], optional): Sample shape. Defaults to empty tuple.
108 109 110 111 112
        """
        shape = shape if isinstance(shape, tuple) else tuple(shape)
        return _dirichlet(self.concentration.expand(self._extend_shape(shape)))

    def prob(self, value):
113
        """Probability density function(PDF) evaluated at value.
114 115

        Args:
116
            value (Tensor): Value to be evaluated.
117 118

        Returns:
119
            PDF evaluated at value.
120 121 122 123
        """
        return paddle.exp(self.log_prob(value))

    def log_prob(self, value):
124
        """Log of probability densitiy function.
125 126

        Args:
127
            value (Tensor): Value to be evaluated.
128 129 130 131 132 133
        """
        return ((paddle.log(value) * (self.concentration - 1.0)
                 ).sum(-1) + paddle.lgamma(self.concentration.sum(-1)) -
                paddle.lgamma(self.concentration).sum(-1))

    def entropy(self):
134
        """Entropy of Dirichlet distribution.
135 136

        Returns:
137
            Entropy of distribution.
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        """
        concentration0 = self.concentration.sum(-1)
        k = self.concentration.shape[-1]
        return (paddle.lgamma(self.concentration).sum(-1) -
                paddle.lgamma(concentration0) -
                (k - concentration0) * paddle.digamma(concentration0) - (
                    (self.concentration - 1.0
                     ) * paddle.digamma(self.concentration)).sum(-1))

    @property
    def _natural_parameters(self):
        return (self.concentration, )

    def _log_normalizer(self, x):
        return x.lgamma().sum(-1) - paddle.lgamma(x.sum(-1))


def _dirichlet(concentration, name=None):
156 157 158 159 160
    op_type = 'dirichlet'

    check_variable_and_dtype(concentration, 'concentration',
                             ['float32', 'float64'], op_type)

J
Jiabin Yang 已提交
161
    if _non_static_mode():
162 163 164 165 166 167 168 169 170 171 172 173
        return paddle._C_ops.dirichlet(concentration)

    else:
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=concentration.dtype)
        helper.append_op(
            type=op_type,
            inputs={"Alpha": concentration},
            outputs={'Out': out},
            attrs={})
        return out