distribution.py 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define the distribution functions
# __all__ = ['Categorical',
#            'MultivariateNormalDiag',
#            'Normal',
#            'sampling_id',
#            'Uniform']

import warnings

import numpy as np
25
import paddle
26
from paddle import _C_ops, _legacy_C_ops
27 28 29 30 31 32
from paddle.fluid.data_feeder import check_variable_and_dtype, convert_dtype
from paddle.fluid.framework import (
    _non_static_mode,
    in_dygraph_mode,
    _in_legacy_dygraph,
)
33
from paddle.fluid.layers import tensor
34 35 36 37


class Distribution(object):
    """
38
    The abstract base class for probability distributions. Functions are
39
    implemented in specific distributions.
40 41

    Args:
42
        batch_shape(Sequence[int], optional):  independent, not identically
43
            distributed draws, aka a "collection" or "bunch" of distributions.
44 45 46
        event_shape(Sequence[int], optional): the shape of a single
            draw from the distribution; it may be dependent across dimensions.
            For scalar distributions, the event shape is []. For n-dimension
47
            multivariate distribution, the event shape is [n].
48 49
    """

50 51
    def __init__(self, batch_shape=(), event_shape=()):

52 53 54 55 56 57 58 59 60 61
        self._batch_shape = (
            batch_shape
            if isinstance(batch_shape, tuple)
            else tuple(batch_shape)
        )
        self._event_shape = (
            event_shape
            if isinstance(event_shape, tuple)
            else tuple(event_shape)
        )
62

63
        super().__init__()
64

65 66 67 68 69
    @property
    def batch_shape(self):
        """Returns batch shape of distribution

        Returns:
70
            Sequence[int]: batch shape
71 72 73 74 75 76 77 78
        """
        return self._batch_shape

    @property
    def event_shape(self):
        """Returns event shape of distribution

        Returns:
79
            Sequence[int]: event shape
80 81 82
        """
        return self._event_shape

83 84 85 86 87 88 89 90 91 92
    @property
    def mean(self):
        """Mean of distribution"""
        raise NotImplementedError

    @property
    def variance(self):
        """Variance of distribution"""
        raise NotImplementedError

93
    def sample(self, shape=()):
94 95 96
        """Sampling from the distribution."""
        raise NotImplementedError

97 98 99 100
    def rsample(self, shape=()):
        """reparameterized sample"""
        raise NotImplementedError

101 102 103 104 105 106 107 108
    def entropy(self):
        """The entropy of the distribution."""
        raise NotImplementedError

    def kl_divergence(self, other):
        """The KL-divergence between self distributions and other."""
        raise NotImplementedError

109 110 111 112 113 114
    def prob(self, value):
        """Probability density/mass function evaluated at value.

        Args:
            value (Tensor): value which will be evaluated
        """
115
        return self.log_prob(value).exp()
116

117 118 119 120 121
    def log_prob(self, value):
        """Log probability density/mass function."""
        raise NotImplementedError

    def probs(self, value):
122
        """Probability density/mass function.
123

124
        Note:
125 126

            This method will be deprecated in the future, please use `prob`
127 128
            instead.
        """
129 130
        raise NotImplementedError

131
    def _extend_shape(self, sample_shape):
132
        """compute shape of the sample
133 134 135 136 137 138 139 140 141

        Args:
            sample_shape (Tensor): sample shape

        Returns:
            Tensor: generated sample data shape
        """
        return sample_shape + self._batch_shape + self._event_shape

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    def _validate_args(self, *args):
        """
        Argument validation for distribution args
        Args:
            value (float, list, numpy.ndarray, Tensor)
        Raises
            ValueError: if one argument is Tensor, all arguments should be Tensor
        """
        is_variable = False
        is_number = False
        for arg in args:
            if isinstance(arg, tensor.Variable):
                is_variable = True
            else:
                is_number = True

        if is_variable and is_number:
            raise ValueError(
160 161
                'if one argument is Tensor, all arguments should be Tensor'
            )
162 163 164 165 166 167 168 169 170 171 172 173 174 175

        return is_variable

    def _to_tensor(self, *args):
        """
        Argument convert args to Tensor

        Args:
            value (float, list, numpy.ndarray, Tensor)
        Returns:
            Tensor of args.
        """
        numpy_args = []
        variable_args = []
176
        tmp = 0.0
177 178 179 180 181 182

        for arg in args:
            if isinstance(arg, float):
                arg = [arg]
            if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
                raise TypeError(
183 184 185 186
                    "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".format(
                        type(arg)
                    )
                )
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222

            arg_np = np.array(arg)
            arg_dtype = arg_np.dtype
            if str(arg_dtype) != 'float32':
                if str(arg_dtype) != 'float64':
                    # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
                    #  and converted to float64 later using "cast".
                    warnings.warn(
                        "data type of argument only support float32 and float64, your argument will be convert to float32."
                    )
                arg_np = arg_np.astype('float32')
            # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
            tmp = tmp + arg_np
            numpy_args.append(arg_np)

        dtype = tmp.dtype
        for arg in numpy_args:
            arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
            arg_variable = tensor.create_tensor(dtype=dtype)
            tensor.assign(arg_broadcasted, arg_variable)
            variable_args.append(arg_variable)

        return tuple(variable_args)

    def _check_values_dtype_in_probs(self, param, value):
        """
        Log_prob and probs methods have input ``value``, if value's dtype is different from param,
        convert value's dtype to be consistent with param's dtype.

        Args:
            param (Tensor): low and high in Uniform class, loc and scale in Normal class.
            value (Tensor): The input tensor.

        Returns:
            value (Tensor): Change value's dtype if value's dtype is different from param.
        """
J
Jiabin Yang 已提交
223
        if _non_static_mode():
224 225 226 227
            if value.dtype != param.dtype and convert_dtype(value.dtype) in [
                'float32',
                'float64',
            ]:
228 229 230
                warnings.warn(
                    "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
                )
231 232 233
                if in_dygraph_mode():
                    return _C_ops.cast(value, param.dtype)
                if _in_legacy_dygraph():
234 235 236
                    return _legacy_C_ops.cast(
                        value, 'in_dtype', value.dtype, 'out_dtype', param.dtype
                    )
237 238
            return value

239 240 241
        check_variable_and_dtype(
            value, 'value', ['float32', 'float64'], 'log_prob'
        )
242 243 244 245 246 247
        if value.dtype != param.dtype:
            warnings.warn(
                "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
            )
            return tensor.cast(value, dtype=param.dtype)
        return value
248 249 250

    def _probs_to_logits(self, probs, is_binary=False):
        r"""
251 252 253
        Converts probabilities into logits. For the binary, probs denotes the
        probability of occurrence of the event indexed by `1`. For the
        multi-dimensional, values of last axis denote the probabilities of
254 255
        occurrence of each of the events.
        """
256 257 258 259 260
        return (
            (paddle.log(probs) - paddle.log1p(-probs))
            if is_binary
            else paddle.log(probs)
        )
261 262 263

    def _logits_to_probs(self, logits, is_binary=False):
        r"""
264 265
        Converts logits into probabilities. For the binary, each value denotes
        log odds, whereas for the multi-dimensional case, the values along the
266 267
        last dimension denote the log probabilities of the events.
        """
268 269 270 271 272
        return (
            paddle.nn.functional.sigmoid(logits)
            if is_binary
            else paddle.nn.functional.softmax(logits, axis=-1)
        )