未验证 提交 a3e6f18c 编写于 作者: X Xiaoxu Chen 提交者: GitHub

move distribution.py into distribution package and split into different file...

move distribution.py into distribution package and split into different file for better scalability (#38047)
上级 7da5368d
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .categorical import Categorical
from .distribution import Distribution
from .normal import Normal
from .uniform import Uniform
__all__ = ['Categorical', 'Distribution', 'Normal', 'Uniform']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,631 +12,20 @@ ...@@ -12,631 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO: define the distribution functions
# __all__ = ['Categorical',
# 'MultivariateNormalDiag',
# 'Normal',
# 'sampling_id',
# 'Uniform']
from __future__ import print_function
from .fluid.layers import control_flow
from .fluid.layers import tensor
from .fluid.layers import ops
from .fluid.layers import nn
from .fluid.layers import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub
from .fluid import core
from .fluid.framework import in_dygraph_mode
from .tensor import arange, gather_nd, concat, multinomial
import math import math
import numpy as np
import warnings import warnings
from .fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype import numpy as np
from paddle import _C_ops from paddle import _C_ops
__all__ = ['Distribution', 'Uniform', 'Normal', 'Categorical'] from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
class Distribution(object): from ..fluid.framework import in_dygraph_mode
""" from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
The abstract base class for probability distributions. Functions are elementwise_mul, elementwise_sub, nn, ops, tensor)
implemented in specific distributions. from ..tensor import arange, concat, gather_nd, multinomial
""" from .distribution import Distribution
def __init__(self):
super(Distribution, self).__init__()
def sample(self):
"""Sampling from the distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the distribution."""
raise NotImplementedError
def kl_divergence(self, other):
"""The KL-divergence between self distributions and other."""
raise NotImplementedError
def log_prob(self, value):
"""Log probability density/mass function."""
raise NotImplementedError
def probs(self, value):
"""Probability density/mass function."""
raise NotImplementedError
def _validate_args(self, *args):
"""
Argument validation for distribution args
Args:
value (float, list, numpy.ndarray, Tensor)
Raises
ValueError: if one argument is Tensor, all arguments should be Tensor
"""
is_variable = False
is_number = False
for arg in args:
if isinstance(arg, tensor.Variable):
is_variable = True
else:
is_number = True
if is_variable and is_number:
raise ValueError(
'if one argument is Tensor, all arguments should be Tensor')
return is_variable
def _to_tensor(self, *args):
"""
Argument convert args to Tensor
Args:
value (float, list, numpy.ndarray, Tensor)
Returns:
Tensor of args.
"""
numpy_args = []
variable_args = []
tmp = 0.
for arg in args:
if isinstance(arg, float):
arg = [arg]
if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".
format(type(arg)))
arg_np = np.array(arg)
arg_dtype = arg_np.dtype
if str(arg_dtype) != 'float32':
if str(arg_dtype) != 'float64':
# "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
# and converted to float64 later using "cast".
warnings.warn(
"data type of argument only support float32 and float64, your argument will be convert to float32."
)
arg_np = arg_np.astype('float32')
# tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
tmp = tmp + arg_np
numpy_args.append(arg_np)
dtype = tmp.dtype
for arg in numpy_args:
arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
arg_variable = tensor.create_tensor(dtype=dtype)
tensor.assign(arg_broadcasted, arg_variable)
variable_args.append(arg_variable)
return tuple(variable_args)
def _check_values_dtype_in_probs(self, param, value):
"""
Log_prob and probs methods have input ``value``, if value's dtype is different from param,
convert value's dtype to be consistent with param's dtype.
Args:
param (Tensor): low and high in Uniform class, loc and scale in Normal class.
value (Tensor): The input tensor.
Returns:
value (Tensor): Change value's dtype if value's dtype is different from param.
"""
if in_dygraph_mode():
if value.dtype != param.dtype and convert_dtype(
value.dtype) in ['float32', 'float64']:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return _C_ops.cast(value, 'in_dtype', value.dtype, 'out_dtype',
param.dtype)
return value
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
if value.dtype != param.dtype:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return tensor.cast(value, dtype=param.dtype)
return value
class Uniform(Distribution):
r"""Uniform distribution with `low` and `high` parameters.
Mathematical Details
The probability density function (pdf) is
.. math::
pdf(x; a, b) = \\frac{1}{Z}, \ a <=x <b
.. math::
Z = b - a
In the above equation:
* :math:`low = a`,
* :math:`high = b`,
* :math:`Z`: is the normalizing constant.
The parameters `low` and `high` must be shaped in a way that supports
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args:
low(int|float|list|tuple|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
high(int|float|list|tuple|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Uniform
# Without broadcasting, a single uniform distribution [3, 4]:
u1 = Uniform(low=3.0, high=4.0)
# 2 distributions [1, 3], [2, 4]
u2 = Uniform(low=[1.0, 2.0], high=[3.0, 4.0])
# 4 distributions
u3 = Uniform(low=[[1.0, 2.0], [3.0, 4.0]],
high=[[1.5, 2.5], [3.5, 4.5]])
# With broadcasting:
u4 = Uniform(low=3.0, high=[5.0, 6.0, 7.0])
# Complete example
value_tensor = paddle.to_tensor([0.8], dtype="float32")
uniform = Uniform([0.], [2.])
sample = uniform.sample([2])
# a random tensor created by uniform distribution with shape: [2, 1]
entropy = uniform.entropy()
# [0.6931472] with shape: [1]
lp = uniform.log_prob(value_tensor)
# [-0.6931472] with shape: [1]
p = uniform.probs(value_tensor)
# [0.5] with shape: [1]
"""
def __init__(self, low, high, name=None):
if not in_dygraph_mode():
check_type(low, 'low',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Uniform')
check_type(high, 'high',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Uniform')
self.all_arg_is_float = False
self.batch_size_unknown = False
self.name = name if name is not None else 'Uniform'
self.dtype = 'float32'
if isinstance(low, int):
low = float(low)
if isinstance(high, int):
high = float(high)
if self._validate_args(low, high):
self.batch_size_unknown = True
self.low = low
self.high = high
self.dtype = convert_dtype(low.dtype)
else:
if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True
if isinstance(
low,
np.ndarray) and str(low.dtype) in ['float32', 'float64']:
self.dtype = low.dtype
elif isinstance(
high,
np.ndarray) and str(high.dtype) in ['float32', 'float64']:
self.dtype = high.dtype
self.low, self.high = self._to_tensor(low, high)
if self.dtype != convert_dtype(self.low.dtype):
self.low = tensor.cast(self.low, dtype=self.dtype)
self.high = tensor.cast(self.high, dtype=self.dtype)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')
name = self.name + '_sample'
batch_shape = list((self.low + self.high).shape)
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp,
zero_tmp.shape,
dtype=self.dtype,
min=0.,
max=1.,
seed=seed)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
uniform_random_tmp_reshape = nn.reshape(uniform_random_tmp,
output_shape)
output = uniform_random_tmp_reshape * (
zero_tmp_reshape + self.high - self.low)
output = elementwise_add(output, self.low, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.uniform_random(
output_shape, dtype=self.dtype, min=0., max=1.,
seed=seed) * (tensor.zeros(
output_shape, dtype=self.dtype) + (self.high - self.low))
output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode():
# ensure value in [low, high]
lb_bool = self.low < value
ub_bool = value < self.high
lb = _C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
value.dtype)
ub = _C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low)
name = self.name + '_log_prob'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub(
nn.log(lb * ub), nn.log(self.high - self.low), name=name)
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
lb = _C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
value.dtype)
ub = _C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
value.dtype)
return (lb * ub) / (self.high - self.low)
name = self.name + '_probs'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name)
def entropy(self):
r"""Shannon entropy in nats.
The entropy is
.. math::
entropy(low, high) = \\log (high - low)
Returns:
Tensor: Shannon entropy of uniform distribution.The data type is float32.
"""
name = self.name + '_entropy'
return nn.log(self.high - self.low, name=name)
class Normal(Distribution):
r"""The Normal distribution with location `loc` and `scale` parameters.
Mathematical details
The probability density function (pdf) is
.. math::
pdf(x; \mu, \sigma) = \\frac{1}{Z}e^{\\frac {-0.5 (x - \mu)^2} {\sigma^2} }
.. math::
Z = (2 \pi \sigma^2)^{0.5}
In the above equation:
* :math:`loc = \mu`: is the mean.
* :math:`scale = \sigma`: is the std.
* :math:`Z`: is the normalization constant.
Args:
loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Normal
# Define a single scalar Normal distribution.
dist = Normal(loc=0., scale=3.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = Normal(loc=[1., 2.], scale=[11., 22.])
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample([3])
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
dist = Normal(loc=1., scale=[11., 22.])
# Complete example
value_tensor = paddle.to_tensor([0.8], dtype="float32")
normal_a = Normal([0.], [1.])
normal_b = Normal([0.5], [2.])
sample = normal_a.sample([2])
# a random tensor created by normal distribution with shape: [2, 1]
entropy = normal_a.entropy()
# [1.4189385] with shape: [1]
lp = normal_a.log_prob(value_tensor)
# [-1.2389386] with shape: [1]
p = normal_a.probs(value_tensor)
# [0.28969154] with shape: [1]
kl = normal_a.kl_divergence(normal_b)
# [0.34939718] with shape: [1]
"""
def __init__(self, loc, scale, name=None):
if not in_dygraph_mode():
check_type(loc, 'loc',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Normal')
check_type(scale, 'scale',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Normal')
self.batch_size_unknown = False
self.all_arg_is_float = False
self.name = name if name is not None else 'Normal'
self.dtype = 'float32'
if isinstance(loc, int):
loc = float(loc)
if isinstance(scale, int):
scale = float(scale)
if self._validate_args(loc, scale):
self.batch_size_unknown = True
self.loc = loc
self.scale = scale
self.dtype = convert_dtype(loc.dtype)
else:
if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True
if isinstance(
loc,
np.ndarray) and str(loc.dtype) in ['float32', 'float64']:
self.dtype = loc.dtype
elif isinstance(
scale,
np.ndarray) and str(scale.dtype) in ['float32', 'float64']:
self.dtype = scale.dtype
self.loc, self.scale = self._to_tensor(loc, scale)
if self.dtype != convert_dtype(self.loc.dtype):
self.loc = tensor.cast(self.loc, dtype=self.dtype)
self.scale = tensor.cast(self.scale, dtype=self.dtype)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')
batch_shape = list((self.loc + self.scale).shape)
name = self.name + '_sample'
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.dtype, 0.)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
zero_tmp_shape = nn.shape(zero_tmp_reshape)
normal_random_tmp = nn.gaussian_random(
zero_tmp_shape, mean=0., std=1., seed=seed, dtype=self.dtype)
output = normal_random_tmp * (zero_tmp_reshape + self.scale)
output = elementwise_add(output, self.loc, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \
(tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output
def entropy(self):
r"""Shannon entropy in nats.
The entropy is
.. math::
entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2)
In the above equation:
* :math:`scale = \sigma`: is the std.
Returns:
Tensor: Shannon entropy of normal distribution.The data type is float32.
"""
name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.dtype, 0.)
return elementwise_add(
0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)),
name=name)
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale
log_scale = nn.log(self.scale)
return elementwise_sub(
-1. * ((value - self.loc) * (value - self.loc)) / (2. * var),
log_scale + math.log(math.sqrt(2. * math.pi)),
name=name)
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale
return elementwise_div(
ops.exp(-1. * ((value - self.loc) * (value - self.loc)) /
(2. * var)), (math.sqrt(2 * math.pi) * self.scale),
name=name)
def kl_divergence(self, other):
r"""The KL-divergence between two normal distributions.
The probability density function (pdf) is
.. math::
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio})
.. math::
ratio = \\frac{\sigma_0}{\sigma_1}
.. math::
diff = \mu_1 - \mu_0
In the above equation:
* :math:`loc = \mu_0`: is the mean of current Normal distribution.
* :math:`scale = \sigma_0`: is the std of current Normal distribution.
* :math:`loc = \mu_1`: is the mean of other Normal distribution.
* :math:`scale = \sigma_1`: is the std of other Normal distribution.
* :math:`ratio`: is the ratio of scales.
* :math:`diff`: is the difference between means.
Args:
other (Normal): instance of Normal.
Returns:
Tensor: kl-divergence between two normal distributions.The data type is float32.
"""
if not in_dygraph_mode():
check_type(other, 'other', Normal, 'kl_divergence')
name = self.name + '_kl_divergence'
var_ratio = self.scale / other.scale
var_ratio = (var_ratio * var_ratio)
t1 = (self.loc - other.loc) / other.scale
t1 = (t1 * t1)
return elementwise_add(
0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name)
class Categorical(Distribution): class Categorical(Distribution):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define the distribution functions
# __all__ = ['Categorical',
# 'MultivariateNormalDiag',
# 'Normal',
# 'sampling_id',
# 'Uniform']
from __future__ import print_function
import math
import warnings
import numpy as np
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import in_dygraph_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
class Distribution(object):
"""
The abstract base class for probability distributions. Functions are
implemented in specific distributions.
"""
def __init__(self):
super(Distribution, self).__init__()
def sample(self):
"""Sampling from the distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the distribution."""
raise NotImplementedError
def kl_divergence(self, other):
"""The KL-divergence between self distributions and other."""
raise NotImplementedError
def log_prob(self, value):
"""Log probability density/mass function."""
raise NotImplementedError
def probs(self, value):
"""Probability density/mass function."""
raise NotImplementedError
def _validate_args(self, *args):
"""
Argument validation for distribution args
Args:
value (float, list, numpy.ndarray, Tensor)
Raises
ValueError: if one argument is Tensor, all arguments should be Tensor
"""
is_variable = False
is_number = False
for arg in args:
if isinstance(arg, tensor.Variable):
is_variable = True
else:
is_number = True
if is_variable and is_number:
raise ValueError(
'if one argument is Tensor, all arguments should be Tensor')
return is_variable
def _to_tensor(self, *args):
"""
Argument convert args to Tensor
Args:
value (float, list, numpy.ndarray, Tensor)
Returns:
Tensor of args.
"""
numpy_args = []
variable_args = []
tmp = 0.
for arg in args:
if isinstance(arg, float):
arg = [arg]
if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".
format(type(arg)))
arg_np = np.array(arg)
arg_dtype = arg_np.dtype
if str(arg_dtype) != 'float32':
if str(arg_dtype) != 'float64':
# "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
# and converted to float64 later using "cast".
warnings.warn(
"data type of argument only support float32 and float64, your argument will be convert to float32."
)
arg_np = arg_np.astype('float32')
# tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
tmp = tmp + arg_np
numpy_args.append(arg_np)
dtype = tmp.dtype
for arg in numpy_args:
arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
arg_variable = tensor.create_tensor(dtype=dtype)
tensor.assign(arg_broadcasted, arg_variable)
variable_args.append(arg_variable)
return tuple(variable_args)
def _check_values_dtype_in_probs(self, param, value):
"""
Log_prob and probs methods have input ``value``, if value's dtype is different from param,
convert value's dtype to be consistent with param's dtype.
Args:
param (Tensor): low and high in Uniform class, loc and scale in Normal class.
value (Tensor): The input tensor.
Returns:
value (Tensor): Change value's dtype if value's dtype is different from param.
"""
if in_dygraph_mode():
if value.dtype != param.dtype and convert_dtype(
value.dtype) in ['float32', 'float64']:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return _C_ops.cast(value, 'in_dtype', value.dtype, 'out_dtype',
param.dtype)
return value
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
if value.dtype != param.dtype:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return tensor.cast(value, dtype=param.dtype)
return value
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
import numpy as np
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import in_dygraph_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution
class Normal(Distribution):
r"""The Normal distribution with location `loc` and `scale` parameters.
Mathematical details
The probability density function (pdf) is
.. math::
pdf(x; \mu, \sigma) = \\frac{1}{Z}e^{\\frac {-0.5 (x - \mu)^2} {\sigma^2} }
.. math::
Z = (2 \pi \sigma^2)^{0.5}
In the above equation:
* :math:`loc = \mu`: is the mean.
* :math:`scale = \sigma`: is the std.
* :math:`Z`: is the normalization constant.
Args:
loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Normal
# Define a single scalar Normal distribution.
dist = Normal(loc=0., scale=3.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = Normal(loc=[1., 2.], scale=[11., 22.])
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample([3])
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
dist = Normal(loc=1., scale=[11., 22.])
# Complete example
value_tensor = paddle.to_tensor([0.8], dtype="float32")
normal_a = Normal([0.], [1.])
normal_b = Normal([0.5], [2.])
sample = normal_a.sample([2])
# a random tensor created by normal distribution with shape: [2, 1]
entropy = normal_a.entropy()
# [1.4189385] with shape: [1]
lp = normal_a.log_prob(value_tensor)
# [-1.2389386] with shape: [1]
p = normal_a.probs(value_tensor)
# [0.28969154] with shape: [1]
kl = normal_a.kl_divergence(normal_b)
# [0.34939718] with shape: [1]
"""
def __init__(self, loc, scale, name=None):
if not in_dygraph_mode():
check_type(loc, 'loc',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Normal')
check_type(scale, 'scale',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Normal')
self.batch_size_unknown = False
self.all_arg_is_float = False
self.name = name if name is not None else 'Normal'
self.dtype = 'float32'
if isinstance(loc, int):
loc = float(loc)
if isinstance(scale, int):
scale = float(scale)
if self._validate_args(loc, scale):
self.batch_size_unknown = True
self.loc = loc
self.scale = scale
self.dtype = convert_dtype(loc.dtype)
else:
if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True
if isinstance(
loc,
np.ndarray) and str(loc.dtype) in ['float32', 'float64']:
self.dtype = loc.dtype
elif isinstance(
scale,
np.ndarray) and str(scale.dtype) in ['float32', 'float64']:
self.dtype = scale.dtype
# pylint: disable=unbalanced-tuple-unpacking
self.loc, self.scale = self._to_tensor(loc, scale)
if self.dtype != convert_dtype(self.loc.dtype):
self.loc = tensor.cast(self.loc, dtype=self.dtype)
self.scale = tensor.cast(self.scale, dtype=self.dtype)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')
batch_shape = list((self.loc + self.scale).shape)
name = self.name + '_sample'
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.dtype, 0.)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
zero_tmp_shape = nn.shape(zero_tmp_reshape)
normal_random_tmp = nn.gaussian_random(
zero_tmp_shape, mean=0., std=1., seed=seed, dtype=self.dtype)
output = normal_random_tmp * (zero_tmp_reshape + self.scale)
output = elementwise_add(output, self.loc, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \
(tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output
def entropy(self):
r"""Shannon entropy in nats.
The entropy is
.. math::
entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2)
In the above equation:
* :math:`scale = \sigma`: is the std.
Returns:
Tensor: Shannon entropy of normal distribution.The data type is float32.
"""
name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.dtype, 0.)
return elementwise_add(
0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)),
name=name)
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale
log_scale = nn.log(self.scale)
return elementwise_sub(
-1. * ((value - self.loc) * (value - self.loc)) / (2. * var),
log_scale + math.log(math.sqrt(2. * math.pi)),
name=name)
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale
return elementwise_div(
ops.exp(-1. * ((value - self.loc) * (value - self.loc)) /
(2. * var)), (math.sqrt(2 * math.pi) * self.scale),
name=name)
def kl_divergence(self, other):
r"""The KL-divergence between two normal distributions.
The probability density function (pdf) is
.. math::
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio})
.. math::
ratio = \\frac{\sigma_0}{\sigma_1}
.. math::
diff = \mu_1 - \mu_0
In the above equation:
* :math:`loc = \mu_0`: is the mean of current Normal distribution.
* :math:`scale = \sigma_0`: is the std of current Normal distribution.
* :math:`loc = \mu_1`: is the mean of other Normal distribution.
* :math:`scale = \sigma_1`: is the std of other Normal distribution.
* :math:`ratio`: is the ratio of scales.
* :math:`diff`: is the difference between means.
Args:
other (Normal): instance of Normal.
Returns:
Tensor: kl-divergence between two normal distributions.The data type is float32.
"""
if not in_dygraph_mode():
check_type(other, 'other', Normal, 'kl_divergence')
name = self.name + '_kl_divergence'
var_ratio = self.scale / other.scale
var_ratio = (var_ratio * var_ratio)
t1 = (self.loc - other.loc) / other.scale
t1 = (t1 * t1)
return elementwise_add(
0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
import numpy as np
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import in_dygraph_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution
class Uniform(Distribution):
r"""Uniform distribution with `low` and `high` parameters.
Mathematical Details
The probability density function (pdf) is
.. math::
pdf(x; a, b) = \\frac{1}{Z}, \ a <=x <b
.. math::
Z = b - a
In the above equation:
* :math:`low = a`,
* :math:`high = b`,
* :math:`Z`: is the normalizing constant.
The parameters `low` and `high` must be shaped in a way that supports
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args:
low(int|float|list|tuple|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
high(int|float|list|tuple|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Uniform
# Without broadcasting, a single uniform distribution [3, 4]:
u1 = Uniform(low=3.0, high=4.0)
# 2 distributions [1, 3], [2, 4]
u2 = Uniform(low=[1.0, 2.0], high=[3.0, 4.0])
# 4 distributions
u3 = Uniform(low=[[1.0, 2.0], [3.0, 4.0]],
high=[[1.5, 2.5], [3.5, 4.5]])
# With broadcasting:
u4 = Uniform(low=3.0, high=[5.0, 6.0, 7.0])
# Complete example
value_tensor = paddle.to_tensor([0.8], dtype="float32")
uniform = Uniform([0.], [2.])
sample = uniform.sample([2])
# a random tensor created by uniform distribution with shape: [2, 1]
entropy = uniform.entropy()
# [0.6931472] with shape: [1]
lp = uniform.log_prob(value_tensor)
# [-0.6931472] with shape: [1]
p = uniform.probs(value_tensor)
# [0.5] with shape: [1]
"""
def __init__(self, low, high, name=None):
if not in_dygraph_mode():
check_type(low, 'low',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Uniform')
check_type(high, 'high',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Uniform')
self.all_arg_is_float = False
self.batch_size_unknown = False
self.name = name if name is not None else 'Uniform'
self.dtype = 'float32'
if isinstance(low, int):
low = float(low)
if isinstance(high, int):
high = float(high)
if self._validate_args(low, high):
self.batch_size_unknown = True
self.low = low
self.high = high
self.dtype = convert_dtype(low.dtype)
else:
if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True
if isinstance(
low,
np.ndarray) and str(low.dtype) in ['float32', 'float64']:
self.dtype = low.dtype
elif isinstance(
high,
np.ndarray) and str(high.dtype) in ['float32', 'float64']:
self.dtype = high.dtype
# pylint: disable=unbalanced-tuple-unpacking
self.low, self.high = self._to_tensor(low, high)
if self.dtype != convert_dtype(self.low.dtype):
self.low = tensor.cast(self.low, dtype=self.dtype)
self.high = tensor.cast(self.high, dtype=self.dtype)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')
name = self.name + '_sample'
batch_shape = list((self.low + self.high).shape)
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp,
zero_tmp.shape,
dtype=self.dtype,
min=0.,
max=1.,
seed=seed)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
uniform_random_tmp_reshape = nn.reshape(uniform_random_tmp,
output_shape)
output = uniform_random_tmp_reshape * (
zero_tmp_reshape + self.high - self.low)
output = elementwise_add(output, self.low, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.uniform_random(
output_shape, dtype=self.dtype, min=0., max=1.,
seed=seed) * (tensor.zeros(
output_shape, dtype=self.dtype) + (self.high - self.low))
output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode():
# ensure value in [low, high]
lb_bool = self.low < value
ub_bool = value < self.high
lb = _C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
value.dtype)
ub = _C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low)
name = self.name + '_log_prob'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub(
nn.log(lb * ub), nn.log(self.high - self.low), name=name)
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
lb = _C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
value.dtype)
ub = _C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
value.dtype)
return (lb * ub) / (self.high - self.low)
name = self.name + '_probs'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name)
def entropy(self):
r"""Shannon entropy in nats.
The entropy is
.. math::
entropy(low, high) = \\log (high - low)
Returns:
Tensor: Shannon entropy of uniform distribution.The data type is float32.
"""
name = self.name + '_entropy'
return nn.log(self.high - self.low, name=name)
...@@ -772,6 +772,7 @@ add_subdirectory(sequence) ...@@ -772,6 +772,7 @@ add_subdirectory(sequence)
add_subdirectory(dygraph_to_static) add_subdirectory(dygraph_to_static)
add_subdirectory(rnn) add_subdirectory(rnn)
add_subdirectory(autograd) add_subdirectory(autograd)
add_subdirectory(distribution)
if (NOT WIN32 OR NOT WITH_GPU) if (NOT WIN32 OR NOT WITH_GPU)
add_subdirectory(fft) add_subdirectory(fft)
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.distribution import *
from paddle.fluid import layers
paddle.enable_static()
class DistributionNumpy():
def sample(self):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def kl_divergence(self, other):
raise NotImplementedError
def log_prob(self, value):
raise NotImplementedError
def probs(self, value):
raise NotImplementedError
class DistributionTestName(unittest.TestCase):
def get_prefix(self, string):
return (string.split('.')[0])
def test_normal_name(self):
name = 'test_normal'
normal1 = Normal(0.0, 1.0, name=name)
self.assertEqual(normal1.name, name)
normal2 = Normal(0.0, 1.0)
self.assertEqual(normal2.name, 'Normal')
paddle.enable_static()
sample = normal1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = normal1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
layers.assign(value_npdata, value_tensor)
lp = normal1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
p = normal1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
kl = normal1.kl_divergence(normal2)
self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence')
def test_uniform_name(self):
name = 'test_uniform'
uniform1 = Uniform(0.0, 1.0, name=name)
self.assertEqual(uniform1.name, name)
uniform2 = Uniform(0.0, 1.0)
self.assertEqual(uniform2.name, 'Uniform')
paddle.enable_static()
sample = uniform1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = uniform1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
layers.assign(value_npdata, value_tensor)
lp = uniform1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
p = uniform1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
def test_categorical_name(self):
name = 'test_categorical'
categorical1 = Categorical([0.4, 0.6], name=name)
self.assertEqual(categorical1.name, name)
categorical2 = Categorical([0.5, 0.5])
self.assertEqual(categorical2.name, 'Categorical')
paddle.enable_static()
sample = categorical1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = categorical1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
kl = categorical1.kl_divergence(categorical2)
self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence')
value_npdata = np.array([0], dtype="int64")
value_tensor = layers.create_tensor(dtype="int64")
layers.assign(value_npdata, value_tensor)
p = categorical1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
lp = categorical1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.distribution import *
from paddle.fluid import layers
from test_distribution import DistributionNumpy
class CategoricalNumpy(DistributionNumpy):
def __init__(self, logits):
self.logits = np.array(logits).astype('float32')
def entropy(self):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
e_logits = np.exp(logits)
z = np.sum(e_logits, axis=-1, keepdims=True)
prob = e_logits / z
return -1. * np.sum(prob * (logits - np.log(z)), axis=-1, keepdims=True)
def kl_divergence(self, other):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
other_logits = other.logits - np.max(
other.logits, axis=-1, keepdims=True)
e_logits = np.exp(logits)
other_e_logits = np.exp(other_logits)
z = np.sum(e_logits, axis=-1, keepdims=True)
other_z = np.sum(other_e_logits, axis=-1, keepdims=True)
prob = e_logits / z
return np.sum(prob *
(logits - np.log(z) - other_logits + np.log(other_z)),
axis=-1,
keepdims=True)
class CategoricalTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=3, dims=5):
self.use_gpu = use_gpu
if not use_gpu:
self.place = fluid.CPUPlace()
self.gpu_id = -1
else:
self.place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.batch_size = batch_size
self.dims = dims
self.init_numpy_data(batch_size, dims)
paddle.disable_static(self.place)
self.init_dynamic_data(batch_size, dims)
paddle.enable_static()
self.test_program = fluid.Program()
self.executor = fluid.Executor(self.place)
self.init_static_data(batch_size, dims)
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float32')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
# dist_shape = logits_shape[:-1], it represents the number of
# different distributions.
self.dist_shape = [batch_size]
# sample shape represents the number of samples
self.sample_shape = [2, 4]
# value used in probs and log_prob method
# If value is 1-D and logits is 2-D or higher dimension, value will be
# broadcasted to have the same number of distributions with logits.
# If value is 2-D or higher dimentsion, it should have the same number
# of distributions with logtis. ``value[:-1] = logits[:-1]
self.value_shape = [3]
def init_dynamic_data(self, batch_size, dims):
self.logits = paddle.to_tensor(self.logits_np)
self.other_logits = paddle.to_tensor(self.other_logits_np)
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = fluid.data(
name='logits', shape=self.logits_shape, dtype='float32')
self.other_logits_static = fluid.data(
name='other_logits', shape=self.logits_shape, dtype='float32')
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.dist_shape + self.value_shape)
for i in range(self.batch_size):
for j in range(3):
np_probs[i][j] = probability[i][self.value_np[j]]
return np_probs
def compare_with_numpy(self, fetch_list, tolerance=1e-6):
sample, entropy, kl, probs, log_prob = fetch_list
log_tolerance = 1e-4
np.testing.assert_equal(sample.shape,
self.sample_shape + self.dist_shape)
np_categorical = CategoricalNumpy(self.logits_np)
np_other_categorical = CategoricalNumpy(self.other_logits_np)
np_entropy = np_categorical.entropy()
np_kl = np_categorical.kl_divergence(np_other_categorical)
np.testing.assert_allclose(
entropy, np_entropy, rtol=log_tolerance, atol=log_tolerance)
np.testing.assert_allclose(
kl, np_kl, rtol=log_tolerance, atol=log_tolerance)
sum_dist = np.sum(self.logits_np, axis=-1, keepdims=True)
probability = self.logits_np / sum_dist
np_probs = self.get_numpy_selected_probs(probability)
np_log_prob = np.log(np_probs)
np.testing.assert_allclose(
probs, np_probs, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
log_prob, np_log_prob, rtol=tolerance, atol=tolerance)
def test_categorical_distribution_dygraph(self, tolerance=1e-6):
paddle.disable_static(self.place)
categorical = Categorical(self.logits)
other_categorical = Categorical(self.other_logits)
sample = categorical.sample(self.sample_shape).numpy()
entropy = categorical.entropy().numpy()
kl = categorical.kl_divergence(other_categorical).numpy()
probs = categorical.probs(self.value).numpy()
log_prob = categorical.log_prob(self.value).numpy()
fetch_list = [sample, entropy, kl, probs, log_prob]
self.compare_with_numpy(fetch_list)
def test_categorical_distribution_static(self, tolerance=1e-6):
paddle.enable_static()
with fluid.program_guard(self.test_program):
categorical = Categorical(self.logits_static)
other_categorical = Categorical(self.other_logits_static)
sample = categorical.sample(self.sample_shape)
entropy = categorical.entropy()
kl = categorical.kl_divergence(other_categorical)
probs = categorical.probs(self.value_static)
log_prob = categorical.log_prob(self.value_static)
fetch_list = [sample, entropy, kl, probs, log_prob]
feed_vars = {
'logits': self.logits_np,
'other_logits': self.other_logits_np,
'value': self.value_np
}
self.executor.run(fluid.default_startup_program())
fetch_list = self.executor.run(program=self.test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_with_numpy(fetch_list)
class CategoricalTest2(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor with dtype Float64
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float64')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float64')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
self.dist_shape = [batch_size]
self.sample_shape = [2, 4]
self.value_shape = [3]
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = fluid.data(
name='logits', shape=self.logits_shape, dtype='float64')
self.other_logits_static = fluid.data(
name='other_logits', shape=self.logits_shape, dtype='float64')
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest3(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D numpy.ndarray with dtype Float32
# value used in probs and log_prob method is 1-D Tensor
self.logits = self.logits_np
self.other_logits = self.other_logits_np
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np
self.other_logits_static = self.other_logits_np
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest4(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D numpy.ndarray with dtype Float64
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float64')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float64')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
self.dist_shape = [batch_size]
self.sample_shape = [2, 4]
self.value_shape = [3]
def init_dynamic_data(self, batch_size, dims):
self.logits = self.logits_np
self.other_logits = self.other_logits_np
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np
self.other_logits_static = self.other_logits_np
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
# test shape of logits and value used in probs and log_prob method
class CategoricalTest5(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 1-D Tensor
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(dims).astype('float32')
self.other_logits_np = np.random.rand(dims).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [dims]
self.dist_shape = []
self.sample_shape = [2, 4]
self.value_shape = [3]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.value_shape)
for i in range(3):
np_probs[i] = probability[self.value_np[i]]
return np_probs
class CategoricalTest6(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor
# value used in probs and log_prob method has the same number of batches with input
self.logits_np = np.random.rand(3, 5).astype('float32')
self.other_logits_np = np.random.rand(3, 5).astype('float32')
self.value_np = np.array([[2, 1], [0, 3], [2, 3]]).astype('int64')
self.logits_shape = [3, 5]
self.dist_shape = [3]
self.sample_shape = [2, 4]
self.value_shape = [3, 2]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.value_shape)
for i in range(3):
for j in range(2):
np_probs[i][j] = probability[i][self.value_np[i][j]]
return np_probs
class CategoricalTest7(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 3-D Tensor
# value used in probs and log_prob method has the same number of distribuions with input
self.logits_np = np.random.rand(3, 2, 5).astype('float32')
self.other_logits_np = np.random.rand(3, 2, 5).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [3, 2, 5]
self.dist_shape = [3, 2]
self.sample_shape = [2, 4]
self.value_shape = [3]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.dist_shape + self.value_shape)
for i in range(3):
for j in range(2):
for k in range(3):
np_probs[i][j][k] = probability[i][j][self.value_np[k]]
return np_probs
class CategoricalTest8(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D list
# value used in probs and log_prob method is 1-D Tensor
self.logits = self.logits_np.tolist()
self.other_logits = self.other_logits_np.tolist()
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np.tolist()
self.other_logits_static = self.other_logits_np.tolist()
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest9(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D tuple
# value used in probs and log_prob method is 1-D Tensor
self.logits = tuple(self.logits_np.tolist())
self.other_logits = tuple(self.other_logits_np.tolist())
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = tuple(self.logits_np.tolist())
self.other_logits_static = tuple(self.other_logits_np.tolist())
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class DistributionTestError(unittest.TestCase):
def test_distribution_error(self):
distribution = Distribution()
self.assertRaises(NotImplementedError, distribution.sample)
self.assertRaises(NotImplementedError, distribution.entropy)
normal = Normal(0.0, 1.0)
self.assertRaises(NotImplementedError, distribution.kl_divergence,
normal)
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
self.assertRaises(NotImplementedError, distribution.log_prob,
value_tensor)
self.assertRaises(NotImplementedError, distribution.probs, value_tensor)
def test_normal_error(self):
paddle.enable_static()
normal = Normal(0.0, 1.0)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, normal.log_prob, value)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, normal.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, normal.sample, shape)
seed = 1.0
# type of seed must be int
self.assertRaises(TypeError, normal.sample, [2, 3], seed)
normal_other = Uniform(1.0, 2.0)
# type of other must be an instance of Normal
self.assertRaises(TypeError, normal.kl_divergence, normal_other)
def test_uniform_error(self):
paddle.enable_static()
uniform = Uniform(0.0, 1.0)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, uniform.log_prob, value)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, uniform.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, uniform.sample, shape)
seed = 1.0
# type of seed must be int
self.assertRaises(TypeError, uniform.sample, [2, 3], seed)
def test_categorical_error(self):
paddle.enable_static()
categorical = Categorical([0.4, 0.6])
value = [1, 0]
# type of value must be variable
self.assertRaises(AttributeError, categorical.log_prob, value)
value = [1, 0]
# type of value must be variable
self.assertRaises(AttributeError, categorical.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, categorical.sample, shape)
categorical_other = Uniform(1.0, 2.0)
# type of other must be an instance of Categorical
self.assertRaises(TypeError, categorical.kl_divergence,
categorical_other)
def test_shape_not_match_error():
# shape of value must match shape of logits
# value_shape[:-1] == logits_shape[:-1]
paddle.disable_static()
logits = paddle.rand([3, 5])
cat = Categorical(logits)
value = paddle.to_tensor([[2, 1, 3], [3, 2, 1]], dtype='int64')
cat.log_prob(value)
self.assertRaises(ValueError, test_shape_not_match_error)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,351 +12,16 @@ ...@@ -12,351 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import math
import unittest import unittest
import numpy as np
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import layers
from paddle.distribution import * from paddle.distribution import *
import math from paddle.fluid import layers
class DistributionNumpy():
def sample(self):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def kl_divergence(self, other):
raise NotImplementedError
def log_prob(self, value):
raise NotImplementedError
def probs(self, value):
raise NotImplementedError
class UniformNumpy(DistributionNumpy):
def __init__(self, low, high):
self.low = np.array(low)
self.high = np.array(high)
if str(self.low.dtype) not in ['float32', 'float64']:
self.low = self.low.astype('float32')
self.high = self.high.astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.low + self.high).shape
return self.low + (np.random.uniform(size=shape) *
(self.high - self.low))
def log_prob(self, value):
lb = np.less(self.low, value).astype(self.low.dtype)
ub = np.less(value, self.high).astype(self.low.dtype)
return np.log(lb * ub) - np.log(self.high - self.low)
def probs(self, value):
lb = np.less(self.low, value).astype(self.low.dtype)
ub = np.less(value, self.high).astype(self.low.dtype)
return (lb * ub) / (self.high - self.low)
def entropy(self):
return np.log(self.high - self.low)
class UniformTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=5, dims=6):
self.use_gpu = use_gpu
if not use_gpu:
self.place = fluid.CPUPlace()
self.gpu_id = -1
else:
self.place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.init_numpy_data(batch_size, dims)
paddle.disable_static(self.place)
self.init_dynamic_data(batch_size, dims)
paddle.enable_static()
self.test_program = fluid.Program()
self.executor = fluid.Executor(self.place)
self.init_static_data(batch_size, dims)
def init_numpy_data(self, batch_size, dims):
# low ans high are 'float'
self.low_np = np.random.uniform(-2, 1)
self.high_np = np.random.uniform(2, 4)
self.values_np = np.array([1.0]).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = self.low_np
self.dynamic_high = self.high_np
self.dynamic_values = paddle.to_tensor(self.values_np)
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[], dtype='float32')
def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6):
sample, entropy, log_prob, probs = fetch_list
np_uniform = UniformNumpy(self.low_np, self.high_np)
np_sample = np_uniform.sample([sample_shape])
np_entropy = np_uniform.entropy()
np_lp = np_uniform.log_prob(self.values_np)
np_p = np_uniform.probs(self.values_np)
np.testing.assert_equal(sample.shape, np_sample.shape)
np.testing.assert_allclose(
entropy, np_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
log_prob, np_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance)
def test_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6):
paddle.disable_static(self.place)
uniform = Uniform(self.dynamic_low, self.dynamic_high)
sample = uniform.sample([sample_shape]).numpy()
entropy = uniform.entropy().numpy()
log_prob = uniform.log_prob(self.dynamic_values).numpy()
probs = uniform.probs(self.dynamic_values).numpy()
fetch_list = [sample, entropy, log_prob, probs]
self.compare_with_numpy(fetch_list)
def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6):
paddle.enable_static()
with fluid.program_guard(self.test_program):
uniform = Uniform(self.static_low, self.static_high)
sample = uniform.sample([sample_shape])
entropy = uniform.entropy()
log_prob = uniform.log_prob(self.static_values)
probs = uniform.probs(self.static_values)
fetch_list = [sample, entropy, log_prob, probs]
feed_vars = {
'low': self.low_np,
'high': self.high_np,
'values': self.values_np
}
self.executor.run(fluid.default_startup_program())
fetch_list = self.executor.run(program=self.test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_with_numpy(fetch_list)
class UniformTest2(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low ans high are 'int'
self.low_np = int(np.random.uniform(-2, 1))
self.high_np = int(np.random.uniform(2, 4))
self.values_np = np.array([1.0]).astype('float32')
class UniformTest3(UniformTest):
def init_numpy_data(self, batch_size, dims):
# test broadcast: low is float, high is numpy.ndarray with dtype 'float32'.
self.low_np = np.random.uniform(-2, 1)
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest4(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float32'.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest5(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float64'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = self.low_np
self.dynamic_high = self.high_np
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
class UniformTest6(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP32'.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np)
self.dynamic_high = paddle.to_tensor(self.high_np)
self.dynamic_values = paddle.to_tensor(self.values_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float32')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float32')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest7(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP64'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64')
self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64')
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float64')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float64')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
class UniformTest8(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64')
self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64')
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float32')
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float64')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float64')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest9(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float32'.
# high < low.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(-10.0, -5.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest10(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are list.
self.low_np = np.random.randn(batch_size,
dims).astype('float32').tolist()
self.high_np = np.random.uniform(
5.0, 15.0, (batch_size, dims)).astype('float32').tolist()
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest11(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are tuple.
self.low_np = tuple(
np.random.randn(batch_size, dims).astype('float32').tolist())
self.high_np = tuple(
np.random.uniform(5.0, 15.0, (batch_size, dims)).astype('float32')
.tolist())
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTestSample(unittest.TestCase):
def setUp(self):
self.init_param()
def init_param(self):
self.low = 3.0
self.high = 4.0
def test_uniform_sample(self):
paddle.disable_static()
uniform = Uniform(low=self.low, high=self.high)
s = uniform.sample([100])
self.assertTrue((s >= self.low).all())
self.assertTrue((s < self.high).all())
paddle.enable_static()
class UniformTestSample2(UniformTestSample): from test_distribution import DistributionNumpy
def init_param(self):
self.low = -5.0
self.high = 2.0
class NormalNumpy(DistributionNumpy): class NormalNumpy(DistributionNumpy):
...@@ -789,511 +454,3 @@ class NormalTest10(NormalTest): ...@@ -789,511 +454,3 @@ class NormalTest10(NormalTest):
with fluid.program_guard(self.test_program): with fluid.program_guard(self.test_program):
self.static_values = layers.data( self.static_values = layers.data(
name='values', shape=[dims], dtype='float32') name='values', shape=[dims], dtype='float32')
class CategoricalNumpy(DistributionNumpy):
def __init__(self, logits):
self.logits = np.array(logits).astype('float32')
def entropy(self):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
e_logits = np.exp(logits)
z = np.sum(e_logits, axis=-1, keepdims=True)
prob = e_logits / z
return -1. * np.sum(prob * (logits - np.log(z)), axis=-1, keepdims=True)
def kl_divergence(self, other):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
other_logits = other.logits - np.max(
other.logits, axis=-1, keepdims=True)
e_logits = np.exp(logits)
other_e_logits = np.exp(other_logits)
z = np.sum(e_logits, axis=-1, keepdims=True)
other_z = np.sum(other_e_logits, axis=-1, keepdims=True)
prob = e_logits / z
return np.sum(prob * (logits - np.log(z) - other_logits \
+ np.log(other_z)), axis=-1, keepdims=True)
class CategoricalTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=3, dims=5):
self.use_gpu = use_gpu
if not use_gpu:
self.place = fluid.CPUPlace()
self.gpu_id = -1
else:
self.place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.batch_size = batch_size
self.dims = dims
self.init_numpy_data(batch_size, dims)
paddle.disable_static(self.place)
self.init_dynamic_data(batch_size, dims)
paddle.enable_static()
self.test_program = fluid.Program()
self.executor = fluid.Executor(self.place)
self.init_static_data(batch_size, dims)
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float32')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
# dist_shape = logits_shape[:-1], it represents the number of
# different distributions.
self.dist_shape = [batch_size]
# sample shape represents the number of samples
self.sample_shape = [2, 4]
# value used in probs and log_prob method
# If value is 1-D and logits is 2-D or higher dimension, value will be
# broadcasted to have the same number of distributions with logits.
# If value is 2-D or higher dimentsion, it should have the same number
# of distributions with logtis. ``value[:-1] = logits[:-1]
self.value_shape = [3]
def init_dynamic_data(self, batch_size, dims):
self.logits = paddle.to_tensor(self.logits_np)
self.other_logits = paddle.to_tensor(self.other_logits_np)
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = fluid.data(
name='logits', shape=self.logits_shape, dtype='float32')
self.other_logits_static = fluid.data(
name='other_logits', shape=self.logits_shape, dtype='float32')
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.dist_shape + self.value_shape)
for i in range(self.batch_size):
for j in range(3):
np_probs[i][j] = probability[i][self.value_np[j]]
return np_probs
def compare_with_numpy(self, fetch_list, tolerance=1e-6):
sample, entropy, kl, probs, log_prob = fetch_list
log_tolerance = 1e-4
np.testing.assert_equal(sample.shape,
self.sample_shape + self.dist_shape)
np_categorical = CategoricalNumpy(self.logits_np)
np_other_categorical = CategoricalNumpy(self.other_logits_np)
np_entropy = np_categorical.entropy()
np_kl = np_categorical.kl_divergence(np_other_categorical)
np.testing.assert_allclose(
entropy, np_entropy, rtol=log_tolerance, atol=log_tolerance)
np.testing.assert_allclose(
kl, np_kl, rtol=log_tolerance, atol=log_tolerance)
sum_dist = np.sum(self.logits_np, axis=-1, keepdims=True)
probability = self.logits_np / sum_dist
np_probs = self.get_numpy_selected_probs(probability)
np_log_prob = np.log(np_probs)
np.testing.assert_allclose(
probs, np_probs, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
log_prob, np_log_prob, rtol=tolerance, atol=tolerance)
def test_categorical_distribution_dygraph(self, tolerance=1e-6):
paddle.disable_static(self.place)
categorical = Categorical(self.logits)
other_categorical = Categorical(self.other_logits)
sample = categorical.sample(self.sample_shape).numpy()
entropy = categorical.entropy().numpy()
kl = categorical.kl_divergence(other_categorical).numpy()
probs = categorical.probs(self.value).numpy()
log_prob = categorical.log_prob(self.value).numpy()
fetch_list = [sample, entropy, kl, probs, log_prob]
self.compare_with_numpy(fetch_list)
def test_categorical_distribution_static(self, tolerance=1e-6):
paddle.enable_static()
with fluid.program_guard(self.test_program):
categorical = Categorical(self.logits_static)
other_categorical = Categorical(self.other_logits_static)
sample = categorical.sample(self.sample_shape)
entropy = categorical.entropy()
kl = categorical.kl_divergence(other_categorical)
probs = categorical.probs(self.value_static)
log_prob = categorical.log_prob(self.value_static)
fetch_list = [sample, entropy, kl, probs, log_prob]
feed_vars = {
'logits': self.logits_np,
'other_logits': self.other_logits_np,
'value': self.value_np
}
self.executor.run(fluid.default_startup_program())
fetch_list = self.executor.run(program=self.test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_with_numpy(fetch_list)
class CategoricalTest2(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor with dtype Float64
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float64')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float64')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
self.dist_shape = [batch_size]
self.sample_shape = [2, 4]
self.value_shape = [3]
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = fluid.data(
name='logits', shape=self.logits_shape, dtype='float64')
self.other_logits_static = fluid.data(
name='other_logits', shape=self.logits_shape, dtype='float64')
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest3(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D numpy.ndarray with dtype Float32
# value used in probs and log_prob method is 1-D Tensor
self.logits = self.logits_np
self.other_logits = self.other_logits_np
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np
self.other_logits_static = self.other_logits_np
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest4(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D numpy.ndarray with dtype Float64
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float64')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float64')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
self.dist_shape = [batch_size]
self.sample_shape = [2, 4]
self.value_shape = [3]
def init_dynamic_data(self, batch_size, dims):
self.logits = self.logits_np
self.other_logits = self.other_logits_np
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np
self.other_logits_static = self.other_logits_np
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
# test shape of logits and value used in probs and log_prob method
class CategoricalTest5(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 1-D Tensor
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(dims).astype('float32')
self.other_logits_np = np.random.rand(dims).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [dims]
self.dist_shape = []
self.sample_shape = [2, 4]
self.value_shape = [3]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.value_shape)
for i in range(3):
np_probs[i] = probability[self.value_np[i]]
return np_probs
class CategoricalTest6(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor
# value used in probs and log_prob method has the same number of batches with input
self.logits_np = np.random.rand(3, 5).astype('float32')
self.other_logits_np = np.random.rand(3, 5).astype('float32')
self.value_np = np.array([[2, 1], [0, 3], [2, 3]]).astype('int64')
self.logits_shape = [3, 5]
self.dist_shape = [3]
self.sample_shape = [2, 4]
self.value_shape = [3, 2]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.value_shape)
for i in range(3):
for j in range(2):
np_probs[i][j] = probability[i][self.value_np[i][j]]
return np_probs
class CategoricalTest7(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 3-D Tensor
# value used in probs and log_prob method has the same number of distribuions with input
self.logits_np = np.random.rand(3, 2, 5).astype('float32')
self.other_logits_np = np.random.rand(3, 2, 5).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [3, 2, 5]
self.dist_shape = [3, 2]
self.sample_shape = [2, 4]
self.value_shape = [3]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.dist_shape + self.value_shape)
for i in range(3):
for j in range(2):
for k in range(3):
np_probs[i][j][k] = probability[i][j][self.value_np[k]]
return np_probs
class CategoricalTest8(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D list
# value used in probs and log_prob method is 1-D Tensor
self.logits = self.logits_np.tolist()
self.other_logits = self.other_logits_np.tolist()
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np.tolist()
self.other_logits_static = self.other_logits_np.tolist()
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest9(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D tuple
# value used in probs and log_prob method is 1-D Tensor
self.logits = tuple(self.logits_np.tolist())
self.other_logits = tuple(self.other_logits_np.tolist())
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = tuple(self.logits_np.tolist())
self.other_logits_static = tuple(self.other_logits_np.tolist())
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class DistributionTestError(unittest.TestCase):
def test_distribution_error(self):
distribution = Distribution()
self.assertRaises(NotImplementedError, distribution.sample)
self.assertRaises(NotImplementedError, distribution.entropy)
normal = Normal(0.0, 1.0)
self.assertRaises(NotImplementedError, distribution.kl_divergence,
normal)
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
self.assertRaises(NotImplementedError, distribution.log_prob,
value_tensor)
self.assertRaises(NotImplementedError, distribution.probs, value_tensor)
def test_normal_error(self):
paddle.enable_static()
normal = Normal(0.0, 1.0)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, normal.log_prob, value)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, normal.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, normal.sample, shape)
seed = 1.0
# type of seed must be int
self.assertRaises(TypeError, normal.sample, [2, 3], seed)
normal_other = Uniform(1.0, 2.0)
# type of other must be an instance of Normal
self.assertRaises(TypeError, normal.kl_divergence, normal_other)
def test_uniform_error(self):
paddle.enable_static()
uniform = Uniform(0.0, 1.0)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, uniform.log_prob, value)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, uniform.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, uniform.sample, shape)
seed = 1.0
# type of seed must be int
self.assertRaises(TypeError, uniform.sample, [2, 3], seed)
def test_categorical_error(self):
paddle.enable_static()
categorical = Categorical([0.4, 0.6])
value = [1, 0]
# type of value must be variable
self.assertRaises(AttributeError, categorical.log_prob, value)
value = [1, 0]
# type of value must be variable
self.assertRaises(AttributeError, categorical.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, categorical.sample, shape)
categorical_other = Uniform(1.0, 2.0)
# type of other must be an instance of Categorical
self.assertRaises(TypeError, categorical.kl_divergence,
categorical_other)
def test_shape_not_match_error():
# shape of value must match shape of logits
# value_shape[:-1] == logits_shape[:-1]
paddle.disable_static()
logits = paddle.rand([3, 5])
cat = Categorical(logits)
value = paddle.to_tensor([[2, 1, 3], [3, 2, 1]], dtype='int64')
cat.log_prob(value)
self.assertRaises(ValueError, test_shape_not_match_error)
class DistributionTestName(unittest.TestCase):
def get_prefix(self, string):
return (string.split('.')[0])
def test_normal_name(self):
name = 'test_normal'
normal1 = Normal(0.0, 1.0, name=name)
self.assertEqual(normal1.name, name)
normal2 = Normal(0.0, 1.0)
self.assertEqual(normal2.name, 'Normal')
paddle.enable_static()
sample = normal1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = normal1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
layers.assign(value_npdata, value_tensor)
lp = normal1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
p = normal1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
kl = normal1.kl_divergence(normal2)
self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence')
def test_uniform_name(self):
name = 'test_uniform'
uniform1 = Uniform(0.0, 1.0, name=name)
self.assertEqual(uniform1.name, name)
uniform2 = Uniform(0.0, 1.0)
self.assertEqual(uniform2.name, 'Uniform')
paddle.enable_static()
sample = uniform1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = uniform1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
layers.assign(value_npdata, value_tensor)
lp = uniform1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
p = uniform1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
def test_categorical_name(self):
name = 'test_categorical'
categorical1 = Categorical([0.4, 0.6], name=name)
self.assertEqual(categorical1.name, name)
categorical2 = Categorical([0.5, 0.5])
self.assertEqual(categorical2.name, 'Categorical')
paddle.enable_static()
sample = categorical1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = categorical1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
kl = categorical1.kl_divergence(categorical2)
self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence')
value_npdata = np.array([0], dtype="int64")
value_tensor = layers.create_tensor(dtype="int64")
layers.assign(value_npdata, value_tensor)
p = categorical1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
lp = categorical1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.distribution import *
from paddle.fluid import layers
from test_distribution import DistributionNumpy
class UniformNumpy(DistributionNumpy):
def __init__(self, low, high):
self.low = np.array(low)
self.high = np.array(high)
if str(self.low.dtype) not in ['float32', 'float64']:
self.low = self.low.astype('float32')
self.high = self.high.astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.low + self.high).shape
return self.low + (np.random.uniform(size=shape) *
(self.high - self.low))
def log_prob(self, value):
lb = np.less(self.low, value).astype(self.low.dtype)
ub = np.less(value, self.high).astype(self.low.dtype)
return np.log(lb * ub) - np.log(self.high - self.low)
def probs(self, value):
lb = np.less(self.low, value).astype(self.low.dtype)
ub = np.less(value, self.high).astype(self.low.dtype)
return (lb * ub) / (self.high - self.low)
def entropy(self):
return np.log(self.high - self.low)
class UniformTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=5, dims=6):
self.use_gpu = use_gpu
if not use_gpu:
self.place = fluid.CPUPlace()
self.gpu_id = -1
else:
self.place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.init_numpy_data(batch_size, dims)
paddle.disable_static(self.place)
self.init_dynamic_data(batch_size, dims)
paddle.enable_static()
self.test_program = fluid.Program()
self.executor = fluid.Executor(self.place)
self.init_static_data(batch_size, dims)
def init_numpy_data(self, batch_size, dims):
# low ans high are 'float'
self.low_np = np.random.uniform(-2, 1)
self.high_np = np.random.uniform(2, 4)
self.values_np = np.array([1.0]).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = self.low_np
self.dynamic_high = self.high_np
self.dynamic_values = paddle.to_tensor(self.values_np)
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[], dtype='float32')
def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6):
sample, entropy, log_prob, probs = fetch_list
np_uniform = UniformNumpy(self.low_np, self.high_np)
np_sample = np_uniform.sample([sample_shape])
np_entropy = np_uniform.entropy()
np_lp = np_uniform.log_prob(self.values_np)
np_p = np_uniform.probs(self.values_np)
np.testing.assert_equal(sample.shape, np_sample.shape)
np.testing.assert_allclose(
entropy, np_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
log_prob, np_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance)
def test_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6):
paddle.disable_static(self.place)
uniform = Uniform(self.dynamic_low, self.dynamic_high)
sample = uniform.sample([sample_shape]).numpy()
entropy = uniform.entropy().numpy()
log_prob = uniform.log_prob(self.dynamic_values).numpy()
probs = uniform.probs(self.dynamic_values).numpy()
fetch_list = [sample, entropy, log_prob, probs]
self.compare_with_numpy(fetch_list)
def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6):
paddle.enable_static()
with fluid.program_guard(self.test_program):
uniform = Uniform(self.static_low, self.static_high)
sample = uniform.sample([sample_shape])
entropy = uniform.entropy()
log_prob = uniform.log_prob(self.static_values)
probs = uniform.probs(self.static_values)
fetch_list = [sample, entropy, log_prob, probs]
feed_vars = {
'low': self.low_np,
'high': self.high_np,
'values': self.values_np
}
self.executor.run(fluid.default_startup_program())
fetch_list = self.executor.run(program=self.test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_with_numpy(fetch_list)
class UniformTest2(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low ans high are 'int'
self.low_np = int(np.random.uniform(-2, 1))
self.high_np = int(np.random.uniform(2, 4))
self.values_np = np.array([1.0]).astype('float32')
class UniformTest3(UniformTest):
def init_numpy_data(self, batch_size, dims):
# test broadcast: low is float, high is numpy.ndarray with dtype 'float32'.
self.low_np = np.random.uniform(-2, 1)
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest4(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float32'.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest5(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float64'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = self.low_np
self.dynamic_high = self.high_np
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
class UniformTest6(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP32'.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np)
self.dynamic_high = paddle.to_tensor(self.high_np)
self.dynamic_values = paddle.to_tensor(self.values_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float32')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float32')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest7(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP64'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64')
self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64')
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float64')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float64')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
class UniformTest8(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64')
self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64')
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float32')
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float64')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float64')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest9(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float32'.
# high < low.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(-10.0, -5.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest10(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are list.
self.low_np = np.random.randn(batch_size,
dims).astype('float32').tolist()
self.high_np = np.random.uniform(
5.0, 15.0, (batch_size, dims)).astype('float32').tolist()
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest11(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are tuple.
self.low_np = tuple(
np.random.randn(batch_size, dims).astype('float32').tolist())
self.high_np = tuple(
np.random.uniform(5.0, 15.0, (batch_size, dims)).astype('float32')
.tolist())
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTestSample(unittest.TestCase):
def setUp(self):
self.init_param()
def init_param(self):
self.low = 3.0
self.high = 4.0
def test_uniform_sample(self):
paddle.disable_static()
uniform = Uniform(low=self.low, high=self.high)
s = uniform.sample([100])
self.assertTrue((s >= self.low).all())
self.assertTrue((s < self.high).all())
paddle.enable_static()
class UniformTestSample2(UniformTestSample):
def init_param(self):
self.low = -5.0
self.high = 2.0
...@@ -276,6 +276,7 @@ packages=['paddle', ...@@ -276,6 +276,7 @@ packages=['paddle',
'paddle.incubate.tensor', 'paddle.incubate.tensor',
'paddle.incubate.nn', 'paddle.incubate.nn',
'paddle.incubate.passes', 'paddle.incubate.passes',
'paddle.distribution',
'paddle.distributed.fleet', 'paddle.distributed.fleet',
'paddle.distributed.fleet.base', 'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.elastic',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册