distribution.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define the distribution functions
# __all__ = ['Categorical',
#            'MultivariateNormalDiag',
#            'Normal',
#            'sampling_id',
#            'Uniform']

import warnings

import numpy as np
25

26
import paddle
27
from paddle import _C_ops
28
from paddle.fluid.data_feeder import check_variable_and_dtype, convert_dtype
29
from paddle.fluid.framework import in_dygraph_mode
30
from paddle.fluid.layers import tensor
31 32


33
class Distribution:
34
    """
35
    The abstract base class for probability distributions. Functions are
36
    implemented in specific distributions.
37 38

    Args:
39
        batch_shape(Sequence[int], optional):  independent, not identically
40
            distributed draws, aka a "collection" or "bunch" of distributions.
41 42 43
        event_shape(Sequence[int], optional): the shape of a single
            draw from the distribution; it may be dependent across dimensions.
            For scalar distributions, the event shape is []. For n-dimension
44
            multivariate distribution, the event shape is [n].
45 46
    """

47 48
    def __init__(self, batch_shape=(), event_shape=()):

49 50 51 52 53 54 55 56 57 58
        self._batch_shape = (
            batch_shape
            if isinstance(batch_shape, tuple)
            else tuple(batch_shape)
        )
        self._event_shape = (
            event_shape
            if isinstance(event_shape, tuple)
            else tuple(event_shape)
        )
59

60
        super().__init__()
61

62 63 64 65 66
    @property
    def batch_shape(self):
        """Returns batch shape of distribution

        Returns:
67
            Sequence[int]: batch shape
68 69 70 71 72 73 74 75
        """
        return self._batch_shape

    @property
    def event_shape(self):
        """Returns event shape of distribution

        Returns:
76
            Sequence[int]: event shape
77 78 79
        """
        return self._event_shape

80 81 82 83 84 85 86 87 88 89
    @property
    def mean(self):
        """Mean of distribution"""
        raise NotImplementedError

    @property
    def variance(self):
        """Variance of distribution"""
        raise NotImplementedError

90
    def sample(self, shape=()):
91 92 93
        """Sampling from the distribution."""
        raise NotImplementedError

94 95 96 97
    def rsample(self, shape=()):
        """reparameterized sample"""
        raise NotImplementedError

98 99 100 101 102 103 104 105
    def entropy(self):
        """The entropy of the distribution."""
        raise NotImplementedError

    def kl_divergence(self, other):
        """The KL-divergence between self distributions and other."""
        raise NotImplementedError

106 107 108 109 110 111
    def prob(self, value):
        """Probability density/mass function evaluated at value.

        Args:
            value (Tensor): value which will be evaluated
        """
112
        return self.log_prob(value).exp()
113

114 115 116 117 118
    def log_prob(self, value):
        """Log probability density/mass function."""
        raise NotImplementedError

    def probs(self, value):
119
        """Probability density/mass function.
120

121
        Note:
122 123

            This method will be deprecated in the future, please use `prob`
124 125
            instead.
        """
126 127
        raise NotImplementedError

128
    def _extend_shape(self, sample_shape):
129
        """compute shape of the sample
130 131 132 133 134 135 136 137 138

        Args:
            sample_shape (Tensor): sample shape

        Returns:
            Tensor: generated sample data shape
        """
        return sample_shape + self._batch_shape + self._event_shape

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    def _validate_args(self, *args):
        """
        Argument validation for distribution args
        Args:
            value (float, list, numpy.ndarray, Tensor)
        Raises
            ValueError: if one argument is Tensor, all arguments should be Tensor
        """
        is_variable = False
        is_number = False
        for arg in args:
            if isinstance(arg, tensor.Variable):
                is_variable = True
            else:
                is_number = True

        if is_variable and is_number:
            raise ValueError(
157 158
                'if one argument is Tensor, all arguments should be Tensor'
            )
159 160 161 162 163 164 165 166 167 168 169 170 171 172

        return is_variable

    def _to_tensor(self, *args):
        """
        Argument convert args to Tensor

        Args:
            value (float, list, numpy.ndarray, Tensor)
        Returns:
            Tensor of args.
        """
        numpy_args = []
        variable_args = []
173
        tmp = 0.0
174 175 176 177 178 179

        for arg in args:
            if isinstance(arg, float):
                arg = [arg]
            if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
                raise TypeError(
180 181 182 183
                    "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".format(
                        type(arg)
                    )
                )
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201

            arg_np = np.array(arg)
            arg_dtype = arg_np.dtype
            if str(arg_dtype) != 'float32':
                if str(arg_dtype) != 'float64':
                    # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
                    #  and converted to float64 later using "cast".
                    warnings.warn(
                        "data type of argument only support float32 and float64, your argument will be convert to float32."
                    )
                arg_np = arg_np.astype('float32')
            # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
            tmp = tmp + arg_np
            numpy_args.append(arg_np)

        dtype = tmp.dtype
        for arg in numpy_args:
            arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
202
            arg_variable = paddle.tensor.create_tensor(dtype=dtype)
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
            tensor.assign(arg_broadcasted, arg_variable)
            variable_args.append(arg_variable)

        return tuple(variable_args)

    def _check_values_dtype_in_probs(self, param, value):
        """
        Log_prob and probs methods have input ``value``, if value's dtype is different from param,
        convert value's dtype to be consistent with param's dtype.

        Args:
            param (Tensor): low and high in Uniform class, loc and scale in Normal class.
            value (Tensor): The input tensor.

        Returns:
            value (Tensor): Change value's dtype if value's dtype is different from param.
        """
220
        if in_dygraph_mode():
221 222 223 224
            if value.dtype != param.dtype and convert_dtype(value.dtype) in [
                'float32',
                'float64',
            ]:
225 226 227
                warnings.warn(
                    "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
                )
228
                return _C_ops.cast(value, param.dtype)
229 230
            return value

231 232 233
        check_variable_and_dtype(
            value, 'value', ['float32', 'float64'], 'log_prob'
        )
234 235 236 237 238 239
        if value.dtype != param.dtype:
            warnings.warn(
                "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
            )
            return tensor.cast(value, dtype=param.dtype)
        return value
240 241 242

    def _probs_to_logits(self, probs, is_binary=False):
        r"""
243 244 245
        Converts probabilities into logits. For the binary, probs denotes the
        probability of occurrence of the event indexed by `1`. For the
        multi-dimensional, values of last axis denote the probabilities of
246 247
        occurrence of each of the events.
        """
248 249 250 251 252
        return (
            (paddle.log(probs) - paddle.log1p(-probs))
            if is_binary
            else paddle.log(probs)
        )
253 254 255

    def _logits_to_probs(self, logits, is_binary=False):
        r"""
256 257
        Converts logits into probabilities. For the binary, each value denotes
        log odds, whereas for the multi-dimensional case, the values along the
258 259
        last dimension denote the log probabilities of the events.
        """
260 261 262 263 264
        return (
            paddle.nn.functional.sigmoid(logits)
            if is_binary
            else paddle.nn.functional.softmax(logits, axis=-1)
        )