distribution.py 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define the distribution functions
# __all__ = ['Categorical',
#            'MultivariateNormalDiag',
#            'Normal',
#            'sampling_id',
#            'Uniform']

import warnings

import numpy as np
25
import paddle
26
from paddle import _C_ops, _legacy_C_ops
27
from paddle.fluid.data_feeder import (check_variable_and_dtype, convert_dtype)
28
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
29
from paddle.fluid.layers import tensor
30 31 32 33


class Distribution(object):
    """
34
    The abstract base class for probability distributions. Functions are
35
    implemented in specific distributions.
36 37

    Args:
38
        batch_shape(Sequence[int], optional):  independent, not identically
39
            distributed draws, aka a "collection" or "bunch" of distributions.
40 41 42
        event_shape(Sequence[int], optional): the shape of a single
            draw from the distribution; it may be dependent across dimensions.
            For scalar distributions, the event shape is []. For n-dimension
43
            multivariate distribution, the event shape is [n].
44 45
    """

46 47 48 49 50 51 52
    def __init__(self, batch_shape=(), event_shape=()):

        self._batch_shape = batch_shape if isinstance(
            batch_shape, tuple) else tuple(batch_shape)
        self._event_shape = event_shape if isinstance(
            event_shape, tuple) else tuple(event_shape)

53 54
        super(Distribution, self).__init__()

55 56 57 58 59
    @property
    def batch_shape(self):
        """Returns batch shape of distribution

        Returns:
60
            Sequence[int]: batch shape
61 62 63 64 65 66 67 68
        """
        return self._batch_shape

    @property
    def event_shape(self):
        """Returns event shape of distribution

        Returns:
69
            Sequence[int]: event shape
70 71 72
        """
        return self._event_shape

73 74 75 76 77 78 79 80 81 82
    @property
    def mean(self):
        """Mean of distribution"""
        raise NotImplementedError

    @property
    def variance(self):
        """Variance of distribution"""
        raise NotImplementedError

83
    def sample(self, shape=()):
84 85 86
        """Sampling from the distribution."""
        raise NotImplementedError

87 88 89 90
    def rsample(self, shape=()):
        """reparameterized sample"""
        raise NotImplementedError

91 92 93 94 95 96 97 98
    def entropy(self):
        """The entropy of the distribution."""
        raise NotImplementedError

    def kl_divergence(self, other):
        """The KL-divergence between self distributions and other."""
        raise NotImplementedError

99 100 101 102 103 104
    def prob(self, value):
        """Probability density/mass function evaluated at value.

        Args:
            value (Tensor): value which will be evaluated
        """
105
        return self.log_prob(value).exp()
106

107 108 109 110 111
    def log_prob(self, value):
        """Log probability density/mass function."""
        raise NotImplementedError

    def probs(self, value):
112
        """Probability density/mass function.
113

114
        Note:
115 116

            This method will be deprecated in the future, please use `prob`
117 118
            instead.
        """
119 120
        raise NotImplementedError

121
    def _extend_shape(self, sample_shape):
122
        """compute shape of the sample
123 124 125 126 127 128 129 130 131

        Args:
            sample_shape (Tensor): sample shape

        Returns:
            Tensor: generated sample data shape
        """
        return sample_shape + self._batch_shape + self._event_shape

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    def _validate_args(self, *args):
        """
        Argument validation for distribution args
        Args:
            value (float, list, numpy.ndarray, Tensor)
        Raises
            ValueError: if one argument is Tensor, all arguments should be Tensor
        """
        is_variable = False
        is_number = False
        for arg in args:
            if isinstance(arg, tensor.Variable):
                is_variable = True
            else:
                is_number = True

        if is_variable and is_number:
            raise ValueError(
                'if one argument is Tensor, all arguments should be Tensor')

        return is_variable

    def _to_tensor(self, *args):
        """
        Argument convert args to Tensor

        Args:
            value (float, list, numpy.ndarray, Tensor)
        Returns:
            Tensor of args.
        """
        numpy_args = []
        variable_args = []
        tmp = 0.

        for arg in args:
            if isinstance(arg, float):
                arg = [arg]
            if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
                raise TypeError(
172 173
                    "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}"
                    .format(type(arg)))
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209

            arg_np = np.array(arg)
            arg_dtype = arg_np.dtype
            if str(arg_dtype) != 'float32':
                if str(arg_dtype) != 'float64':
                    # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
                    #  and converted to float64 later using "cast".
                    warnings.warn(
                        "data type of argument only support float32 and float64, your argument will be convert to float32."
                    )
                arg_np = arg_np.astype('float32')
            # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
            tmp = tmp + arg_np
            numpy_args.append(arg_np)

        dtype = tmp.dtype
        for arg in numpy_args:
            arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
            arg_variable = tensor.create_tensor(dtype=dtype)
            tensor.assign(arg_broadcasted, arg_variable)
            variable_args.append(arg_variable)

        return tuple(variable_args)

    def _check_values_dtype_in_probs(self, param, value):
        """
        Log_prob and probs methods have input ``value``, if value's dtype is different from param,
        convert value's dtype to be consistent with param's dtype.

        Args:
            param (Tensor): low and high in Uniform class, loc and scale in Normal class.
            value (Tensor): The input tensor.

        Returns:
            value (Tensor): Change value's dtype if value's dtype is different from param.
        """
J
Jiabin Yang 已提交
210
        if _non_static_mode():
211 212 213 214 215
            if value.dtype != param.dtype and convert_dtype(
                    value.dtype) in ['float32', 'float64']:
                warnings.warn(
                    "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
                )
216 217 218 219 220
                if in_dygraph_mode():
                    return _C_ops.cast(value, param.dtype)
                if _in_legacy_dygraph():
                    return _legacy_C_ops.cast(value, 'in_dtype', value.dtype,
                                              'out_dtype', param.dtype)
221 222 223 224 225 226 227 228 229 230
            return value

        check_variable_and_dtype(value, 'value', ['float32', 'float64'],
                                 'log_prob')
        if value.dtype != param.dtype:
            warnings.warn(
                "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
            )
            return tensor.cast(value, dtype=param.dtype)
        return value
231 232 233

    def _probs_to_logits(self, probs, is_binary=False):
        r"""
234 235 236
        Converts probabilities into logits. For the binary, probs denotes the
        probability of occurrence of the event indexed by `1`. For the
        multi-dimensional, values of last axis denote the probabilities of
237 238 239 240 241 242 243
        occurrence of each of the events.
        """
        return (paddle.log(probs) - paddle.log1p(-probs)) \
            if is_binary else paddle.log(probs)

    def _logits_to_probs(self, logits, is_binary=False):
        r"""
244 245
        Converts logits into probabilities. For the binary, each value denotes
        log odds, whereas for the multi-dimensional case, the values along the
246 247 248 249
        last dimension denote the log probabilities of the events.
        """
        return paddle.nn.functional.sigmoid(logits) \
            if is_binary else paddle.nn.functional.softmax(logits, axis=-1)