distribution.py 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define the distribution functions
# __all__ = ['Categorical',
#            'MultivariateNormalDiag',
#            'Normal',
#            'sampling_id',
#            'Uniform']

from __future__ import print_function

import math
import warnings

import numpy as np
28
import paddle
29
from paddle import _C_ops, _legacy_C_ops
30 31 32
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
                                      check_variable_and_dtype, convert_dtype)
33
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
34 35 36 37
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
                                 elementwise_mul, elementwise_sub, nn, ops,
                                 tensor)
from paddle.tensor import arange, concat, gather_nd, multinomial
38 39 40 41


class Distribution(object):
    """
42
    The abstract base class for probability distributions. Functions are
43
    implemented in specific distributions.
44 45

    Args:
46
        batch_shape(Sequence[int], optional):  independent, not identically
47
            distributed draws, aka a "collection" or "bunch" of distributions.
48 49 50
        event_shape(Sequence[int], optional): the shape of a single
            draw from the distribution; it may be dependent across dimensions.
            For scalar distributions, the event shape is []. For n-dimension
51
            multivariate distribution, the event shape is [n].
52 53
    """

54 55 56 57 58 59 60
    def __init__(self, batch_shape=(), event_shape=()):

        self._batch_shape = batch_shape if isinstance(
            batch_shape, tuple) else tuple(batch_shape)
        self._event_shape = event_shape if isinstance(
            event_shape, tuple) else tuple(event_shape)

61 62
        super(Distribution, self).__init__()

63 64 65 66 67
    @property
    def batch_shape(self):
        """Returns batch shape of distribution

        Returns:
68
            Sequence[int]: batch shape
69 70 71 72 73 74 75 76
        """
        return self._batch_shape

    @property
    def event_shape(self):
        """Returns event shape of distribution

        Returns:
77
            Sequence[int]: event shape
78 79 80
        """
        return self._event_shape

81 82 83 84 85 86 87 88 89 90
    @property
    def mean(self):
        """Mean of distribution"""
        raise NotImplementedError

    @property
    def variance(self):
        """Variance of distribution"""
        raise NotImplementedError

91
    def sample(self, shape=()):
92 93 94
        """Sampling from the distribution."""
        raise NotImplementedError

95 96 97 98
    def rsample(self, shape=()):
        """reparameterized sample"""
        raise NotImplementedError

99 100 101 102 103 104 105 106
    def entropy(self):
        """The entropy of the distribution."""
        raise NotImplementedError

    def kl_divergence(self, other):
        """The KL-divergence between self distributions and other."""
        raise NotImplementedError

107 108 109 110 111 112
    def prob(self, value):
        """Probability density/mass function evaluated at value.

        Args:
            value (Tensor): value which will be evaluated
        """
113
        return self.log_prob(value).exp()
114

115 116 117 118 119
    def log_prob(self, value):
        """Log probability density/mass function."""
        raise NotImplementedError

    def probs(self, value):
120
        """Probability density/mass function.
121

122
        Note:
123 124

            This method will be deprecated in the future, please use `prob`
125 126
            instead.
        """
127 128
        raise NotImplementedError

129
    def _extend_shape(self, sample_shape):
130
        """compute shape of the sample
131 132 133 134 135 136 137 138 139

        Args:
            sample_shape (Tensor): sample shape

        Returns:
            Tensor: generated sample data shape
        """
        return sample_shape + self._batch_shape + self._event_shape

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    def _validate_args(self, *args):
        """
        Argument validation for distribution args
        Args:
            value (float, list, numpy.ndarray, Tensor)
        Raises
            ValueError: if one argument is Tensor, all arguments should be Tensor
        """
        is_variable = False
        is_number = False
        for arg in args:
            if isinstance(arg, tensor.Variable):
                is_variable = True
            else:
                is_number = True

        if is_variable and is_number:
            raise ValueError(
                'if one argument is Tensor, all arguments should be Tensor')

        return is_variable

    def _to_tensor(self, *args):
        """
        Argument convert args to Tensor

        Args:
            value (float, list, numpy.ndarray, Tensor)
        Returns:
            Tensor of args.
        """
        numpy_args = []
        variable_args = []
        tmp = 0.

        for arg in args:
            if isinstance(arg, float):
                arg = [arg]
            if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
                raise TypeError(
180 181
                    "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}"
                    .format(type(arg)))
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217

            arg_np = np.array(arg)
            arg_dtype = arg_np.dtype
            if str(arg_dtype) != 'float32':
                if str(arg_dtype) != 'float64':
                    # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
                    #  and converted to float64 later using "cast".
                    warnings.warn(
                        "data type of argument only support float32 and float64, your argument will be convert to float32."
                    )
                arg_np = arg_np.astype('float32')
            # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
            tmp = tmp + arg_np
            numpy_args.append(arg_np)

        dtype = tmp.dtype
        for arg in numpy_args:
            arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
            arg_variable = tensor.create_tensor(dtype=dtype)
            tensor.assign(arg_broadcasted, arg_variable)
            variable_args.append(arg_variable)

        return tuple(variable_args)

    def _check_values_dtype_in_probs(self, param, value):
        """
        Log_prob and probs methods have input ``value``, if value's dtype is different from param,
        convert value's dtype to be consistent with param's dtype.

        Args:
            param (Tensor): low and high in Uniform class, loc and scale in Normal class.
            value (Tensor): The input tensor.

        Returns:
            value (Tensor): Change value's dtype if value's dtype is different from param.
        """
J
Jiabin Yang 已提交
218
        if _non_static_mode():
219 220 221 222 223
            if value.dtype != param.dtype and convert_dtype(
                    value.dtype) in ['float32', 'float64']:
                warnings.warn(
                    "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
                )
224 225 226 227 228
                if in_dygraph_mode():
                    return _C_ops.cast(value, param.dtype)
                if _in_legacy_dygraph():
                    return _legacy_C_ops.cast(value, 'in_dtype', value.dtype,
                                              'out_dtype', param.dtype)
229 230 231 232 233 234 235 236 237 238
            return value

        check_variable_and_dtype(value, 'value', ['float32', 'float64'],
                                 'log_prob')
        if value.dtype != param.dtype:
            warnings.warn(
                "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
            )
            return tensor.cast(value, dtype=param.dtype)
        return value
239 240 241

    def _probs_to_logits(self, probs, is_binary=False):
        r"""
242 243 244
        Converts probabilities into logits. For the binary, probs denotes the
        probability of occurrence of the event indexed by `1`. For the
        multi-dimensional, values of last axis denote the probabilities of
245 246 247 248 249 250 251
        occurrence of each of the events.
        """
        return (paddle.log(probs) - paddle.log1p(-probs)) \
            if is_binary else paddle.log(probs)

    def _logits_to_probs(self, logits, is_binary=False):
        r"""
252 253
        Converts logits into probabilities. For the binary, each value denotes
        log odds, whereas for the multi-dimensional case, the values along the
254 255 256 257
        last dimension denote the log probabilities of the events.
        """
        return paddle.nn.functional.sigmoid(logits) \
            if is_binary else paddle.nn.functional.softmax(logits, axis=-1)