distribution.py 8.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define the distribution functions
# __all__ = ['Categorical',
#            'MultivariateNormalDiag',
#            'Normal',
#            'sampling_id',
#            'Uniform']

from __future__ import print_function

import math
import warnings

import numpy as np
28
import paddle
29 30 31 32 33
from paddle import _C_ops

from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
                                 check_variable_and_dtype, convert_dtype)
J
Jiabin Yang 已提交
34
from ..fluid.framework import _non_static_mode
35 36 37 38 39 40 41 42 43
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
                            elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial


class Distribution(object):
    """
    The abstract base class for probability distributions. Functions are 
    implemented in specific distributions.
44 45 46 47 48 49 50 51

    Args:
        batch_shape(Sequence[int], optional):  independent, not identically 
            distributed draws, aka a "collection" or "bunch" of distributions.
        event_shape(Sequence[int], optional): the shape of a single 
            draw from the distribution; it may be dependent across dimensions. 
            For scalar distributions, the event shape is []. For n-dimension 
            multivariate distribution, the event shape is [n].
52 53
    """

54 55 56 57 58 59 60
    def __init__(self, batch_shape=(), event_shape=()):

        self._batch_shape = batch_shape if isinstance(
            batch_shape, tuple) else tuple(batch_shape)
        self._event_shape = event_shape if isinstance(
            event_shape, tuple) else tuple(event_shape)

61 62
        super(Distribution, self).__init__()

63 64 65 66 67
    @property
    def batch_shape(self):
        """Returns batch shape of distribution

        Returns:
68
            Sequence[int]: batch shape
69 70 71 72 73 74 75 76
        """
        return self._batch_shape

    @property
    def event_shape(self):
        """Returns event shape of distribution

        Returns:
77
            Sequence[int]: event shape
78 79 80 81
        """
        return self._event_shape

    def sample(self, shape=()):
82 83 84 85 86 87 88 89 90 91 92
        """Sampling from the distribution."""
        raise NotImplementedError

    def entropy(self):
        """The entropy of the distribution."""
        raise NotImplementedError

    def kl_divergence(self, other):
        """The KL-divergence between self distributions and other."""
        raise NotImplementedError

93 94 95 96 97 98 99 100
    def prob(self, value):
        """Probability density/mass function evaluated at value.

        Args:
            value (Tensor): value which will be evaluated
        """
        raise NotImplementedError

101 102 103 104 105
    def log_prob(self, value):
        """Log probability density/mass function."""
        raise NotImplementedError

    def probs(self, value):
106 107 108 109 110 111 112
        """Probability density/mass function.
        
        .. note:: 
        
            This method will be deprecated in the future, please use `prob` 
            instead.
        """
113 114
        raise NotImplementedError

115 116 117 118 119 120 121 122 123 124 125
    def _extend_shape(self, sample_shape):
        """compute shape of the sample 

        Args:
            sample_shape (Tensor): sample shape

        Returns:
            Tensor: generated sample data shape
        """
        return sample_shape + self._batch_shape + self._event_shape

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
    def _validate_args(self, *args):
        """
        Argument validation for distribution args
        Args:
            value (float, list, numpy.ndarray, Tensor)
        Raises
            ValueError: if one argument is Tensor, all arguments should be Tensor
        """
        is_variable = False
        is_number = False
        for arg in args:
            if isinstance(arg, tensor.Variable):
                is_variable = True
            else:
                is_number = True

        if is_variable and is_number:
            raise ValueError(
                'if one argument is Tensor, all arguments should be Tensor')

        return is_variable

    def _to_tensor(self, *args):
        """
        Argument convert args to Tensor

        Args:
            value (float, list, numpy.ndarray, Tensor)
        Returns:
            Tensor of args.
        """
        numpy_args = []
        variable_args = []
        tmp = 0.

        for arg in args:
            if isinstance(arg, float):
                arg = [arg]
            if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
                raise TypeError(
                    "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".
                    format(type(arg)))

            arg_np = np.array(arg)
            arg_dtype = arg_np.dtype
            if str(arg_dtype) != 'float32':
                if str(arg_dtype) != 'float64':
                    # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
                    #  and converted to float64 later using "cast".
                    warnings.warn(
                        "data type of argument only support float32 and float64, your argument will be convert to float32."
                    )
                arg_np = arg_np.astype('float32')
            # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
            tmp = tmp + arg_np
            numpy_args.append(arg_np)

        dtype = tmp.dtype
        for arg in numpy_args:
            arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
            arg_variable = tensor.create_tensor(dtype=dtype)
            tensor.assign(arg_broadcasted, arg_variable)
            variable_args.append(arg_variable)

        return tuple(variable_args)

    def _check_values_dtype_in_probs(self, param, value):
        """
        Log_prob and probs methods have input ``value``, if value's dtype is different from param,
        convert value's dtype to be consistent with param's dtype.

        Args:
            param (Tensor): low and high in Uniform class, loc and scale in Normal class.
            value (Tensor): The input tensor.

        Returns:
            value (Tensor): Change value's dtype if value's dtype is different from param.
        """
J
Jiabin Yang 已提交
204
        if _non_static_mode():
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
            if value.dtype != param.dtype and convert_dtype(
                    value.dtype) in ['float32', 'float64']:
                warnings.warn(
                    "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
                )
                return _C_ops.cast(value, 'in_dtype', value.dtype, 'out_dtype',
                                   param.dtype)
            return value

        check_variable_and_dtype(value, 'value', ['float32', 'float64'],
                                 'log_prob')
        if value.dtype != param.dtype:
            warnings.warn(
                "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
            )
            return tensor.cast(value, dtype=param.dtype)
        return value
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240

    def _probs_to_logits(self, probs, is_binary=False):
        r"""
        Converts probabilities into logits. For the binary, probs denotes the 
        probability of occurrence of the event indexed by `1`. For the 
        multi-dimensional, values of last axis denote the probabilities of 
        occurrence of each of the events.
        """
        return (paddle.log(probs) - paddle.log1p(-probs)) \
            if is_binary else paddle.log(probs)

    def _logits_to_probs(self, logits, is_binary=False):
        r"""
        Converts logits into probabilities. For the binary, each value denotes 
        log odds, whereas for the multi-dimensional case, the values along the 
        last dimension denote the log probabilities of the events.
        """
        return paddle.nn.functional.sigmoid(logits) \
            if is_binary else paddle.nn.functional.softmax(logits, axis=-1)