From a3e6f18cf7e5f3cfe5ffaeefcf2eabace109778f Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Thu, 23 Dec 2021 19:29:44 +0800 Subject: [PATCH] move distribution.py into distribution package and split into different file for better scalability (#38047) --- python/paddle/distribution.py | 968 ------------ python/paddle/distribution/__init__.py | 20 + python/paddle/distribution/categorical.py | 357 +++++ python/paddle/distribution/distribution.py | 163 +++ python/paddle/distribution/normal.py | 279 ++++ python/paddle/distribution/uniform.py | 252 ++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/distribution/CMakeLists.txt | 6 + .../distribution/test_distribution.py | 130 ++ .../test_distribution_categorical.py | 441 ++++++ .../distribution/test_distribution_normal.py | 456 ++++++ .../distribution/test_distribution_uniform.py | 345 +++++ .../tests/unittests/test_distribution.py | 1299 ----------------- python/setup.py.in | 1 + 14 files changed, 2451 insertions(+), 2267 deletions(-) delete mode 100644 python/paddle/distribution.py create mode 100644 python/paddle/distribution/__init__.py create mode 100644 python/paddle/distribution/categorical.py create mode 100644 python/paddle/distribution/distribution.py create mode 100644 python/paddle/distribution/normal.py create mode 100644 python/paddle/distribution/uniform.py create mode 100644 python/paddle/fluid/tests/unittests/distribution/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/distribution/test_distribution.py create mode 100644 python/paddle/fluid/tests/unittests/distribution/test_distribution_categorical.py create mode 100644 python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py create mode 100644 python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py delete mode 100644 python/paddle/fluid/tests/unittests/test_distribution.py diff --git a/python/paddle/distribution.py b/python/paddle/distribution.py deleted file mode 100644 index cf198eab1e8..00000000000 --- a/python/paddle/distribution.py +++ /dev/null @@ -1,968 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO: define the distribution functions -# __all__ = ['Categorical', -# 'MultivariateNormalDiag', -# 'Normal', -# 'sampling_id', -# 'Uniform'] - -from __future__ import print_function - -from .fluid.layers import control_flow -from .fluid.layers import tensor -from .fluid.layers import ops -from .fluid.layers import nn -from .fluid.layers import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub -from .fluid import core -from .fluid.framework import in_dygraph_mode -from .tensor import arange, gather_nd, concat, multinomial -import math -import numpy as np -import warnings - -from .fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype -from paddle import _C_ops - -__all__ = ['Distribution', 'Uniform', 'Normal', 'Categorical'] - - -class Distribution(object): - """ - The abstract base class for probability distributions. Functions are - implemented in specific distributions. - """ - - def __init__(self): - super(Distribution, self).__init__() - - def sample(self): - """Sampling from the distribution.""" - raise NotImplementedError - - def entropy(self): - """The entropy of the distribution.""" - raise NotImplementedError - - def kl_divergence(self, other): - """The KL-divergence between self distributions and other.""" - raise NotImplementedError - - def log_prob(self, value): - """Log probability density/mass function.""" - raise NotImplementedError - - def probs(self, value): - """Probability density/mass function.""" - raise NotImplementedError - - def _validate_args(self, *args): - """ - Argument validation for distribution args - Args: - value (float, list, numpy.ndarray, Tensor) - Raises - ValueError: if one argument is Tensor, all arguments should be Tensor - """ - is_variable = False - is_number = False - for arg in args: - if isinstance(arg, tensor.Variable): - is_variable = True - else: - is_number = True - - if is_variable and is_number: - raise ValueError( - 'if one argument is Tensor, all arguments should be Tensor') - - return is_variable - - def _to_tensor(self, *args): - """ - Argument convert args to Tensor - - Args: - value (float, list, numpy.ndarray, Tensor) - Returns: - Tensor of args. - """ - numpy_args = [] - variable_args = [] - tmp = 0. - - for arg in args: - if isinstance(arg, float): - arg = [arg] - if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)): - raise TypeError( - "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}". - format(type(arg))) - - arg_np = np.array(arg) - arg_dtype = arg_np.dtype - if str(arg_dtype) != 'float32': - if str(arg_dtype) != 'float64': - # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated - # and converted to float64 later using "cast". - warnings.warn( - "data type of argument only support float32 and float64, your argument will be convert to float32." - ) - arg_np = arg_np.astype('float32') - # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape. - tmp = tmp + arg_np - numpy_args.append(arg_np) - - dtype = tmp.dtype - for arg in numpy_args: - arg_broadcasted, _ = np.broadcast_arrays(arg, tmp) - arg_variable = tensor.create_tensor(dtype=dtype) - tensor.assign(arg_broadcasted, arg_variable) - variable_args.append(arg_variable) - - return tuple(variable_args) - - def _check_values_dtype_in_probs(self, param, value): - """ - Log_prob and probs methods have input ``value``, if value's dtype is different from param, - convert value's dtype to be consistent with param's dtype. - - Args: - param (Tensor): low and high in Uniform class, loc and scale in Normal class. - value (Tensor): The input tensor. - - Returns: - value (Tensor): Change value's dtype if value's dtype is different from param. - """ - if in_dygraph_mode(): - if value.dtype != param.dtype and convert_dtype( - value.dtype) in ['float32', 'float64']: - warnings.warn( - "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted." - ) - return _C_ops.cast(value, 'in_dtype', value.dtype, 'out_dtype', - param.dtype) - return value - - check_variable_and_dtype(value, 'value', ['float32', 'float64'], - 'log_prob') - if value.dtype != param.dtype: - warnings.warn( - "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted." - ) - return tensor.cast(value, dtype=param.dtype) - return value - - -class Uniform(Distribution): - r"""Uniform distribution with `low` and `high` parameters. - - Mathematical Details - - The probability density function (pdf) is - - .. math:: - - pdf(x; a, b) = \\frac{1}{Z}, \ a <=x 1: - sample_shape = shape + logits_shape[:-1] - logits = nn.reshape(self.logits, - [np.prod(logits_shape[:-1]), logits_shape[-1]]) - else: - sample_shape = shape - logits = self.logits - - sample_index = multinomial(logits, num_samples, True) - return nn.reshape(sample_index, sample_shape, name=name) - - def kl_divergence(self, other): - """The KL-divergence between two Categorical distributions. - - Args: - other (Categorical): instance of Categorical. The data type is float32. - - Returns: - Tensor: kl-divergence between two Categorical distributions. - - Examples: - .. code-block:: python - - import paddle - from paddle.distribution import Categorical - - paddle.seed(100) # on CPU device - x = paddle.rand([6]) - print(x) - # [0.5535528 0.20714243 0.01162981 - # 0.51577556 0.36369765 0.2609165 ] - - paddle.seed(200) # on CPU device - y = paddle.rand([6]) - print(y) - # [0.77663314 0.90824795 0.15685187 - # 0.04279523 0.34468332 0.7955718 ] - - cat = Categorical(x) - cat2 = Categorical(y) - - cat.kl_divergence(cat2) - # [0.071952] - - """ - name = self.name + '_kl_divergence' - if not in_dygraph_mode(): - check_type(other, 'other', Categorical, 'kl_divergence') - - logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) - other_logits = other.logits - nn.reduce_max( - other.logits, dim=-1, keep_dim=True) - e_logits = ops.exp(logits) - other_e_logits = ops.exp(other_logits) - z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True) - other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True) - prob = e_logits / z - kl = nn.reduce_sum( - prob * (logits - nn.log(z) - other_logits + nn.log(other_z)), - dim=-1, - keep_dim=True, - name=name) - - return kl - - def entropy(self): - """Shannon entropy in nats. - - Returns: - Tensor: Shannon entropy of Categorical distribution. The data type is float32. - - Examples: - .. code-block:: python - - import paddle - from paddle.distribution import Categorical - - paddle.seed(100) # on CPU device - x = paddle.rand([6]) - print(x) - # [0.5535528 0.20714243 0.01162981 - # 0.51577556 0.36369765 0.2609165 ] - - cat = Categorical(x) - - cat.entropy() - # [1.77528] - - """ - name = self.name + '_entropy' - logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) - e_logits = ops.exp(logits) - z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True) - prob = e_logits / z - - neg_entropy = nn.reduce_sum( - prob * (logits - nn.log(z)), dim=-1, keep_dim=True) - entropy = nn.scale(neg_entropy, scale=-1.0, name=name) - return entropy - - def probs(self, value): - """Probabilities of the given category (``value``). - - If ``logits`` is 2-D or higher dimension, the last dimension will be regarded as - category, and the others represents the different distributions. - At the same time, if ``vlaue`` is 1-D Tensor, ``value`` will be broadcast to the - same number of distributions as ``logits``. - If ``value`` is not 1-D Tensor, ``value`` should have the same number distributions - with ``logits. That is, ``value[:-1] = logits[:-1]``. - - Args: - value (Tensor): The input tensor represents the selected category index. - - Returns: - Tensor: probability according to the category index. - - Examples: - .. code-block:: python - - import paddle - from paddle.distribution import Categorical - - paddle.seed(100) # on CPU device - x = paddle.rand([6]) - print(x) - # [0.5535528 0.20714243 0.01162981 - # 0.51577556 0.36369765 0.2609165 ] - - cat = Categorical(x) - - value = paddle.to_tensor([2,1,3]) - cat.probs(value) - # [0.00608027 0.108298 0.269656] - - """ - name = self.name + '_probs' - - dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True) - prob = self.logits / dist_sum - - shape = list(prob.shape) - value_shape = list(value.shape) - if len(shape) == 1: - num_value_in_one_dist = np.prod(value_shape) - index_value = nn.reshape(value, [num_value_in_one_dist, 1]) - index = index_value - else: - num_dist = np.prod(shape[:-1]) - num_value_in_one_dist = value_shape[-1] - prob = nn.reshape(prob, [num_dist, shape[-1]]) - if len(value_shape) == 1: - value = nn.expand(value, [num_dist]) - value_shape = shape[:-1] + value_shape - index_value = nn.reshape(value, [num_dist, -1, 1]) - if shape[:-1] != value_shape[:-1]: - raise ValueError( - "shape of value {} must match shape of logits {}".format( - str(value_shape[:-1]), str(shape[:-1]))) - - index_prefix = nn.unsqueeze( - arange( - num_dist, dtype=index_value.dtype), axes=-1) - index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist]) - index_prefix = nn.unsqueeze(index_prefix, axes=-1) - - if index_value.dtype != index_prefix.dtype: - tensor.cast(index_prefix, dtype=index_value.dtype) - index = concat([index_prefix, index_value], axis=-1) - - # value is the category index to search for the corresponding probability. - select_prob = gather_nd(prob, index) - return nn.reshape(select_prob, value_shape, name=name) - - def log_prob(self, value): - """Log probabilities of the given category. Refer to ``probs`` method. - - Args: - value (Tensor): The input tensor represents the selected category index. - - Returns: - Tensor: Log probability. - - Examples: - .. code-block:: python - - import paddle - from paddle.distribution import Categorical - - paddle.seed(100) # on CPU device - x = paddle.rand([6]) - print(x) - # [0.5535528 0.20714243 0.01162981 - # 0.51577556 0.36369765 0.2609165 ] - - cat = Categorical(x) - - value = paddle.to_tensor([2,1,3]) - cat.log_prob(value) - # [-5.10271 -2.22287 -1.31061] - - """ - name = self.name + '_log_prob' - - return nn.log(self.probs(value), name=name) diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py new file mode 100644 index 00000000000..282b53f5ac4 --- /dev/null +++ b/python/paddle/distribution/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .categorical import Categorical +from .distribution import Distribution +from .normal import Normal +from .uniform import Uniform + +__all__ = ['Categorical', 'Distribution', 'Normal', 'Uniform'] diff --git a/python/paddle/distribution/categorical.py b/python/paddle/distribution/categorical.py new file mode 100644 index 00000000000..151e060e29b --- /dev/null +++ b/python/paddle/distribution/categorical.py @@ -0,0 +1,357 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import numpy as np +from paddle import _C_ops + +from ..fluid import core +from ..fluid.data_feeder import (check_dtype, check_type, + check_variable_and_dtype, convert_dtype) +from ..fluid.framework import in_dygraph_mode +from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, + elementwise_mul, elementwise_sub, nn, ops, tensor) +from ..tensor import arange, concat, gather_nd, multinomial +from .distribution import Distribution + + +class Categorical(Distribution): + r""" + Categorical distribution is a discrete probability distribution that + describes the possible results of a random variable that can take on + one of K possible categories, with the probability of each category + separately specified. + + The probability mass function (pmf) is: + + .. math:: + + pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]} + + In the above equation: + + * :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise. + + Args: + logits(list|tuple|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64. + name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + paddle.seed(100) # on CPU device + x = paddle.rand([6]) + print(x) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] + + paddle.seed(200) # on CPU device + y = paddle.rand([6]) + print(y) + # [0.77663314 0.90824795 0.15685187 + # 0.04279523 0.34468332 0.7955718 ] + + cat = Categorical(x) + cat2 = Categorical(y) + + paddle.seed(1000) # on CPU device + cat.sample([2,3]) + # [[0, 0, 5], + # [3, 4, 5]] + + cat.entropy() + # [1.77528] + + cat.kl_divergence(cat2) + # [0.071952] + + value = paddle.to_tensor([2,1,3]) + cat.probs(value) + # [0.00608027 0.108298 0.269656] + + cat.log_prob(value) + # [-5.10271 -2.22287 -1.31061] + + """ + + def __init__(self, logits, name=None): + """ + Args: + logits(list|tuple|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64. + name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + """ + if not in_dygraph_mode(): + check_type(logits, 'logits', + (np.ndarray, tensor.Variable, list, tuple), + 'Categorical') + + self.name = name if name is not None else 'Categorical' + self.dtype = 'float32' + + if self._validate_args(logits): + self.logits = logits + self.dtype = convert_dtype(logits.dtype) + else: + if isinstance(logits, np.ndarray) and str( + logits.dtype) in ['float32', 'float64']: + self.dtype = logits.dtype + self.logits = self._to_tensor(logits)[0] + if self.dtype != convert_dtype(self.logits.dtype): + self.logits = tensor.cast(self.logits, dtype=self.dtype) + + def sample(self, shape): + """Generate samples of the specified shape. + + Args: + shape (list): Shape of the generated samples. + + Returns: + Tensor: A tensor with prepended dimensions shape. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + paddle.seed(100) # on CPU device + x = paddle.rand([6]) + print(x) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] + + cat = Categorical(x) + + paddle.seed(1000) # on CPU device + cat.sample([2,3]) + # [[0, 0, 5], + # [3, 4, 5]] + + """ + name = self.name + '_sample' + if not in_dygraph_mode(): + check_type(shape, 'shape', (list), 'sample') + + num_samples = np.prod(np.array(shape)) + + logits_shape = list(self.logits.shape) + if len(logits_shape) > 1: + sample_shape = shape + logits_shape[:-1] + logits = nn.reshape(self.logits, + [np.prod(logits_shape[:-1]), logits_shape[-1]]) + else: + sample_shape = shape + logits = self.logits + + sample_index = multinomial(logits, num_samples, True) + return nn.reshape(sample_index, sample_shape, name=name) + + def kl_divergence(self, other): + """The KL-divergence between two Categorical distributions. + + Args: + other (Categorical): instance of Categorical. The data type is float32. + + Returns: + Tensor: kl-divergence between two Categorical distributions. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + paddle.seed(100) # on CPU device + x = paddle.rand([6]) + print(x) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] + + paddle.seed(200) # on CPU device + y = paddle.rand([6]) + print(y) + # [0.77663314 0.90824795 0.15685187 + # 0.04279523 0.34468332 0.7955718 ] + + cat = Categorical(x) + cat2 = Categorical(y) + + cat.kl_divergence(cat2) + # [0.071952] + + """ + name = self.name + '_kl_divergence' + if not in_dygraph_mode(): + check_type(other, 'other', Categorical, 'kl_divergence') + + logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) + other_logits = other.logits - nn.reduce_max( + other.logits, dim=-1, keep_dim=True) + e_logits = ops.exp(logits) + other_e_logits = ops.exp(other_logits) + z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True) + other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True) + prob = e_logits / z + kl = nn.reduce_sum( + prob * (logits - nn.log(z) - other_logits + nn.log(other_z)), + dim=-1, + keep_dim=True, + name=name) + + return kl + + def entropy(self): + """Shannon entropy in nats. + + Returns: + Tensor: Shannon entropy of Categorical distribution. The data type is float32. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + paddle.seed(100) # on CPU device + x = paddle.rand([6]) + print(x) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] + + cat = Categorical(x) + + cat.entropy() + # [1.77528] + + """ + name = self.name + '_entropy' + logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) + e_logits = ops.exp(logits) + z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True) + prob = e_logits / z + + neg_entropy = nn.reduce_sum( + prob * (logits - nn.log(z)), dim=-1, keep_dim=True) + entropy = nn.scale(neg_entropy, scale=-1.0, name=name) + return entropy + + def probs(self, value): + """Probabilities of the given category (``value``). + + If ``logits`` is 2-D or higher dimension, the last dimension will be regarded as + category, and the others represents the different distributions. + At the same time, if ``vlaue`` is 1-D Tensor, ``value`` will be broadcast to the + same number of distributions as ``logits``. + If ``value`` is not 1-D Tensor, ``value`` should have the same number distributions + with ``logits. That is, ``value[:-1] = logits[:-1]``. + + Args: + value (Tensor): The input tensor represents the selected category index. + + Returns: + Tensor: probability according to the category index. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + paddle.seed(100) # on CPU device + x = paddle.rand([6]) + print(x) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] + + cat = Categorical(x) + + value = paddle.to_tensor([2,1,3]) + cat.probs(value) + # [0.00608027 0.108298 0.269656] + + """ + name = self.name + '_probs' + + dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True) + prob = self.logits / dist_sum + + shape = list(prob.shape) + value_shape = list(value.shape) + if len(shape) == 1: + num_value_in_one_dist = np.prod(value_shape) + index_value = nn.reshape(value, [num_value_in_one_dist, 1]) + index = index_value + else: + num_dist = np.prod(shape[:-1]) + num_value_in_one_dist = value_shape[-1] + prob = nn.reshape(prob, [num_dist, shape[-1]]) + if len(value_shape) == 1: + value = nn.expand(value, [num_dist]) + value_shape = shape[:-1] + value_shape + index_value = nn.reshape(value, [num_dist, -1, 1]) + if shape[:-1] != value_shape[:-1]: + raise ValueError( + "shape of value {} must match shape of logits {}".format( + str(value_shape[:-1]), str(shape[:-1]))) + + index_prefix = nn.unsqueeze( + arange( + num_dist, dtype=index_value.dtype), axes=-1) + index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist]) + index_prefix = nn.unsqueeze(index_prefix, axes=-1) + + if index_value.dtype != index_prefix.dtype: + tensor.cast(index_prefix, dtype=index_value.dtype) + index = concat([index_prefix, index_value], axis=-1) + + # value is the category index to search for the corresponding probability. + select_prob = gather_nd(prob, index) + return nn.reshape(select_prob, value_shape, name=name) + + def log_prob(self, value): + """Log probabilities of the given category. Refer to ``probs`` method. + + Args: + value (Tensor): The input tensor represents the selected category index. + + Returns: + Tensor: Log probability. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + paddle.seed(100) # on CPU device + x = paddle.rand([6]) + print(x) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] + + cat = Categorical(x) + + value = paddle.to_tensor([2,1,3]) + cat.log_prob(value) + # [-5.10271 -2.22287 -1.31061] + + """ + name = self.name + '_log_prob' + + return nn.log(self.probs(value), name=name) diff --git a/python/paddle/distribution/distribution.py b/python/paddle/distribution/distribution.py new file mode 100644 index 00000000000..6cdbedf4194 --- /dev/null +++ b/python/paddle/distribution/distribution.py @@ -0,0 +1,163 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: define the distribution functions +# __all__ = ['Categorical', +# 'MultivariateNormalDiag', +# 'Normal', +# 'sampling_id', +# 'Uniform'] + +from __future__ import print_function + +import math +import warnings + +import numpy as np +from paddle import _C_ops + +from ..fluid import core +from ..fluid.data_feeder import (check_dtype, check_type, + check_variable_and_dtype, convert_dtype) +from ..fluid.framework import in_dygraph_mode +from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, + elementwise_mul, elementwise_sub, nn, ops, tensor) +from ..tensor import arange, concat, gather_nd, multinomial + + +class Distribution(object): + """ + The abstract base class for probability distributions. Functions are + implemented in specific distributions. + """ + + def __init__(self): + super(Distribution, self).__init__() + + def sample(self): + """Sampling from the distribution.""" + raise NotImplementedError + + def entropy(self): + """The entropy of the distribution.""" + raise NotImplementedError + + def kl_divergence(self, other): + """The KL-divergence between self distributions and other.""" + raise NotImplementedError + + def log_prob(self, value): + """Log probability density/mass function.""" + raise NotImplementedError + + def probs(self, value): + """Probability density/mass function.""" + raise NotImplementedError + + def _validate_args(self, *args): + """ + Argument validation for distribution args + Args: + value (float, list, numpy.ndarray, Tensor) + Raises + ValueError: if one argument is Tensor, all arguments should be Tensor + """ + is_variable = False + is_number = False + for arg in args: + if isinstance(arg, tensor.Variable): + is_variable = True + else: + is_number = True + + if is_variable and is_number: + raise ValueError( + 'if one argument is Tensor, all arguments should be Tensor') + + return is_variable + + def _to_tensor(self, *args): + """ + Argument convert args to Tensor + + Args: + value (float, list, numpy.ndarray, Tensor) + Returns: + Tensor of args. + """ + numpy_args = [] + variable_args = [] + tmp = 0. + + for arg in args: + if isinstance(arg, float): + arg = [arg] + if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)): + raise TypeError( + "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}". + format(type(arg))) + + arg_np = np.array(arg) + arg_dtype = arg_np.dtype + if str(arg_dtype) != 'float32': + if str(arg_dtype) != 'float64': + # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated + # and converted to float64 later using "cast". + warnings.warn( + "data type of argument only support float32 and float64, your argument will be convert to float32." + ) + arg_np = arg_np.astype('float32') + # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape. + tmp = tmp + arg_np + numpy_args.append(arg_np) + + dtype = tmp.dtype + for arg in numpy_args: + arg_broadcasted, _ = np.broadcast_arrays(arg, tmp) + arg_variable = tensor.create_tensor(dtype=dtype) + tensor.assign(arg_broadcasted, arg_variable) + variable_args.append(arg_variable) + + return tuple(variable_args) + + def _check_values_dtype_in_probs(self, param, value): + """ + Log_prob and probs methods have input ``value``, if value's dtype is different from param, + convert value's dtype to be consistent with param's dtype. + + Args: + param (Tensor): low and high in Uniform class, loc and scale in Normal class. + value (Tensor): The input tensor. + + Returns: + value (Tensor): Change value's dtype if value's dtype is different from param. + """ + if in_dygraph_mode(): + if value.dtype != param.dtype and convert_dtype( + value.dtype) in ['float32', 'float64']: + warnings.warn( + "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted." + ) + return _C_ops.cast(value, 'in_dtype', value.dtype, 'out_dtype', + param.dtype) + return value + + check_variable_and_dtype(value, 'value', ['float32', 'float64'], + 'log_prob') + if value.dtype != param.dtype: + warnings.warn( + "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted." + ) + return tensor.cast(value, dtype=param.dtype) + return value diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py new file mode 100644 index 00000000000..95d544577df --- /dev/null +++ b/python/paddle/distribution/normal.py @@ -0,0 +1,279 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import numpy as np +from paddle import _C_ops + +from ..fluid import core +from ..fluid.data_feeder import (check_dtype, check_type, + check_variable_and_dtype, convert_dtype) +from ..fluid.framework import in_dygraph_mode +from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, + elementwise_mul, elementwise_sub, nn, ops, tensor) +from ..tensor import arange, concat, gather_nd, multinomial +from .distribution import Distribution + + +class Normal(Distribution): + r"""The Normal distribution with location `loc` and `scale` parameters. + + Mathematical details + + The probability density function (pdf) is + + .. math:: + + pdf(x; \mu, \sigma) = \\frac{1}{Z}e^{\\frac {-0.5 (x - \mu)^2} {\sigma^2} } + + .. math:: + + Z = (2 \pi \sigma^2)^{0.5} + + In the above equation: + + * :math:`loc = \mu`: is the mean. + * :math:`scale = \sigma`: is the std. + * :math:`Z`: is the normalization constant. + + Args: + loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor. + scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor. + name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Normal + + # Define a single scalar Normal distribution. + dist = Normal(loc=0., scale=3.) + # Define a batch of two scalar valued Normals. + # The first has mean 1 and standard deviation 11, the second 2 and 22. + dist = Normal(loc=[1., 2.], scale=[11., 22.]) + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + + # Define a batch of two scalar valued Normals. + # Both have mean 1, but different standard deviations. + dist = Normal(loc=1., scale=[11., 22.]) + + # Complete example + value_tensor = paddle.to_tensor([0.8], dtype="float32") + + normal_a = Normal([0.], [1.]) + normal_b = Normal([0.5], [2.]) + sample = normal_a.sample([2]) + # a random tensor created by normal distribution with shape: [2, 1] + entropy = normal_a.entropy() + # [1.4189385] with shape: [1] + lp = normal_a.log_prob(value_tensor) + # [-1.2389386] with shape: [1] + p = normal_a.probs(value_tensor) + # [0.28969154] with shape: [1] + kl = normal_a.kl_divergence(normal_b) + # [0.34939718] with shape: [1] + """ + + def __init__(self, loc, scale, name=None): + if not in_dygraph_mode(): + check_type(loc, 'loc', + (int, float, np.ndarray, tensor.Variable, list, tuple), + 'Normal') + check_type(scale, 'scale', + (int, float, np.ndarray, tensor.Variable, list, tuple), + 'Normal') + + self.batch_size_unknown = False + self.all_arg_is_float = False + self.name = name if name is not None else 'Normal' + self.dtype = 'float32' + + if isinstance(loc, int): + loc = float(loc) + if isinstance(scale, int): + scale = float(scale) + + if self._validate_args(loc, scale): + self.batch_size_unknown = True + self.loc = loc + self.scale = scale + self.dtype = convert_dtype(loc.dtype) + else: + if isinstance(loc, float) and isinstance(scale, float): + self.all_arg_is_float = True + if isinstance( + loc, + np.ndarray) and str(loc.dtype) in ['float32', 'float64']: + self.dtype = loc.dtype + elif isinstance( + scale, + np.ndarray) and str(scale.dtype) in ['float32', 'float64']: + self.dtype = scale.dtype + # pylint: disable=unbalanced-tuple-unpacking + self.loc, self.scale = self._to_tensor(loc, scale) + if self.dtype != convert_dtype(self.loc.dtype): + self.loc = tensor.cast(self.loc, dtype=self.dtype) + self.scale = tensor.cast(self.scale, dtype=self.dtype) + + def sample(self, shape, seed=0): + """Generate samples of the specified shape. + + Args: + shape (list): 1D `int32`. Shape of the generated samples. + seed (int): Python integer number. + + Returns: + Tensor: A tensor with prepended dimensions shape.The data type is float32. + + """ + if not in_dygraph_mode(): + check_type(shape, 'shape', (list), 'sample') + check_type(seed, 'seed', (int), 'sample') + + batch_shape = list((self.loc + self.scale).shape) + name = self.name + '_sample' + + if self.batch_size_unknown: + output_shape = shape + batch_shape + zero_tmp = tensor.fill_constant_batch_size_like( + self.loc + self.scale, batch_shape + shape, self.dtype, 0.) + zero_tmp_reshape = nn.reshape(zero_tmp, output_shape) + zero_tmp_shape = nn.shape(zero_tmp_reshape) + normal_random_tmp = nn.gaussian_random( + zero_tmp_shape, mean=0., std=1., seed=seed, dtype=self.dtype) + output = normal_random_tmp * (zero_tmp_reshape + self.scale) + output = elementwise_add(output, self.loc, name=name) + return output + else: + output_shape = shape + batch_shape + output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \ + (tensor.zeros(output_shape, dtype=self.dtype) + self.scale) + output = elementwise_add(output, self.loc, name=name) + if self.all_arg_is_float: + return nn.reshape(output, shape, name=name) + else: + return output + + def entropy(self): + r"""Shannon entropy in nats. + + The entropy is + + .. math:: + + entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2) + + In the above equation: + + * :math:`scale = \sigma`: is the std. + + Returns: + Tensor: Shannon entropy of normal distribution.The data type is float32. + + """ + name = self.name + '_entropy' + batch_shape = list((self.loc + self.scale).shape) + zero_tmp = tensor.fill_constant_batch_size_like( + self.loc + self.scale, batch_shape, self.dtype, 0.) + return elementwise_add( + 0.5 + zero_tmp, + 0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)), + name=name) + + def log_prob(self, value): + """Log probability density/mass function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: log probability.The data type is same with value. + + """ + name = self.name + '_log_prob' + value = self._check_values_dtype_in_probs(self.loc, value) + + var = self.scale * self.scale + log_scale = nn.log(self.scale) + return elementwise_sub( + -1. * ((value - self.loc) * (value - self.loc)) / (2. * var), + log_scale + math.log(math.sqrt(2. * math.pi)), + name=name) + + def probs(self, value): + """Probability density/mass function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: probability.The data type is same with value. + + """ + name = self.name + '_probs' + value = self._check_values_dtype_in_probs(self.loc, value) + + var = self.scale * self.scale + return elementwise_div( + ops.exp(-1. * ((value - self.loc) * (value - self.loc)) / + (2. * var)), (math.sqrt(2 * math.pi) * self.scale), + name=name) + + def kl_divergence(self, other): + r"""The KL-divergence between two normal distributions. + + The probability density function (pdf) is + + .. math:: + + KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio}) + + .. math:: + + ratio = \\frac{\sigma_0}{\sigma_1} + + .. math:: + + diff = \mu_1 - \mu_0 + + In the above equation: + + * :math:`loc = \mu_0`: is the mean of current Normal distribution. + * :math:`scale = \sigma_0`: is the std of current Normal distribution. + * :math:`loc = \mu_1`: is the mean of other Normal distribution. + * :math:`scale = \sigma_1`: is the std of other Normal distribution. + * :math:`ratio`: is the ratio of scales. + * :math:`diff`: is the difference between means. + + Args: + other (Normal): instance of Normal. + + Returns: + Tensor: kl-divergence between two normal distributions.The data type is float32. + + """ + if not in_dygraph_mode(): + check_type(other, 'other', Normal, 'kl_divergence') + + name = self.name + '_kl_divergence' + var_ratio = self.scale / other.scale + var_ratio = (var_ratio * var_ratio) + t1 = (self.loc - other.loc) / other.scale + t1 = (t1 * t1) + return elementwise_add( + 0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name) diff --git a/python/paddle/distribution/uniform.py b/python/paddle/distribution/uniform.py new file mode 100644 index 00000000000..516071dafe6 --- /dev/null +++ b/python/paddle/distribution/uniform.py @@ -0,0 +1,252 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import numpy as np +from paddle import _C_ops + +from ..fluid import core +from ..fluid.data_feeder import (check_dtype, check_type, + check_variable_and_dtype, convert_dtype) +from ..fluid.framework import in_dygraph_mode +from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, + elementwise_mul, elementwise_sub, nn, ops, tensor) +from ..tensor import arange, concat, gather_nd, multinomial +from .distribution import Distribution + + +class Uniform(Distribution): + r"""Uniform distribution with `low` and `high` parameters. + + Mathematical Details + + The probability density function (pdf) is + + .. math:: + + pdf(x; a, b) = \\frac{1}{Z}, \ a <=x 0): + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + # used to construct another Normal object to calculate kl_divergence + self.other_loc_np = (np.random.ranf() - 0.5) * 4 + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + while not np.all(self.scale_np > 0): + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + + def init_static_data(self, batch_size, dims): + self.static_loc = self.loc_np + self.static_scale = self.scale_np + self.static_other_loc = self.other_loc_np + self.static_other_scale = self.other_scale_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class NormalTest4(NormalTest): + def init_numpy_data(self, batch_size, dims): + # loc and scale are numpy.ndarray with dtype 'float32'. + self.loc_np = np.random.randn(batch_size, dims).astype('float32') + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + while not np.all(self.scale_np > 0): + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + # used to construct another Normal object to calculate kl_divergence + self.other_loc_np = np.random.randn(batch_size, dims).astype('float32') + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + while not np.all(self.scale_np > 0): + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + + def init_static_data(self, batch_size, dims): + self.static_loc = self.loc_np + self.static_scale = self.scale_np + self.static_other_loc = self.other_loc_np + self.static_other_scale = self.other_scale_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class NormalTest5(NormalTest): + def init_numpy_data(self, batch_size, dims): + # loc and scale are numpy.ndarray with dtype 'float64'. + self.loc_np = np.random.randn(batch_size, dims).astype('float64') + self.scale_np = np.random.randn(batch_size, dims).astype('float64') + while not np.all(self.scale_np > 0): + self.scale_np = np.random.randn(batch_size, dims).astype('float64') + self.values_np = np.random.randn(batch_size, dims).astype('float64') + # used to construct another Normal object to calculate kl_divergence + self.other_loc_np = np.random.randn(batch_size, dims).astype('float64') + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float64') + while not np.all(self.scale_np > 0): + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float64') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_loc = self.loc_np + self.dynamic_scale = self.scale_np + self.dynamic_other_loc = self.other_loc_np + self.dynamic_other_scale = self.other_scale_np + self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64') + + def init_static_data(self, batch_size, dims): + self.static_loc = self.loc_np + self.static_scale = self.scale_np + self.static_other_loc = self.other_loc_np + self.static_other_scale = self.other_scale_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float64') + + +class NormalTest6(NormalTest): + def init_numpy_data(self, batch_size, dims): + # loc and scale are Tensor with dtype 'VarType.FP32'. + self.loc_np = np.random.randn(batch_size, dims).astype('float32') + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + while not np.all(self.scale_np > 0): + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + # used to construct another Normal object to calculate kl_divergence + self.other_loc_np = np.random.randn(batch_size, dims).astype('float32') + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + while not np.all(self.scale_np > 0): + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_loc = paddle.to_tensor(self.loc_np) + self.dynamic_scale = paddle.to_tensor(self.scale_np) + self.dynamic_values = paddle.to_tensor(self.values_np) + self.dynamic_other_loc = paddle.to_tensor(self.other_loc_np) + self.dynamic_other_scale = paddle.to_tensor(self.other_scale_np) + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.static_loc = layers.data( + name='loc', shape=[dims], dtype='float32') + self.static_scale = layers.data( + name='scale', shape=[dims], dtype='float32') + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + self.static_other_loc = layers.data( + name='other_loc', shape=[dims], dtype='float32') + self.static_other_scale = layers.data( + name='other_scale', shape=[dims], dtype='float32') + + +class NormalTest7(NormalTest): + def init_numpy_data(self, batch_size, dims): + # loc and scale are Tensor with dtype 'VarType.FP64'. + self.loc_np = np.random.randn(batch_size, dims).astype('float64') + self.scale_np = np.random.randn(batch_size, dims).astype('float64') + while not np.all(self.scale_np > 0): + self.scale_np = np.random.randn(batch_size, dims).astype('float64') + self.values_np = np.random.randn(batch_size, dims).astype('float64') + # used to construct another Normal object to calculate kl_divergence + self.other_loc_np = np.random.randn(batch_size, dims).astype('float64') + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float64') + while not np.all(self.scale_np > 0): + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float64') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_loc = paddle.to_tensor(self.loc_np, dtype='float64') + self.dynamic_scale = paddle.to_tensor(self.scale_np, dtype='float64') + self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64') + self.dynamic_other_loc = paddle.to_tensor( + self.other_loc_np, dtype='float64') + self.dynamic_other_scale = paddle.to_tensor( + self.other_scale_np, dtype='float64') + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.static_loc = layers.data( + name='loc', shape=[dims], dtype='float64') + self.static_scale = layers.data( + name='scale', shape=[dims], dtype='float64') + self.static_values = layers.data( + name='values', shape=[dims], dtype='float64') + self.static_other_loc = layers.data( + name='other_loc', shape=[dims], dtype='float64') + self.static_other_scale = layers.data( + name='other_scale', shape=[dims], dtype='float64') + + +class NormalTest8(NormalTest): + def init_numpy_data(self, batch_size, dims): + # loc and scale are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'. + self.loc_np = np.random.randn(batch_size, dims).astype('float64') + self.scale_np = np.random.randn(batch_size, dims).astype('float64') + while not np.all(self.scale_np > 0): + self.scale_np = np.random.randn(batch_size, dims).astype('float64') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + # used to construct another Normal object to calculate kl_divergence + self.other_loc_np = np.random.randn(batch_size, dims).astype('float64') + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float64') + while not np.all(self.scale_np > 0): + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float64') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_loc = paddle.to_tensor(self.loc_np, dtype='float64') + self.dynamic_scale = paddle.to_tensor(self.scale_np, dtype='float64') + self.dynamic_values = paddle.to_tensor(self.values_np) + self.dynamic_other_loc = paddle.to_tensor( + self.other_loc_np, dtype='float64') + self.dynamic_other_scale = paddle.to_tensor( + self.other_scale_np, dtype='float64') + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.static_loc = layers.data( + name='loc', shape=[dims], dtype='float64') + self.static_scale = layers.data( + name='scale', shape=[dims], dtype='float64') + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + self.static_other_loc = layers.data( + name='other_loc', shape=[dims], dtype='float64') + self.static_other_scale = layers.data( + name='other_scale', shape=[dims], dtype='float64') + + +class NormalTest9(NormalTest): + def init_numpy_data(self, batch_size, dims): + # loc and scale are list. + self.loc_np = np.random.randn(batch_size, + dims).astype('float32').tolist() + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + while not np.all(self.scale_np > 0): + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + self.scale_np = self.scale_np.tolist() + self.values_np = np.random.randn(batch_size, dims).astype('float32') + # used to construct another Normal object to calculate kl_divergence + self.other_loc_np = np.random.randn(batch_size, + dims).astype('float32').tolist() + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + while not np.all(self.other_scale_np > 0): + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + self.other_scale_np = self.other_scale_np.tolist() + + def init_static_data(self, batch_size, dims): + self.static_loc = self.loc_np + self.static_scale = self.scale_np + self.static_other_loc = self.other_loc_np + self.static_other_scale = self.other_scale_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class NormalTest10(NormalTest): + def init_numpy_data(self, batch_size, dims): + # loc and scale are tuple. + self.loc_np = tuple( + np.random.randn(batch_size, dims).astype('float32').tolist()) + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + while not np.all(self.scale_np > 0): + self.scale_np = np.random.randn(batch_size, dims).astype('float32') + self.scale_np = tuple(self.scale_np.tolist()) + self.values_np = np.random.randn(batch_size, dims).astype('float32') + # used to construct another Normal object to calculate kl_divergence + self.other_loc_np = tuple( + np.random.randn(batch_size, dims).astype('float32').tolist()) + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + while not np.all(self.other_scale_np > 0): + self.other_scale_np = np.random.randn(batch_size, + dims).astype('float32') + self.other_scale_np = tuple(self.other_scale_np.tolist()) + + def init_static_data(self, batch_size, dims): + self.static_loc = self.loc_np + self.static_scale = self.scale_np + self.static_other_loc = self.other_loc_np + self.static_other_scale = self.other_scale_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py new file mode 100644 index 00000000000..e6076764b04 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py @@ -0,0 +1,345 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import unittest + +import numpy as np +import paddle +from paddle import fluid +from paddle.distribution import * +from paddle.fluid import layers + +from test_distribution import DistributionNumpy + + +class UniformNumpy(DistributionNumpy): + def __init__(self, low, high): + self.low = np.array(low) + self.high = np.array(high) + if str(self.low.dtype) not in ['float32', 'float64']: + self.low = self.low.astype('float32') + self.high = self.high.astype('float32') + + def sample(self, shape): + shape = tuple(shape) + (self.low + self.high).shape + return self.low + (np.random.uniform(size=shape) * + (self.high - self.low)) + + def log_prob(self, value): + lb = np.less(self.low, value).astype(self.low.dtype) + ub = np.less(value, self.high).astype(self.low.dtype) + return np.log(lb * ub) - np.log(self.high - self.low) + + def probs(self, value): + lb = np.less(self.low, value).astype(self.low.dtype) + ub = np.less(value, self.high).astype(self.low.dtype) + return (lb * ub) / (self.high - self.low) + + def entropy(self): + return np.log(self.high - self.low) + + +class UniformTest(unittest.TestCase): + def setUp(self, use_gpu=False, batch_size=5, dims=6): + self.use_gpu = use_gpu + if not use_gpu: + self.place = fluid.CPUPlace() + self.gpu_id = -1 + else: + self.place = fluid.CUDAPlace(0) + self.gpu_id = 0 + + self.init_numpy_data(batch_size, dims) + + paddle.disable_static(self.place) + self.init_dynamic_data(batch_size, dims) + + paddle.enable_static() + self.test_program = fluid.Program() + self.executor = fluid.Executor(self.place) + self.init_static_data(batch_size, dims) + + def init_numpy_data(self, batch_size, dims): + # low ans high are 'float' + self.low_np = np.random.uniform(-2, 1) + self.high_np = np.random.uniform(2, 4) + self.values_np = np.array([1.0]).astype('float32') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_low = self.low_np + self.dynamic_high = self.high_np + self.dynamic_values = paddle.to_tensor(self.values_np) + + def init_static_data(self, batch_size, dims): + self.static_low = self.low_np + self.static_high = self.high_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[], dtype='float32') + + def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6): + sample, entropy, log_prob, probs = fetch_list + + np_uniform = UniformNumpy(self.low_np, self.high_np) + np_sample = np_uniform.sample([sample_shape]) + np_entropy = np_uniform.entropy() + np_lp = np_uniform.log_prob(self.values_np) + np_p = np_uniform.probs(self.values_np) + + np.testing.assert_equal(sample.shape, np_sample.shape) + np.testing.assert_allclose( + entropy, np_entropy, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + log_prob, np_lp, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance) + + def test_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6): + paddle.disable_static(self.place) + uniform = Uniform(self.dynamic_low, self.dynamic_high) + sample = uniform.sample([sample_shape]).numpy() + entropy = uniform.entropy().numpy() + log_prob = uniform.log_prob(self.dynamic_values).numpy() + probs = uniform.probs(self.dynamic_values).numpy() + fetch_list = [sample, entropy, log_prob, probs] + + self.compare_with_numpy(fetch_list) + + def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6): + paddle.enable_static() + with fluid.program_guard(self.test_program): + uniform = Uniform(self.static_low, self.static_high) + sample = uniform.sample([sample_shape]) + entropy = uniform.entropy() + log_prob = uniform.log_prob(self.static_values) + probs = uniform.probs(self.static_values) + fetch_list = [sample, entropy, log_prob, probs] + + feed_vars = { + 'low': self.low_np, + 'high': self.high_np, + 'values': self.values_np + } + + self.executor.run(fluid.default_startup_program()) + fetch_list = self.executor.run(program=self.test_program, + feed=feed_vars, + fetch_list=fetch_list) + + self.compare_with_numpy(fetch_list) + + +class UniformTest2(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low ans high are 'int' + self.low_np = int(np.random.uniform(-2, 1)) + self.high_np = int(np.random.uniform(2, 4)) + self.values_np = np.array([1.0]).astype('float32') + + +class UniformTest3(UniformTest): + def init_numpy_data(self, batch_size, dims): + # test broadcast: low is float, high is numpy.ndarray with dtype 'float32'. + self.low_np = np.random.uniform(-2, 1) + self.high_np = np.random.uniform(5.0, 15.0, + (batch_size, dims)).astype('float32') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + + def init_static_data(self, batch_size, dims): + self.static_low = self.low_np + self.static_high = self.high_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class UniformTest4(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low and high are numpy.ndarray with dtype 'float32'. + self.low_np = np.random.randn(batch_size, dims).astype('float32') + self.high_np = np.random.uniform(5.0, 15.0, + (batch_size, dims)).astype('float32') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + + def init_static_data(self, batch_size, dims): + self.static_low = self.low_np + self.static_high = self.high_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class UniformTest5(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low and high are numpy.ndarray with dtype 'float64'. + self.low_np = np.random.randn(batch_size, dims).astype('float64') + self.high_np = np.random.uniform(5.0, 15.0, + (batch_size, dims)).astype('float64') + self.values_np = np.random.randn(batch_size, dims).astype('float64') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_low = self.low_np + self.dynamic_high = self.high_np + self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64') + + def init_static_data(self, batch_size, dims): + self.static_low = self.low_np + self.static_high = self.high_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float64') + + +class UniformTest6(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low and high are Tensor with dtype 'VarType.FP32'. + self.low_np = np.random.randn(batch_size, dims).astype('float32') + self.high_np = np.random.uniform(5.0, 15.0, + (batch_size, dims)).astype('float32') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_low = paddle.to_tensor(self.low_np) + self.dynamic_high = paddle.to_tensor(self.high_np) + self.dynamic_values = paddle.to_tensor(self.values_np) + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.static_low = layers.data( + name='low', shape=[dims], dtype='float32') + self.static_high = layers.data( + name='high', shape=[dims], dtype='float32') + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class UniformTest7(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low and high are Tensor with dtype 'VarType.FP64'. + self.low_np = np.random.randn(batch_size, dims).astype('float64') + self.high_np = np.random.uniform(5.0, 15.0, + (batch_size, dims)).astype('float64') + self.values_np = np.random.randn(batch_size, dims).astype('float64') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64') + self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64') + self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64') + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.static_low = layers.data( + name='low', shape=[dims], dtype='float64') + self.static_high = layers.data( + name='high', shape=[dims], dtype='float64') + self.static_values = layers.data( + name='values', shape=[dims], dtype='float64') + + +class UniformTest8(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low and high are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'. + self.low_np = np.random.randn(batch_size, dims).astype('float64') + self.high_np = np.random.uniform(5.0, 15.0, + (batch_size, dims)).astype('float64') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + + def init_dynamic_data(self, batch_size, dims): + self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64') + self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64') + self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float32') + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.static_low = layers.data( + name='low', shape=[dims], dtype='float64') + self.static_high = layers.data( + name='high', shape=[dims], dtype='float64') + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class UniformTest9(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low and high are numpy.ndarray with dtype 'float32'. + # high < low. + self.low_np = np.random.randn(batch_size, dims).astype('float32') + self.high_np = np.random.uniform(-10.0, -5.0, + (batch_size, dims)).astype('float32') + self.values_np = np.random.randn(batch_size, dims).astype('float32') + + def init_static_data(self, batch_size, dims): + self.static_low = self.low_np + self.static_high = self.high_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class UniformTest10(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low and high are list. + self.low_np = np.random.randn(batch_size, + dims).astype('float32').tolist() + self.high_np = np.random.uniform( + 5.0, 15.0, (batch_size, dims)).astype('float32').tolist() + self.values_np = np.random.randn(batch_size, dims).astype('float32') + + def init_static_data(self, batch_size, dims): + self.static_low = self.low_np + self.static_high = self.high_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class UniformTest11(UniformTest): + def init_numpy_data(self, batch_size, dims): + # low and high are tuple. + self.low_np = tuple( + np.random.randn(batch_size, dims).astype('float32').tolist()) + self.high_np = tuple( + np.random.uniform(5.0, 15.0, (batch_size, dims)).astype('float32') + .tolist()) + self.values_np = np.random.randn(batch_size, dims).astype('float32') + + def init_static_data(self, batch_size, dims): + self.static_low = self.low_np + self.static_high = self.high_np + with fluid.program_guard(self.test_program): + self.static_values = layers.data( + name='values', shape=[dims], dtype='float32') + + +class UniformTestSample(unittest.TestCase): + def setUp(self): + self.init_param() + + def init_param(self): + self.low = 3.0 + self.high = 4.0 + + def test_uniform_sample(self): + paddle.disable_static() + uniform = Uniform(low=self.low, high=self.high) + s = uniform.sample([100]) + self.assertTrue((s >= self.low).all()) + self.assertTrue((s < self.high).all()) + paddle.enable_static() + + +class UniformTestSample2(UniformTestSample): + def init_param(self): + self.low = -5.0 + self.high = 2.0 diff --git a/python/paddle/fluid/tests/unittests/test_distribution.py b/python/paddle/fluid/tests/unittests/test_distribution.py deleted file mode 100644 index 6cf2c5f6e2c..00000000000 --- a/python/paddle/fluid/tests/unittests/test_distribution.py +++ /dev/null @@ -1,1299 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import unittest -import paddle -from paddle import fluid -from paddle.fluid import layers -from paddle.distribution import * -import math - - -class DistributionNumpy(): - def sample(self): - raise NotImplementedError - - def entropy(self): - raise NotImplementedError - - def kl_divergence(self, other): - raise NotImplementedError - - def log_prob(self, value): - raise NotImplementedError - - def probs(self, value): - raise NotImplementedError - - -class UniformNumpy(DistributionNumpy): - def __init__(self, low, high): - self.low = np.array(low) - self.high = np.array(high) - if str(self.low.dtype) not in ['float32', 'float64']: - self.low = self.low.astype('float32') - self.high = self.high.astype('float32') - - def sample(self, shape): - shape = tuple(shape) + (self.low + self.high).shape - return self.low + (np.random.uniform(size=shape) * - (self.high - self.low)) - - def log_prob(self, value): - lb = np.less(self.low, value).astype(self.low.dtype) - ub = np.less(value, self.high).astype(self.low.dtype) - return np.log(lb * ub) - np.log(self.high - self.low) - - def probs(self, value): - lb = np.less(self.low, value).astype(self.low.dtype) - ub = np.less(value, self.high).astype(self.low.dtype) - return (lb * ub) / (self.high - self.low) - - def entropy(self): - return np.log(self.high - self.low) - - -class UniformTest(unittest.TestCase): - def setUp(self, use_gpu=False, batch_size=5, dims=6): - self.use_gpu = use_gpu - if not use_gpu: - self.place = fluid.CPUPlace() - self.gpu_id = -1 - else: - self.place = fluid.CUDAPlace(0) - self.gpu_id = 0 - - self.init_numpy_data(batch_size, dims) - - paddle.disable_static(self.place) - self.init_dynamic_data(batch_size, dims) - - paddle.enable_static() - self.test_program = fluid.Program() - self.executor = fluid.Executor(self.place) - self.init_static_data(batch_size, dims) - - def init_numpy_data(self, batch_size, dims): - # low ans high are 'float' - self.low_np = np.random.uniform(-2, 1) - self.high_np = np.random.uniform(2, 4) - self.values_np = np.array([1.0]).astype('float32') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_low = self.low_np - self.dynamic_high = self.high_np - self.dynamic_values = paddle.to_tensor(self.values_np) - - def init_static_data(self, batch_size, dims): - self.static_low = self.low_np - self.static_high = self.high_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[], dtype='float32') - - def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6): - sample, entropy, log_prob, probs = fetch_list - - np_uniform = UniformNumpy(self.low_np, self.high_np) - np_sample = np_uniform.sample([sample_shape]) - np_entropy = np_uniform.entropy() - np_lp = np_uniform.log_prob(self.values_np) - np_p = np_uniform.probs(self.values_np) - - np.testing.assert_equal(sample.shape, np_sample.shape) - np.testing.assert_allclose( - entropy, np_entropy, rtol=tolerance, atol=tolerance) - np.testing.assert_allclose( - log_prob, np_lp, rtol=tolerance, atol=tolerance) - np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance) - - def test_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6): - paddle.disable_static(self.place) - uniform = Uniform(self.dynamic_low, self.dynamic_high) - sample = uniform.sample([sample_shape]).numpy() - entropy = uniform.entropy().numpy() - log_prob = uniform.log_prob(self.dynamic_values).numpy() - probs = uniform.probs(self.dynamic_values).numpy() - fetch_list = [sample, entropy, log_prob, probs] - - self.compare_with_numpy(fetch_list) - - def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6): - paddle.enable_static() - with fluid.program_guard(self.test_program): - uniform = Uniform(self.static_low, self.static_high) - sample = uniform.sample([sample_shape]) - entropy = uniform.entropy() - log_prob = uniform.log_prob(self.static_values) - probs = uniform.probs(self.static_values) - fetch_list = [sample, entropy, log_prob, probs] - - feed_vars = { - 'low': self.low_np, - 'high': self.high_np, - 'values': self.values_np - } - - self.executor.run(fluid.default_startup_program()) - fetch_list = self.executor.run(program=self.test_program, - feed=feed_vars, - fetch_list=fetch_list) - - self.compare_with_numpy(fetch_list) - - -class UniformTest2(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low ans high are 'int' - self.low_np = int(np.random.uniform(-2, 1)) - self.high_np = int(np.random.uniform(2, 4)) - self.values_np = np.array([1.0]).astype('float32') - - -class UniformTest3(UniformTest): - def init_numpy_data(self, batch_size, dims): - # test broadcast: low is float, high is numpy.ndarray with dtype 'float32'. - self.low_np = np.random.uniform(-2, 1) - self.high_np = np.random.uniform(5.0, 15.0, - (batch_size, dims)).astype('float32') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - - def init_static_data(self, batch_size, dims): - self.static_low = self.low_np - self.static_high = self.high_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class UniformTest4(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low and high are numpy.ndarray with dtype 'float32'. - self.low_np = np.random.randn(batch_size, dims).astype('float32') - self.high_np = np.random.uniform(5.0, 15.0, - (batch_size, dims)).astype('float32') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - - def init_static_data(self, batch_size, dims): - self.static_low = self.low_np - self.static_high = self.high_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class UniformTest5(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low and high are numpy.ndarray with dtype 'float64'. - self.low_np = np.random.randn(batch_size, dims).astype('float64') - self.high_np = np.random.uniform(5.0, 15.0, - (batch_size, dims)).astype('float64') - self.values_np = np.random.randn(batch_size, dims).astype('float64') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_low = self.low_np - self.dynamic_high = self.high_np - self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64') - - def init_static_data(self, batch_size, dims): - self.static_low = self.low_np - self.static_high = self.high_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float64') - - -class UniformTest6(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low and high are Tensor with dtype 'VarType.FP32'. - self.low_np = np.random.randn(batch_size, dims).astype('float32') - self.high_np = np.random.uniform(5.0, 15.0, - (batch_size, dims)).astype('float32') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_low = paddle.to_tensor(self.low_np) - self.dynamic_high = paddle.to_tensor(self.high_np) - self.dynamic_values = paddle.to_tensor(self.values_np) - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.static_low = layers.data( - name='low', shape=[dims], dtype='float32') - self.static_high = layers.data( - name='high', shape=[dims], dtype='float32') - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class UniformTest7(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low and high are Tensor with dtype 'VarType.FP64'. - self.low_np = np.random.randn(batch_size, dims).astype('float64') - self.high_np = np.random.uniform(5.0, 15.0, - (batch_size, dims)).astype('float64') - self.values_np = np.random.randn(batch_size, dims).astype('float64') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64') - self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64') - self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64') - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.static_low = layers.data( - name='low', shape=[dims], dtype='float64') - self.static_high = layers.data( - name='high', shape=[dims], dtype='float64') - self.static_values = layers.data( - name='values', shape=[dims], dtype='float64') - - -class UniformTest8(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low and high are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'. - self.low_np = np.random.randn(batch_size, dims).astype('float64') - self.high_np = np.random.uniform(5.0, 15.0, - (batch_size, dims)).astype('float64') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64') - self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64') - self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float32') - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.static_low = layers.data( - name='low', shape=[dims], dtype='float64') - self.static_high = layers.data( - name='high', shape=[dims], dtype='float64') - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class UniformTest9(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low and high are numpy.ndarray with dtype 'float32'. - # high < low. - self.low_np = np.random.randn(batch_size, dims).astype('float32') - self.high_np = np.random.uniform(-10.0, -5.0, - (batch_size, dims)).astype('float32') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - - def init_static_data(self, batch_size, dims): - self.static_low = self.low_np - self.static_high = self.high_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class UniformTest10(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low and high are list. - self.low_np = np.random.randn(batch_size, - dims).astype('float32').tolist() - self.high_np = np.random.uniform( - 5.0, 15.0, (batch_size, dims)).astype('float32').tolist() - self.values_np = np.random.randn(batch_size, dims).astype('float32') - - def init_static_data(self, batch_size, dims): - self.static_low = self.low_np - self.static_high = self.high_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class UniformTest11(UniformTest): - def init_numpy_data(self, batch_size, dims): - # low and high are tuple. - self.low_np = tuple( - np.random.randn(batch_size, dims).astype('float32').tolist()) - self.high_np = tuple( - np.random.uniform(5.0, 15.0, (batch_size, dims)).astype('float32') - .tolist()) - self.values_np = np.random.randn(batch_size, dims).astype('float32') - - def init_static_data(self, batch_size, dims): - self.static_low = self.low_np - self.static_high = self.high_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class UniformTestSample(unittest.TestCase): - def setUp(self): - self.init_param() - - def init_param(self): - self.low = 3.0 - self.high = 4.0 - - def test_uniform_sample(self): - paddle.disable_static() - uniform = Uniform(low=self.low, high=self.high) - s = uniform.sample([100]) - self.assertTrue((s >= self.low).all()) - self.assertTrue((s < self.high).all()) - paddle.enable_static() - - -class UniformTestSample2(UniformTestSample): - def init_param(self): - self.low = -5.0 - self.high = 2.0 - - -class NormalNumpy(DistributionNumpy): - def __init__(self, loc, scale): - self.loc = np.array(loc) - self.scale = np.array(scale) - if str(self.loc.dtype) not in ['float32', 'float64']: - self.loc = self.loc.astype('float32') - self.scale = self.scale.astype('float32') - - def sample(self, shape): - shape = tuple(shape) + (self.loc + self.scale).shape - return self.loc + (np.random.randn(*shape) * self.scale) - - def log_prob(self, value): - var = self.scale * self.scale - log_scale = np.log(self.scale) - return -((value - self.loc) * (value - self.loc)) / ( - 2. * var) - log_scale - math.log(math.sqrt(2. * math.pi)) - - def probs(self, value): - var = self.scale * self.scale - return np.exp(-1. * ((value - self.loc) * (value - self.loc)) / - (2. * var)) / (math.sqrt(2 * math.pi) * self.scale) - - def entropy(self): - return 0.5 + 0.5 * np.log( - np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale) - - def kl_divergence(self, other): - var_ratio = (self.scale / other.scale) - var_ratio = var_ratio * var_ratio - t1 = ((self.loc - other.loc) / other.scale) - t1 = (t1 * t1) - return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) - - -class NormalTest(unittest.TestCase): - def setUp(self, use_gpu=False, batch_size=2, dims=3): - self.use_gpu = use_gpu - if not use_gpu: - self.place = fluid.CPUPlace() - self.gpu_id = -1 - else: - self.place = fluid.CUDAPlace(0) - self.gpu_id = 0 - - self.init_numpy_data(batch_size, dims) - - paddle.disable_static(self.place) - self.init_dynamic_data(batch_size, dims) - - paddle.enable_static() - self.test_program = fluid.Program() - self.executor = fluid.Executor(self.place) - self.init_static_data(batch_size, dims) - - def init_numpy_data(self, batch_size, dims): - # loc ans scale are 'float' - self.loc_np = (np.random.ranf() - 0.5) * 4 - self.scale_np = (np.random.ranf() - 0.5) * 4 - while self.scale_np < 0: - self.scale_np = (np.random.ranf() - 0.5) * 4 - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = (np.random.ranf() - 0.5) * 4 - self.other_scale_np = (np.random.ranf() - 0.5) * 4 - while self.other_scale_np < 0: - self.other_scale_np = (np.random.ranf() - 0.5) * 4 - self.values_np = np.random.ranf(1).astype('float32') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_loc = self.loc_np - self.dynamic_scale = self.scale_np - self.dynamic_other_loc = self.other_loc_np - self.dynamic_other_scale = self.other_scale_np - self.dynamic_values = paddle.to_tensor(self.values_np) - - def init_static_data(self, batch_size, dims): - self.static_loc = self.loc_np - self.static_scale = self.scale_np - self.static_other_loc = self.other_loc_np - self.static_other_scale = self.other_scale_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[], dtype='float32') - - def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6): - sample, entropy, log_prob, probs, kl = fetch_list - - np_normal = NormalNumpy(self.loc_np, self.scale_np) - np_sample = np_normal.sample([sample_shape]) - np_entropy = np_normal.entropy() - np_lp = np_normal.log_prob(self.values_np) - np_p = np_normal.probs(self.values_np) - np_other_normal = NormalNumpy(self.other_loc_np, self.other_scale_np) - np_kl = np_normal.kl_divergence(np_other_normal) - - # Because assign op does not support the input of numpy.ndarray whose dtype is FP64. - # When loc and scale are FP64 numpy.ndarray, we need to use assign op to convert it - # to FP32 Tensor. And then use cast op to convert it to a FP64 Tensor. - # There is a loss of accuracy in this conversion. - # So set the tolerance from 1e-6 to 1e-4. - log_tolerance = 1e-4 - - np.testing.assert_equal(sample.shape, np_sample.shape) - np.testing.assert_allclose( - entropy, np_entropy, rtol=tolerance, atol=tolerance) - np.testing.assert_allclose( - log_prob, np_lp, rtol=log_tolerance, atol=log_tolerance) - np.testing.assert_allclose( - probs, np_p, rtol=log_tolerance, atol=log_tolerance) - np.testing.assert_allclose( - kl, np_kl, rtol=log_tolerance, atol=log_tolerance) - - def test_normal_distribution_dygraph(self, sample_shape=7, tolerance=1e-6): - paddle.disable_static(self.place) - normal = Normal(self.dynamic_loc, self.dynamic_scale) - - sample = normal.sample([sample_shape]).numpy() - entropy = normal.entropy().numpy() - log_prob = normal.log_prob(self.dynamic_values).numpy() - probs = normal.probs(self.dynamic_values).numpy() - other_normal = Normal(self.dynamic_other_loc, self.dynamic_other_scale) - kl = normal.kl_divergence(other_normal).numpy() - - fetch_list = [sample, entropy, log_prob, probs, kl] - self.compare_with_numpy(fetch_list) - - def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6): - paddle.enable_static() - with fluid.program_guard(self.test_program): - normal = Normal(self.static_loc, self.static_scale) - - sample = normal.sample([sample_shape]) - entropy = normal.entropy() - log_prob = normal.log_prob(self.static_values) - probs = normal.probs(self.static_values) - other_normal = Normal(self.static_other_loc, - self.static_other_scale) - kl = normal.kl_divergence(other_normal) - - fetch_list = [sample, entropy, log_prob, probs, kl] - - feed_vars = { - 'loc': self.loc_np, - 'scale': self.scale_np, - 'values': self.values_np, - 'other_loc': self.other_loc_np, - 'other_scale': self.other_scale_np - } - - self.executor.run(fluid.default_startup_program()) - fetch_list = self.executor.run(program=self.test_program, - feed=feed_vars, - fetch_list=fetch_list) - - self.compare_with_numpy(fetch_list) - - -class NormalTest2(NormalTest): - def init_numpy_data(self, batch_size, dims): - # loc ans scale are 'int' - self.loc_np = int((np.random.ranf() - 0.5) * 8) - self.scale_np = int((np.random.ranf() - 0.5) * 8) - while self.scale_np < 0: - self.scale_np = int((np.random.ranf() - 0.5) * 8) - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = int((np.random.ranf() - 0.5) * 8) - self.other_scale_np = int((np.random.ranf() - 0.5) * 8) - while self.other_scale_np < 0: - self.other_scale_np = int((np.random.ranf() - 0.5) * 8) - self.values_np = np.random.ranf(1).astype('float32') - - -class NormalTest3(NormalTest): - def init_numpy_data(self, batch_size, dims): - # test broadcast: loc is float, scale is numpy.ndarray with dtype 'float32'. - self.loc_np = (np.random.ranf() - 0.5) * 4 - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = (np.random.ranf() - 0.5) * 4 - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - while not np.all(self.scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - - def init_static_data(self, batch_size, dims): - self.static_loc = self.loc_np - self.static_scale = self.scale_np - self.static_other_loc = self.other_loc_np - self.static_other_scale = self.other_scale_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class NormalTest4(NormalTest): - def init_numpy_data(self, batch_size, dims): - # loc and scale are numpy.ndarray with dtype 'float32'. - self.loc_np = np.random.randn(batch_size, dims).astype('float32') - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = np.random.randn(batch_size, dims).astype('float32') - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - while not np.all(self.scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - - def init_static_data(self, batch_size, dims): - self.static_loc = self.loc_np - self.static_scale = self.scale_np - self.static_other_loc = self.other_loc_np - self.static_other_scale = self.other_scale_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class NormalTest5(NormalTest): - def init_numpy_data(self, batch_size, dims): - # loc and scale are numpy.ndarray with dtype 'float64'. - self.loc_np = np.random.randn(batch_size, dims).astype('float64') - self.scale_np = np.random.randn(batch_size, dims).astype('float64') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float64') - self.values_np = np.random.randn(batch_size, dims).astype('float64') - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = np.random.randn(batch_size, dims).astype('float64') - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float64') - while not np.all(self.scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float64') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_loc = self.loc_np - self.dynamic_scale = self.scale_np - self.dynamic_other_loc = self.other_loc_np - self.dynamic_other_scale = self.other_scale_np - self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64') - - def init_static_data(self, batch_size, dims): - self.static_loc = self.loc_np - self.static_scale = self.scale_np - self.static_other_loc = self.other_loc_np - self.static_other_scale = self.other_scale_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float64') - - -class NormalTest6(NormalTest): - def init_numpy_data(self, batch_size, dims): - # loc and scale are Tensor with dtype 'VarType.FP32'. - self.loc_np = np.random.randn(batch_size, dims).astype('float32') - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = np.random.randn(batch_size, dims).astype('float32') - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - while not np.all(self.scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_loc = paddle.to_tensor(self.loc_np) - self.dynamic_scale = paddle.to_tensor(self.scale_np) - self.dynamic_values = paddle.to_tensor(self.values_np) - self.dynamic_other_loc = paddle.to_tensor(self.other_loc_np) - self.dynamic_other_scale = paddle.to_tensor(self.other_scale_np) - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.static_loc = layers.data( - name='loc', shape=[dims], dtype='float32') - self.static_scale = layers.data( - name='scale', shape=[dims], dtype='float32') - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - self.static_other_loc = layers.data( - name='other_loc', shape=[dims], dtype='float32') - self.static_other_scale = layers.data( - name='other_scale', shape=[dims], dtype='float32') - - -class NormalTest7(NormalTest): - def init_numpy_data(self, batch_size, dims): - # loc and scale are Tensor with dtype 'VarType.FP64'. - self.loc_np = np.random.randn(batch_size, dims).astype('float64') - self.scale_np = np.random.randn(batch_size, dims).astype('float64') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float64') - self.values_np = np.random.randn(batch_size, dims).astype('float64') - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = np.random.randn(batch_size, dims).astype('float64') - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float64') - while not np.all(self.scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float64') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_loc = paddle.to_tensor(self.loc_np, dtype='float64') - self.dynamic_scale = paddle.to_tensor(self.scale_np, dtype='float64') - self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64') - self.dynamic_other_loc = paddle.to_tensor( - self.other_loc_np, dtype='float64') - self.dynamic_other_scale = paddle.to_tensor( - self.other_scale_np, dtype='float64') - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.static_loc = layers.data( - name='loc', shape=[dims], dtype='float64') - self.static_scale = layers.data( - name='scale', shape=[dims], dtype='float64') - self.static_values = layers.data( - name='values', shape=[dims], dtype='float64') - self.static_other_loc = layers.data( - name='other_loc', shape=[dims], dtype='float64') - self.static_other_scale = layers.data( - name='other_scale', shape=[dims], dtype='float64') - - -class NormalTest8(NormalTest): - def init_numpy_data(self, batch_size, dims): - # loc and scale are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'. - self.loc_np = np.random.randn(batch_size, dims).astype('float64') - self.scale_np = np.random.randn(batch_size, dims).astype('float64') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float64') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = np.random.randn(batch_size, dims).astype('float64') - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float64') - while not np.all(self.scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float64') - - def init_dynamic_data(self, batch_size, dims): - self.dynamic_loc = paddle.to_tensor(self.loc_np, dtype='float64') - self.dynamic_scale = paddle.to_tensor(self.scale_np, dtype='float64') - self.dynamic_values = paddle.to_tensor(self.values_np) - self.dynamic_other_loc = paddle.to_tensor( - self.other_loc_np, dtype='float64') - self.dynamic_other_scale = paddle.to_tensor( - self.other_scale_np, dtype='float64') - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.static_loc = layers.data( - name='loc', shape=[dims], dtype='float64') - self.static_scale = layers.data( - name='scale', shape=[dims], dtype='float64') - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - self.static_other_loc = layers.data( - name='other_loc', shape=[dims], dtype='float64') - self.static_other_scale = layers.data( - name='other_scale', shape=[dims], dtype='float64') - - -class NormalTest9(NormalTest): - def init_numpy_data(self, batch_size, dims): - # loc and scale are list. - self.loc_np = np.random.randn(batch_size, - dims).astype('float32').tolist() - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - self.scale_np = self.scale_np.tolist() - self.values_np = np.random.randn(batch_size, dims).astype('float32') - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = np.random.randn(batch_size, - dims).astype('float32').tolist() - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - while not np.all(self.other_scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - self.other_scale_np = self.other_scale_np.tolist() - - def init_static_data(self, batch_size, dims): - self.static_loc = self.loc_np - self.static_scale = self.scale_np - self.static_other_loc = self.other_loc_np - self.static_other_scale = self.other_scale_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class NormalTest10(NormalTest): - def init_numpy_data(self, batch_size, dims): - # loc and scale are tuple. - self.loc_np = tuple( - np.random.randn(batch_size, dims).astype('float32').tolist()) - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - self.scale_np = tuple(self.scale_np.tolist()) - self.values_np = np.random.randn(batch_size, dims).astype('float32') - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = tuple( - np.random.randn(batch_size, dims).astype('float32').tolist()) - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - while not np.all(self.other_scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - self.other_scale_np = tuple(self.other_scale_np.tolist()) - - def init_static_data(self, batch_size, dims): - self.static_loc = self.loc_np - self.static_scale = self.scale_np - self.static_other_loc = self.other_loc_np - self.static_other_scale = self.other_scale_np - with fluid.program_guard(self.test_program): - self.static_values = layers.data( - name='values', shape=[dims], dtype='float32') - - -class CategoricalNumpy(DistributionNumpy): - def __init__(self, logits): - self.logits = np.array(logits).astype('float32') - - def entropy(self): - logits = self.logits - np.max(self.logits, axis=-1, keepdims=True) - e_logits = np.exp(logits) - z = np.sum(e_logits, axis=-1, keepdims=True) - prob = e_logits / z - return -1. * np.sum(prob * (logits - np.log(z)), axis=-1, keepdims=True) - - def kl_divergence(self, other): - logits = self.logits - np.max(self.logits, axis=-1, keepdims=True) - other_logits = other.logits - np.max( - other.logits, axis=-1, keepdims=True) - e_logits = np.exp(logits) - other_e_logits = np.exp(other_logits) - z = np.sum(e_logits, axis=-1, keepdims=True) - other_z = np.sum(other_e_logits, axis=-1, keepdims=True) - prob = e_logits / z - return np.sum(prob * (logits - np.log(z) - other_logits \ - + np.log(other_z)), axis=-1, keepdims=True) - - -class CategoricalTest(unittest.TestCase): - def setUp(self, use_gpu=False, batch_size=3, dims=5): - self.use_gpu = use_gpu - if not use_gpu: - self.place = fluid.CPUPlace() - self.gpu_id = -1 - else: - self.place = fluid.CUDAPlace(0) - self.gpu_id = 0 - - self.batch_size = batch_size - self.dims = dims - self.init_numpy_data(batch_size, dims) - - paddle.disable_static(self.place) - self.init_dynamic_data(batch_size, dims) - - paddle.enable_static() - self.test_program = fluid.Program() - self.executor = fluid.Executor(self.place) - self.init_static_data(batch_size, dims) - - def init_numpy_data(self, batch_size, dims): - # input logtis is 2-D Tensor - # value used in probs and log_prob method is 1-D Tensor - self.logits_np = np.random.rand(batch_size, dims).astype('float32') - self.other_logits_np = np.random.rand(batch_size, - dims).astype('float32') - self.value_np = np.array([2, 1, 3]).astype('int64') - - self.logits_shape = [batch_size, dims] - # dist_shape = logits_shape[:-1], it represents the number of - # different distributions. - self.dist_shape = [batch_size] - # sample shape represents the number of samples - self.sample_shape = [2, 4] - # value used in probs and log_prob method - # If value is 1-D and logits is 2-D or higher dimension, value will be - # broadcasted to have the same number of distributions with logits. - # If value is 2-D or higher dimentsion, it should have the same number - # of distributions with logtis. ``value[:-1] = logits[:-1] - self.value_shape = [3] - - def init_dynamic_data(self, batch_size, dims): - self.logits = paddle.to_tensor(self.logits_np) - self.other_logits = paddle.to_tensor(self.other_logits_np) - self.value = paddle.to_tensor(self.value_np) - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.logits_static = fluid.data( - name='logits', shape=self.logits_shape, dtype='float32') - self.other_logits_static = fluid.data( - name='other_logits', shape=self.logits_shape, dtype='float32') - self.value_static = fluid.data( - name='value', shape=self.value_shape, dtype='int64') - - def get_numpy_selected_probs(self, probability): - np_probs = np.zeros(self.dist_shape + self.value_shape) - for i in range(self.batch_size): - for j in range(3): - np_probs[i][j] = probability[i][self.value_np[j]] - return np_probs - - def compare_with_numpy(self, fetch_list, tolerance=1e-6): - sample, entropy, kl, probs, log_prob = fetch_list - log_tolerance = 1e-4 - - np.testing.assert_equal(sample.shape, - self.sample_shape + self.dist_shape) - - np_categorical = CategoricalNumpy(self.logits_np) - np_other_categorical = CategoricalNumpy(self.other_logits_np) - np_entropy = np_categorical.entropy() - np_kl = np_categorical.kl_divergence(np_other_categorical) - - np.testing.assert_allclose( - entropy, np_entropy, rtol=log_tolerance, atol=log_tolerance) - np.testing.assert_allclose( - kl, np_kl, rtol=log_tolerance, atol=log_tolerance) - - sum_dist = np.sum(self.logits_np, axis=-1, keepdims=True) - probability = self.logits_np / sum_dist - np_probs = self.get_numpy_selected_probs(probability) - np_log_prob = np.log(np_probs) - - np.testing.assert_allclose( - probs, np_probs, rtol=tolerance, atol=tolerance) - np.testing.assert_allclose( - log_prob, np_log_prob, rtol=tolerance, atol=tolerance) - - def test_categorical_distribution_dygraph(self, tolerance=1e-6): - paddle.disable_static(self.place) - categorical = Categorical(self.logits) - other_categorical = Categorical(self.other_logits) - - sample = categorical.sample(self.sample_shape).numpy() - entropy = categorical.entropy().numpy() - kl = categorical.kl_divergence(other_categorical).numpy() - probs = categorical.probs(self.value).numpy() - log_prob = categorical.log_prob(self.value).numpy() - - fetch_list = [sample, entropy, kl, probs, log_prob] - self.compare_with_numpy(fetch_list) - - def test_categorical_distribution_static(self, tolerance=1e-6): - paddle.enable_static() - with fluid.program_guard(self.test_program): - categorical = Categorical(self.logits_static) - other_categorical = Categorical(self.other_logits_static) - - sample = categorical.sample(self.sample_shape) - entropy = categorical.entropy() - kl = categorical.kl_divergence(other_categorical) - probs = categorical.probs(self.value_static) - log_prob = categorical.log_prob(self.value_static) - - fetch_list = [sample, entropy, kl, probs, log_prob] - - feed_vars = { - 'logits': self.logits_np, - 'other_logits': self.other_logits_np, - 'value': self.value_np - } - - self.executor.run(fluid.default_startup_program()) - fetch_list = self.executor.run(program=self.test_program, - feed=feed_vars, - fetch_list=fetch_list) - - self.compare_with_numpy(fetch_list) - - -class CategoricalTest2(CategoricalTest): - def init_numpy_data(self, batch_size, dims): - # input logtis is 2-D Tensor with dtype Float64 - # value used in probs and log_prob method is 1-D Tensor - self.logits_np = np.random.rand(batch_size, dims).astype('float64') - self.other_logits_np = np.random.rand(batch_size, - dims).astype('float64') - self.value_np = np.array([2, 1, 3]).astype('int64') - - self.logits_shape = [batch_size, dims] - self.dist_shape = [batch_size] - self.sample_shape = [2, 4] - self.value_shape = [3] - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.logits_static = fluid.data( - name='logits', shape=self.logits_shape, dtype='float64') - self.other_logits_static = fluid.data( - name='other_logits', shape=self.logits_shape, dtype='float64') - self.value_static = fluid.data( - name='value', shape=self.value_shape, dtype='int64') - - -class CategoricalTest3(CategoricalTest): - def init_dynamic_data(self, batch_size, dims): - # input logtis is 2-D numpy.ndarray with dtype Float32 - # value used in probs and log_prob method is 1-D Tensor - self.logits = self.logits_np - self.other_logits = self.other_logits_np - self.value = paddle.to_tensor(self.value_np) - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.logits_static = self.logits_np - self.other_logits_static = self.other_logits_np - self.value_static = fluid.data( - name='value', shape=self.value_shape, dtype='int64') - - -class CategoricalTest4(CategoricalTest): - def init_numpy_data(self, batch_size, dims): - # input logtis is 2-D numpy.ndarray with dtype Float64 - # value used in probs and log_prob method is 1-D Tensor - self.logits_np = np.random.rand(batch_size, dims).astype('float64') - self.other_logits_np = np.random.rand(batch_size, - dims).astype('float64') - self.value_np = np.array([2, 1, 3]).astype('int64') - - self.logits_shape = [batch_size, dims] - self.dist_shape = [batch_size] - self.sample_shape = [2, 4] - self.value_shape = [3] - - def init_dynamic_data(self, batch_size, dims): - self.logits = self.logits_np - self.other_logits = self.other_logits_np - self.value = paddle.to_tensor(self.value_np) - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.logits_static = self.logits_np - self.other_logits_static = self.other_logits_np - self.value_static = fluid.data( - name='value', shape=self.value_shape, dtype='int64') - - -# test shape of logits and value used in probs and log_prob method -class CategoricalTest5(CategoricalTest): - def init_numpy_data(self, batch_size, dims): - # input logtis is 1-D Tensor - # value used in probs and log_prob method is 1-D Tensor - self.logits_np = np.random.rand(dims).astype('float32') - self.other_logits_np = np.random.rand(dims).astype('float32') - self.value_np = np.array([2, 1, 3]).astype('int64') - - self.logits_shape = [dims] - self.dist_shape = [] - self.sample_shape = [2, 4] - self.value_shape = [3] - - def get_numpy_selected_probs(self, probability): - np_probs = np.zeros(self.value_shape) - for i in range(3): - np_probs[i] = probability[self.value_np[i]] - return np_probs - - -class CategoricalTest6(CategoricalTest): - def init_numpy_data(self, batch_size, dims): - # input logtis is 2-D Tensor - # value used in probs and log_prob method has the same number of batches with input - self.logits_np = np.random.rand(3, 5).astype('float32') - self.other_logits_np = np.random.rand(3, 5).astype('float32') - self.value_np = np.array([[2, 1], [0, 3], [2, 3]]).astype('int64') - - self.logits_shape = [3, 5] - self.dist_shape = [3] - self.sample_shape = [2, 4] - self.value_shape = [3, 2] - - def get_numpy_selected_probs(self, probability): - np_probs = np.zeros(self.value_shape) - for i in range(3): - for j in range(2): - np_probs[i][j] = probability[i][self.value_np[i][j]] - return np_probs - - -class CategoricalTest7(CategoricalTest): - def init_numpy_data(self, batch_size, dims): - # input logtis is 3-D Tensor - # value used in probs and log_prob method has the same number of distribuions with input - self.logits_np = np.random.rand(3, 2, 5).astype('float32') - self.other_logits_np = np.random.rand(3, 2, 5).astype('float32') - self.value_np = np.array([2, 1, 3]).astype('int64') - - self.logits_shape = [3, 2, 5] - self.dist_shape = [3, 2] - self.sample_shape = [2, 4] - self.value_shape = [3] - - def get_numpy_selected_probs(self, probability): - np_probs = np.zeros(self.dist_shape + self.value_shape) - for i in range(3): - for j in range(2): - for k in range(3): - np_probs[i][j][k] = probability[i][j][self.value_np[k]] - return np_probs - - -class CategoricalTest8(CategoricalTest): - def init_dynamic_data(self, batch_size, dims): - # input logtis is 2-D list - # value used in probs and log_prob method is 1-D Tensor - self.logits = self.logits_np.tolist() - self.other_logits = self.other_logits_np.tolist() - self.value = paddle.to_tensor(self.value_np) - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.logits_static = self.logits_np.tolist() - self.other_logits_static = self.other_logits_np.tolist() - self.value_static = fluid.data( - name='value', shape=self.value_shape, dtype='int64') - - -class CategoricalTest9(CategoricalTest): - def init_dynamic_data(self, batch_size, dims): - # input logtis is 2-D tuple - # value used in probs and log_prob method is 1-D Tensor - self.logits = tuple(self.logits_np.tolist()) - self.other_logits = tuple(self.other_logits_np.tolist()) - self.value = paddle.to_tensor(self.value_np) - - def init_static_data(self, batch_size, dims): - with fluid.program_guard(self.test_program): - self.logits_static = tuple(self.logits_np.tolist()) - self.other_logits_static = tuple(self.other_logits_np.tolist()) - self.value_static = fluid.data( - name='value', shape=self.value_shape, dtype='int64') - - -class DistributionTestError(unittest.TestCase): - def test_distribution_error(self): - distribution = Distribution() - - self.assertRaises(NotImplementedError, distribution.sample) - self.assertRaises(NotImplementedError, distribution.entropy) - - normal = Normal(0.0, 1.0) - self.assertRaises(NotImplementedError, distribution.kl_divergence, - normal) - - value_npdata = np.array([0.8], dtype="float32") - value_tensor = layers.create_tensor(dtype="float32") - self.assertRaises(NotImplementedError, distribution.log_prob, - value_tensor) - self.assertRaises(NotImplementedError, distribution.probs, value_tensor) - - def test_normal_error(self): - paddle.enable_static() - normal = Normal(0.0, 1.0) - - value = [1.0, 2.0] - # type of value must be variable - self.assertRaises(TypeError, normal.log_prob, value) - - value = [1.0, 2.0] - # type of value must be variable - self.assertRaises(TypeError, normal.probs, value) - - shape = 1.0 - # type of shape must be list - self.assertRaises(TypeError, normal.sample, shape) - - seed = 1.0 - # type of seed must be int - self.assertRaises(TypeError, normal.sample, [2, 3], seed) - - normal_other = Uniform(1.0, 2.0) - # type of other must be an instance of Normal - self.assertRaises(TypeError, normal.kl_divergence, normal_other) - - def test_uniform_error(self): - paddle.enable_static() - uniform = Uniform(0.0, 1.0) - - value = [1.0, 2.0] - # type of value must be variable - self.assertRaises(TypeError, uniform.log_prob, value) - - value = [1.0, 2.0] - # type of value must be variable - self.assertRaises(TypeError, uniform.probs, value) - - shape = 1.0 - # type of shape must be list - self.assertRaises(TypeError, uniform.sample, shape) - - seed = 1.0 - # type of seed must be int - self.assertRaises(TypeError, uniform.sample, [2, 3], seed) - - def test_categorical_error(self): - paddle.enable_static() - - categorical = Categorical([0.4, 0.6]) - - value = [1, 0] - # type of value must be variable - self.assertRaises(AttributeError, categorical.log_prob, value) - - value = [1, 0] - # type of value must be variable - self.assertRaises(AttributeError, categorical.probs, value) - - shape = 1.0 - # type of shape must be list - self.assertRaises(TypeError, categorical.sample, shape) - - categorical_other = Uniform(1.0, 2.0) - # type of other must be an instance of Categorical - self.assertRaises(TypeError, categorical.kl_divergence, - categorical_other) - - def test_shape_not_match_error(): - # shape of value must match shape of logits - # value_shape[:-1] == logits_shape[:-1] - paddle.disable_static() - logits = paddle.rand([3, 5]) - cat = Categorical(logits) - value = paddle.to_tensor([[2, 1, 3], [3, 2, 1]], dtype='int64') - cat.log_prob(value) - - self.assertRaises(ValueError, test_shape_not_match_error) - - -class DistributionTestName(unittest.TestCase): - def get_prefix(self, string): - return (string.split('.')[0]) - - def test_normal_name(self): - name = 'test_normal' - normal1 = Normal(0.0, 1.0, name=name) - self.assertEqual(normal1.name, name) - - normal2 = Normal(0.0, 1.0) - self.assertEqual(normal2.name, 'Normal') - - paddle.enable_static() - - sample = normal1.sample([2]) - self.assertEqual(self.get_prefix(sample.name), name + '_sample') - - entropy = normal1.entropy() - self.assertEqual(self.get_prefix(entropy.name), name + '_entropy') - - value_npdata = np.array([0.8], dtype="float32") - value_tensor = layers.create_tensor(dtype="float32") - layers.assign(value_npdata, value_tensor) - - lp = normal1.log_prob(value_tensor) - self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') - - p = normal1.probs(value_tensor) - self.assertEqual(self.get_prefix(p.name), name + '_probs') - - kl = normal1.kl_divergence(normal2) - self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence') - - def test_uniform_name(self): - name = 'test_uniform' - uniform1 = Uniform(0.0, 1.0, name=name) - self.assertEqual(uniform1.name, name) - - uniform2 = Uniform(0.0, 1.0) - self.assertEqual(uniform2.name, 'Uniform') - - paddle.enable_static() - - sample = uniform1.sample([2]) - self.assertEqual(self.get_prefix(sample.name), name + '_sample') - - entropy = uniform1.entropy() - self.assertEqual(self.get_prefix(entropy.name), name + '_entropy') - - value_npdata = np.array([0.8], dtype="float32") - value_tensor = layers.create_tensor(dtype="float32") - layers.assign(value_npdata, value_tensor) - - lp = uniform1.log_prob(value_tensor) - self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') - - p = uniform1.probs(value_tensor) - self.assertEqual(self.get_prefix(p.name), name + '_probs') - - def test_categorical_name(self): - name = 'test_categorical' - categorical1 = Categorical([0.4, 0.6], name=name) - self.assertEqual(categorical1.name, name) - - categorical2 = Categorical([0.5, 0.5]) - self.assertEqual(categorical2.name, 'Categorical') - - paddle.enable_static() - - sample = categorical1.sample([2]) - self.assertEqual(self.get_prefix(sample.name), name + '_sample') - - entropy = categorical1.entropy() - self.assertEqual(self.get_prefix(entropy.name), name + '_entropy') - - kl = categorical1.kl_divergence(categorical2) - self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence') - - value_npdata = np.array([0], dtype="int64") - value_tensor = layers.create_tensor(dtype="int64") - layers.assign(value_npdata, value_tensor) - - p = categorical1.probs(value_tensor) - self.assertEqual(self.get_prefix(p.name), name + '_probs') - - lp = categorical1.log_prob(value_tensor) - self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') - - -if __name__ == '__main__': - unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index e286b9cc735..f14111c7dab 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -276,6 +276,7 @@ packages=['paddle', 'paddle.incubate.tensor', 'paddle.incubate.nn', 'paddle.incubate.passes', + 'paddle.distribution', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.elastic', -- GitLab