distribution.py 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define the distribution functions
# __all__ = ['Categorical',
#            'MultivariateNormalDiag',
#            'Normal',
#            'sampling_id',
#            'Uniform']

import math
import warnings

import numpy as np
26
import paddle
27
from paddle import _C_ops, _legacy_C_ops
28 29 30
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
                                      check_variable_and_dtype, convert_dtype)
31
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
32 33 34 35
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
                                 elementwise_mul, elementwise_sub, nn, ops,
                                 tensor)
from paddle.tensor import arange, concat, gather_nd, multinomial
36 37 38 39


class Distribution(object):
    """
40
    The abstract base class for probability distributions. Functions are
41
    implemented in specific distributions.
42 43

    Args:
44
        batch_shape(Sequence[int], optional):  independent, not identically
45
            distributed draws, aka a "collection" or "bunch" of distributions.
46 47 48
        event_shape(Sequence[int], optional): the shape of a single
            draw from the distribution; it may be dependent across dimensions.
            For scalar distributions, the event shape is []. For n-dimension
49
            multivariate distribution, the event shape is [n].
50 51
    """

52 53 54 55 56 57 58
    def __init__(self, batch_shape=(), event_shape=()):

        self._batch_shape = batch_shape if isinstance(
            batch_shape, tuple) else tuple(batch_shape)
        self._event_shape = event_shape if isinstance(
            event_shape, tuple) else tuple(event_shape)

59 60
        super(Distribution, self).__init__()

61 62 63 64 65
    @property
    def batch_shape(self):
        """Returns batch shape of distribution

        Returns:
66
            Sequence[int]: batch shape
67 68 69 70 71 72 73 74
        """
        return self._batch_shape

    @property
    def event_shape(self):
        """Returns event shape of distribution

        Returns:
75
            Sequence[int]: event shape
76 77 78
        """
        return self._event_shape

79 80 81 82 83 84 85 86 87 88
    @property
    def mean(self):
        """Mean of distribution"""
        raise NotImplementedError

    @property
    def variance(self):
        """Variance of distribution"""
        raise NotImplementedError

89
    def sample(self, shape=()):
90 91 92
        """Sampling from the distribution."""
        raise NotImplementedError

93 94 95 96
    def rsample(self, shape=()):
        """reparameterized sample"""
        raise NotImplementedError

97 98 99 100 101 102 103 104
    def entropy(self):
        """The entropy of the distribution."""
        raise NotImplementedError

    def kl_divergence(self, other):
        """The KL-divergence between self distributions and other."""
        raise NotImplementedError

105 106 107 108 109 110
    def prob(self, value):
        """Probability density/mass function evaluated at value.

        Args:
            value (Tensor): value which will be evaluated
        """
111
        return self.log_prob(value).exp()
112

113 114 115 116 117
    def log_prob(self, value):
        """Log probability density/mass function."""
        raise NotImplementedError

    def probs(self, value):
118
        """Probability density/mass function.
119

120
        Note:
121 122

            This method will be deprecated in the future, please use `prob`
123 124
            instead.
        """
125 126
        raise NotImplementedError

127
    def _extend_shape(self, sample_shape):
128
        """compute shape of the sample
129 130 131 132 133 134 135 136 137

        Args:
            sample_shape (Tensor): sample shape

        Returns:
            Tensor: generated sample data shape
        """
        return sample_shape + self._batch_shape + self._event_shape

138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    def _validate_args(self, *args):
        """
        Argument validation for distribution args
        Args:
            value (float, list, numpy.ndarray, Tensor)
        Raises
            ValueError: if one argument is Tensor, all arguments should be Tensor
        """
        is_variable = False
        is_number = False
        for arg in args:
            if isinstance(arg, tensor.Variable):
                is_variable = True
            else:
                is_number = True

        if is_variable and is_number:
            raise ValueError(
                'if one argument is Tensor, all arguments should be Tensor')

        return is_variable

    def _to_tensor(self, *args):
        """
        Argument convert args to Tensor

        Args:
            value (float, list, numpy.ndarray, Tensor)
        Returns:
            Tensor of args.
        """
        numpy_args = []
        variable_args = []
        tmp = 0.

        for arg in args:
            if isinstance(arg, float):
                arg = [arg]
            if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
                raise TypeError(
178 179
                    "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}"
                    .format(type(arg)))
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215

            arg_np = np.array(arg)
            arg_dtype = arg_np.dtype
            if str(arg_dtype) != 'float32':
                if str(arg_dtype) != 'float64':
                    # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
                    #  and converted to float64 later using "cast".
                    warnings.warn(
                        "data type of argument only support float32 and float64, your argument will be convert to float32."
                    )
                arg_np = arg_np.astype('float32')
            # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
            tmp = tmp + arg_np
            numpy_args.append(arg_np)

        dtype = tmp.dtype
        for arg in numpy_args:
            arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
            arg_variable = tensor.create_tensor(dtype=dtype)
            tensor.assign(arg_broadcasted, arg_variable)
            variable_args.append(arg_variable)

        return tuple(variable_args)

    def _check_values_dtype_in_probs(self, param, value):
        """
        Log_prob and probs methods have input ``value``, if value's dtype is different from param,
        convert value's dtype to be consistent with param's dtype.

        Args:
            param (Tensor): low and high in Uniform class, loc and scale in Normal class.
            value (Tensor): The input tensor.

        Returns:
            value (Tensor): Change value's dtype if value's dtype is different from param.
        """
J
Jiabin Yang 已提交
216
        if _non_static_mode():
217 218 219 220 221
            if value.dtype != param.dtype and convert_dtype(
                    value.dtype) in ['float32', 'float64']:
                warnings.warn(
                    "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
                )
222 223 224 225 226
                if in_dygraph_mode():
                    return _C_ops.cast(value, param.dtype)
                if _in_legacy_dygraph():
                    return _legacy_C_ops.cast(value, 'in_dtype', value.dtype,
                                              'out_dtype', param.dtype)
227 228 229 230 231 232 233 234 235 236
            return value

        check_variable_and_dtype(value, 'value', ['float32', 'float64'],
                                 'log_prob')
        if value.dtype != param.dtype:
            warnings.warn(
                "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
            )
            return tensor.cast(value, dtype=param.dtype)
        return value
237 238 239

    def _probs_to_logits(self, probs, is_binary=False):
        r"""
240 241 242
        Converts probabilities into logits. For the binary, probs denotes the
        probability of occurrence of the event indexed by `1`. For the
        multi-dimensional, values of last axis denote the probabilities of
243 244 245 246 247 248 249
        occurrence of each of the events.
        """
        return (paddle.log(probs) - paddle.log1p(-probs)) \
            if is_binary else paddle.log(probs)

    def _logits_to_probs(self, logits, is_binary=False):
        r"""
250 251
        Converts logits into probabilities. For the binary, each value denotes
        log odds, whereas for the multi-dimensional case, the values along the
252 253 254 255
        last dimension denote the log probabilities of the events.
        """
        return paddle.nn.functional.sigmoid(logits) \
            if is_binary else paddle.nn.functional.softmax(logits, axis=-1)