未验证 提交 a3e6f18c 编写于 作者: X Xiaoxu Chen 提交者: GitHub

move distribution.py into distribution package and split into different file...

move distribution.py into distribution package and split into different file for better scalability (#38047)
上级 7da5368d
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .categorical import Categorical
from .distribution import Distribution
from .normal import Normal
from .uniform import Uniform
__all__ = ['Categorical', 'Distribution', 'Normal', 'Uniform']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define the distribution functions
# __all__ = ['Categorical',
# 'MultivariateNormalDiag',
# 'Normal',
# 'sampling_id',
# 'Uniform']
from __future__ import print_function
import math
import warnings
import numpy as np
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import in_dygraph_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
class Distribution(object):
"""
The abstract base class for probability distributions. Functions are
implemented in specific distributions.
"""
def __init__(self):
super(Distribution, self).__init__()
def sample(self):
"""Sampling from the distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the distribution."""
raise NotImplementedError
def kl_divergence(self, other):
"""The KL-divergence between self distributions and other."""
raise NotImplementedError
def log_prob(self, value):
"""Log probability density/mass function."""
raise NotImplementedError
def probs(self, value):
"""Probability density/mass function."""
raise NotImplementedError
def _validate_args(self, *args):
"""
Argument validation for distribution args
Args:
value (float, list, numpy.ndarray, Tensor)
Raises
ValueError: if one argument is Tensor, all arguments should be Tensor
"""
is_variable = False
is_number = False
for arg in args:
if isinstance(arg, tensor.Variable):
is_variable = True
else:
is_number = True
if is_variable and is_number:
raise ValueError(
'if one argument is Tensor, all arguments should be Tensor')
return is_variable
def _to_tensor(self, *args):
"""
Argument convert args to Tensor
Args:
value (float, list, numpy.ndarray, Tensor)
Returns:
Tensor of args.
"""
numpy_args = []
variable_args = []
tmp = 0.
for arg in args:
if isinstance(arg, float):
arg = [arg]
if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".
format(type(arg)))
arg_np = np.array(arg)
arg_dtype = arg_np.dtype
if str(arg_dtype) != 'float32':
if str(arg_dtype) != 'float64':
# "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
# and converted to float64 later using "cast".
warnings.warn(
"data type of argument only support float32 and float64, your argument will be convert to float32."
)
arg_np = arg_np.astype('float32')
# tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
tmp = tmp + arg_np
numpy_args.append(arg_np)
dtype = tmp.dtype
for arg in numpy_args:
arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
arg_variable = tensor.create_tensor(dtype=dtype)
tensor.assign(arg_broadcasted, arg_variable)
variable_args.append(arg_variable)
return tuple(variable_args)
def _check_values_dtype_in_probs(self, param, value):
"""
Log_prob and probs methods have input ``value``, if value's dtype is different from param,
convert value's dtype to be consistent with param's dtype.
Args:
param (Tensor): low and high in Uniform class, loc and scale in Normal class.
value (Tensor): The input tensor.
Returns:
value (Tensor): Change value's dtype if value's dtype is different from param.
"""
if in_dygraph_mode():
if value.dtype != param.dtype and convert_dtype(
value.dtype) in ['float32', 'float64']:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return _C_ops.cast(value, 'in_dtype', value.dtype, 'out_dtype',
param.dtype)
return value
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
if value.dtype != param.dtype:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return tensor.cast(value, dtype=param.dtype)
return value
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
import numpy as np
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import in_dygraph_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution
class Normal(Distribution):
r"""The Normal distribution with location `loc` and `scale` parameters.
Mathematical details
The probability density function (pdf) is
.. math::
pdf(x; \mu, \sigma) = \\frac{1}{Z}e^{\\frac {-0.5 (x - \mu)^2} {\sigma^2} }
.. math::
Z = (2 \pi \sigma^2)^{0.5}
In the above equation:
* :math:`loc = \mu`: is the mean.
* :math:`scale = \sigma`: is the std.
* :math:`Z`: is the normalization constant.
Args:
loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Normal
# Define a single scalar Normal distribution.
dist = Normal(loc=0., scale=3.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = Normal(loc=[1., 2.], scale=[11., 22.])
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample([3])
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
dist = Normal(loc=1., scale=[11., 22.])
# Complete example
value_tensor = paddle.to_tensor([0.8], dtype="float32")
normal_a = Normal([0.], [1.])
normal_b = Normal([0.5], [2.])
sample = normal_a.sample([2])
# a random tensor created by normal distribution with shape: [2, 1]
entropy = normal_a.entropy()
# [1.4189385] with shape: [1]
lp = normal_a.log_prob(value_tensor)
# [-1.2389386] with shape: [1]
p = normal_a.probs(value_tensor)
# [0.28969154] with shape: [1]
kl = normal_a.kl_divergence(normal_b)
# [0.34939718] with shape: [1]
"""
def __init__(self, loc, scale, name=None):
if not in_dygraph_mode():
check_type(loc, 'loc',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Normal')
check_type(scale, 'scale',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Normal')
self.batch_size_unknown = False
self.all_arg_is_float = False
self.name = name if name is not None else 'Normal'
self.dtype = 'float32'
if isinstance(loc, int):
loc = float(loc)
if isinstance(scale, int):
scale = float(scale)
if self._validate_args(loc, scale):
self.batch_size_unknown = True
self.loc = loc
self.scale = scale
self.dtype = convert_dtype(loc.dtype)
else:
if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True
if isinstance(
loc,
np.ndarray) and str(loc.dtype) in ['float32', 'float64']:
self.dtype = loc.dtype
elif isinstance(
scale,
np.ndarray) and str(scale.dtype) in ['float32', 'float64']:
self.dtype = scale.dtype
# pylint: disable=unbalanced-tuple-unpacking
self.loc, self.scale = self._to_tensor(loc, scale)
if self.dtype != convert_dtype(self.loc.dtype):
self.loc = tensor.cast(self.loc, dtype=self.dtype)
self.scale = tensor.cast(self.scale, dtype=self.dtype)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')
batch_shape = list((self.loc + self.scale).shape)
name = self.name + '_sample'
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.dtype, 0.)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
zero_tmp_shape = nn.shape(zero_tmp_reshape)
normal_random_tmp = nn.gaussian_random(
zero_tmp_shape, mean=0., std=1., seed=seed, dtype=self.dtype)
output = normal_random_tmp * (zero_tmp_reshape + self.scale)
output = elementwise_add(output, self.loc, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \
(tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output
def entropy(self):
r"""Shannon entropy in nats.
The entropy is
.. math::
entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2)
In the above equation:
* :math:`scale = \sigma`: is the std.
Returns:
Tensor: Shannon entropy of normal distribution.The data type is float32.
"""
name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.dtype, 0.)
return elementwise_add(
0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)),
name=name)
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale
log_scale = nn.log(self.scale)
return elementwise_sub(
-1. * ((value - self.loc) * (value - self.loc)) / (2. * var),
log_scale + math.log(math.sqrt(2. * math.pi)),
name=name)
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale
return elementwise_div(
ops.exp(-1. * ((value - self.loc) * (value - self.loc)) /
(2. * var)), (math.sqrt(2 * math.pi) * self.scale),
name=name)
def kl_divergence(self, other):
r"""The KL-divergence between two normal distributions.
The probability density function (pdf) is
.. math::
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio})
.. math::
ratio = \\frac{\sigma_0}{\sigma_1}
.. math::
diff = \mu_1 - \mu_0
In the above equation:
* :math:`loc = \mu_0`: is the mean of current Normal distribution.
* :math:`scale = \sigma_0`: is the std of current Normal distribution.
* :math:`loc = \mu_1`: is the mean of other Normal distribution.
* :math:`scale = \sigma_1`: is the std of other Normal distribution.
* :math:`ratio`: is the ratio of scales.
* :math:`diff`: is the difference between means.
Args:
other (Normal): instance of Normal.
Returns:
Tensor: kl-divergence between two normal distributions.The data type is float32.
"""
if not in_dygraph_mode():
check_type(other, 'other', Normal, 'kl_divergence')
name = self.name + '_kl_divergence'
var_ratio = self.scale / other.scale
var_ratio = (var_ratio * var_ratio)
t1 = (self.loc - other.loc) / other.scale
t1 = (t1 * t1)
return elementwise_add(
0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
import numpy as np
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import in_dygraph_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution
class Uniform(Distribution):
r"""Uniform distribution with `low` and `high` parameters.
Mathematical Details
The probability density function (pdf) is
.. math::
pdf(x; a, b) = \\frac{1}{Z}, \ a <=x <b
.. math::
Z = b - a
In the above equation:
* :math:`low = a`,
* :math:`high = b`,
* :math:`Z`: is the normalizing constant.
The parameters `low` and `high` must be shaped in a way that supports
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args:
low(int|float|list|tuple|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
high(int|float|list|tuple|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Uniform
# Without broadcasting, a single uniform distribution [3, 4]:
u1 = Uniform(low=3.0, high=4.0)
# 2 distributions [1, 3], [2, 4]
u2 = Uniform(low=[1.0, 2.0], high=[3.0, 4.0])
# 4 distributions
u3 = Uniform(low=[[1.0, 2.0], [3.0, 4.0]],
high=[[1.5, 2.5], [3.5, 4.5]])
# With broadcasting:
u4 = Uniform(low=3.0, high=[5.0, 6.0, 7.0])
# Complete example
value_tensor = paddle.to_tensor([0.8], dtype="float32")
uniform = Uniform([0.], [2.])
sample = uniform.sample([2])
# a random tensor created by uniform distribution with shape: [2, 1]
entropy = uniform.entropy()
# [0.6931472] with shape: [1]
lp = uniform.log_prob(value_tensor)
# [-0.6931472] with shape: [1]
p = uniform.probs(value_tensor)
# [0.5] with shape: [1]
"""
def __init__(self, low, high, name=None):
if not in_dygraph_mode():
check_type(low, 'low',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Uniform')
check_type(high, 'high',
(int, float, np.ndarray, tensor.Variable, list, tuple),
'Uniform')
self.all_arg_is_float = False
self.batch_size_unknown = False
self.name = name if name is not None else 'Uniform'
self.dtype = 'float32'
if isinstance(low, int):
low = float(low)
if isinstance(high, int):
high = float(high)
if self._validate_args(low, high):
self.batch_size_unknown = True
self.low = low
self.high = high
self.dtype = convert_dtype(low.dtype)
else:
if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True
if isinstance(
low,
np.ndarray) and str(low.dtype) in ['float32', 'float64']:
self.dtype = low.dtype
elif isinstance(
high,
np.ndarray) and str(high.dtype) in ['float32', 'float64']:
self.dtype = high.dtype
# pylint: disable=unbalanced-tuple-unpacking
self.low, self.high = self._to_tensor(low, high)
if self.dtype != convert_dtype(self.low.dtype):
self.low = tensor.cast(self.low, dtype=self.dtype)
self.high = tensor.cast(self.high, dtype=self.dtype)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')
name = self.name + '_sample'
batch_shape = list((self.low + self.high).shape)
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp,
zero_tmp.shape,
dtype=self.dtype,
min=0.,
max=1.,
seed=seed)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
uniform_random_tmp_reshape = nn.reshape(uniform_random_tmp,
output_shape)
output = uniform_random_tmp_reshape * (
zero_tmp_reshape + self.high - self.low)
output = elementwise_add(output, self.low, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.uniform_random(
output_shape, dtype=self.dtype, min=0., max=1.,
seed=seed) * (tensor.zeros(
output_shape, dtype=self.dtype) + (self.high - self.low))
output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode():
# ensure value in [low, high]
lb_bool = self.low < value
ub_bool = value < self.high
lb = _C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
value.dtype)
ub = _C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low)
name = self.name + '_log_prob'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub(
nn.log(lb * ub), nn.log(self.high - self.low), name=name)
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
lb = _C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
value.dtype)
ub = _C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
value.dtype)
return (lb * ub) / (self.high - self.low)
name = self.name + '_probs'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name)
def entropy(self):
r"""Shannon entropy in nats.
The entropy is
.. math::
entropy(low, high) = \\log (high - low)
Returns:
Tensor: Shannon entropy of uniform distribution.The data type is float32.
"""
name = self.name + '_entropy'
return nn.log(self.high - self.low, name=name)
......@@ -772,6 +772,7 @@ add_subdirectory(sequence)
add_subdirectory(dygraph_to_static)
add_subdirectory(rnn)
add_subdirectory(autograd)
add_subdirectory(distribution)
if (NOT WIN32 OR NOT WITH_GPU)
add_subdirectory(fft)
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.distribution import *
from paddle.fluid import layers
paddle.enable_static()
class DistributionNumpy():
def sample(self):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def kl_divergence(self, other):
raise NotImplementedError
def log_prob(self, value):
raise NotImplementedError
def probs(self, value):
raise NotImplementedError
class DistributionTestName(unittest.TestCase):
def get_prefix(self, string):
return (string.split('.')[0])
def test_normal_name(self):
name = 'test_normal'
normal1 = Normal(0.0, 1.0, name=name)
self.assertEqual(normal1.name, name)
normal2 = Normal(0.0, 1.0)
self.assertEqual(normal2.name, 'Normal')
paddle.enable_static()
sample = normal1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = normal1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
layers.assign(value_npdata, value_tensor)
lp = normal1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
p = normal1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
kl = normal1.kl_divergence(normal2)
self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence')
def test_uniform_name(self):
name = 'test_uniform'
uniform1 = Uniform(0.0, 1.0, name=name)
self.assertEqual(uniform1.name, name)
uniform2 = Uniform(0.0, 1.0)
self.assertEqual(uniform2.name, 'Uniform')
paddle.enable_static()
sample = uniform1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = uniform1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
layers.assign(value_npdata, value_tensor)
lp = uniform1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
p = uniform1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
def test_categorical_name(self):
name = 'test_categorical'
categorical1 = Categorical([0.4, 0.6], name=name)
self.assertEqual(categorical1.name, name)
categorical2 = Categorical([0.5, 0.5])
self.assertEqual(categorical2.name, 'Categorical')
paddle.enable_static()
sample = categorical1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = categorical1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
kl = categorical1.kl_divergence(categorical2)
self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence')
value_npdata = np.array([0], dtype="int64")
value_tensor = layers.create_tensor(dtype="int64")
layers.assign(value_npdata, value_tensor)
p = categorical1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
lp = categorical1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.distribution import *
from paddle.fluid import layers
from test_distribution import DistributionNumpy
class CategoricalNumpy(DistributionNumpy):
def __init__(self, logits):
self.logits = np.array(logits).astype('float32')
def entropy(self):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
e_logits = np.exp(logits)
z = np.sum(e_logits, axis=-1, keepdims=True)
prob = e_logits / z
return -1. * np.sum(prob * (logits - np.log(z)), axis=-1, keepdims=True)
def kl_divergence(self, other):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
other_logits = other.logits - np.max(
other.logits, axis=-1, keepdims=True)
e_logits = np.exp(logits)
other_e_logits = np.exp(other_logits)
z = np.sum(e_logits, axis=-1, keepdims=True)
other_z = np.sum(other_e_logits, axis=-1, keepdims=True)
prob = e_logits / z
return np.sum(prob *
(logits - np.log(z) - other_logits + np.log(other_z)),
axis=-1,
keepdims=True)
class CategoricalTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=3, dims=5):
self.use_gpu = use_gpu
if not use_gpu:
self.place = fluid.CPUPlace()
self.gpu_id = -1
else:
self.place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.batch_size = batch_size
self.dims = dims
self.init_numpy_data(batch_size, dims)
paddle.disable_static(self.place)
self.init_dynamic_data(batch_size, dims)
paddle.enable_static()
self.test_program = fluid.Program()
self.executor = fluid.Executor(self.place)
self.init_static_data(batch_size, dims)
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float32')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
# dist_shape = logits_shape[:-1], it represents the number of
# different distributions.
self.dist_shape = [batch_size]
# sample shape represents the number of samples
self.sample_shape = [2, 4]
# value used in probs and log_prob method
# If value is 1-D and logits is 2-D or higher dimension, value will be
# broadcasted to have the same number of distributions with logits.
# If value is 2-D or higher dimentsion, it should have the same number
# of distributions with logtis. ``value[:-1] = logits[:-1]
self.value_shape = [3]
def init_dynamic_data(self, batch_size, dims):
self.logits = paddle.to_tensor(self.logits_np)
self.other_logits = paddle.to_tensor(self.other_logits_np)
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = fluid.data(
name='logits', shape=self.logits_shape, dtype='float32')
self.other_logits_static = fluid.data(
name='other_logits', shape=self.logits_shape, dtype='float32')
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.dist_shape + self.value_shape)
for i in range(self.batch_size):
for j in range(3):
np_probs[i][j] = probability[i][self.value_np[j]]
return np_probs
def compare_with_numpy(self, fetch_list, tolerance=1e-6):
sample, entropy, kl, probs, log_prob = fetch_list
log_tolerance = 1e-4
np.testing.assert_equal(sample.shape,
self.sample_shape + self.dist_shape)
np_categorical = CategoricalNumpy(self.logits_np)
np_other_categorical = CategoricalNumpy(self.other_logits_np)
np_entropy = np_categorical.entropy()
np_kl = np_categorical.kl_divergence(np_other_categorical)
np.testing.assert_allclose(
entropy, np_entropy, rtol=log_tolerance, atol=log_tolerance)
np.testing.assert_allclose(
kl, np_kl, rtol=log_tolerance, atol=log_tolerance)
sum_dist = np.sum(self.logits_np, axis=-1, keepdims=True)
probability = self.logits_np / sum_dist
np_probs = self.get_numpy_selected_probs(probability)
np_log_prob = np.log(np_probs)
np.testing.assert_allclose(
probs, np_probs, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
log_prob, np_log_prob, rtol=tolerance, atol=tolerance)
def test_categorical_distribution_dygraph(self, tolerance=1e-6):
paddle.disable_static(self.place)
categorical = Categorical(self.logits)
other_categorical = Categorical(self.other_logits)
sample = categorical.sample(self.sample_shape).numpy()
entropy = categorical.entropy().numpy()
kl = categorical.kl_divergence(other_categorical).numpy()
probs = categorical.probs(self.value).numpy()
log_prob = categorical.log_prob(self.value).numpy()
fetch_list = [sample, entropy, kl, probs, log_prob]
self.compare_with_numpy(fetch_list)
def test_categorical_distribution_static(self, tolerance=1e-6):
paddle.enable_static()
with fluid.program_guard(self.test_program):
categorical = Categorical(self.logits_static)
other_categorical = Categorical(self.other_logits_static)
sample = categorical.sample(self.sample_shape)
entropy = categorical.entropy()
kl = categorical.kl_divergence(other_categorical)
probs = categorical.probs(self.value_static)
log_prob = categorical.log_prob(self.value_static)
fetch_list = [sample, entropy, kl, probs, log_prob]
feed_vars = {
'logits': self.logits_np,
'other_logits': self.other_logits_np,
'value': self.value_np
}
self.executor.run(fluid.default_startup_program())
fetch_list = self.executor.run(program=self.test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_with_numpy(fetch_list)
class CategoricalTest2(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor with dtype Float64
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float64')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float64')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
self.dist_shape = [batch_size]
self.sample_shape = [2, 4]
self.value_shape = [3]
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = fluid.data(
name='logits', shape=self.logits_shape, dtype='float64')
self.other_logits_static = fluid.data(
name='other_logits', shape=self.logits_shape, dtype='float64')
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest3(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D numpy.ndarray with dtype Float32
# value used in probs and log_prob method is 1-D Tensor
self.logits = self.logits_np
self.other_logits = self.other_logits_np
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np
self.other_logits_static = self.other_logits_np
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest4(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D numpy.ndarray with dtype Float64
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float64')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float64')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
self.dist_shape = [batch_size]
self.sample_shape = [2, 4]
self.value_shape = [3]
def init_dynamic_data(self, batch_size, dims):
self.logits = self.logits_np
self.other_logits = self.other_logits_np
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np
self.other_logits_static = self.other_logits_np
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
# test shape of logits and value used in probs and log_prob method
class CategoricalTest5(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 1-D Tensor
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(dims).astype('float32')
self.other_logits_np = np.random.rand(dims).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [dims]
self.dist_shape = []
self.sample_shape = [2, 4]
self.value_shape = [3]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.value_shape)
for i in range(3):
np_probs[i] = probability[self.value_np[i]]
return np_probs
class CategoricalTest6(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor
# value used in probs and log_prob method has the same number of batches with input
self.logits_np = np.random.rand(3, 5).astype('float32')
self.other_logits_np = np.random.rand(3, 5).astype('float32')
self.value_np = np.array([[2, 1], [0, 3], [2, 3]]).astype('int64')
self.logits_shape = [3, 5]
self.dist_shape = [3]
self.sample_shape = [2, 4]
self.value_shape = [3, 2]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.value_shape)
for i in range(3):
for j in range(2):
np_probs[i][j] = probability[i][self.value_np[i][j]]
return np_probs
class CategoricalTest7(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 3-D Tensor
# value used in probs and log_prob method has the same number of distribuions with input
self.logits_np = np.random.rand(3, 2, 5).astype('float32')
self.other_logits_np = np.random.rand(3, 2, 5).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [3, 2, 5]
self.dist_shape = [3, 2]
self.sample_shape = [2, 4]
self.value_shape = [3]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.dist_shape + self.value_shape)
for i in range(3):
for j in range(2):
for k in range(3):
np_probs[i][j][k] = probability[i][j][self.value_np[k]]
return np_probs
class CategoricalTest8(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D list
# value used in probs and log_prob method is 1-D Tensor
self.logits = self.logits_np.tolist()
self.other_logits = self.other_logits_np.tolist()
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np.tolist()
self.other_logits_static = self.other_logits_np.tolist()
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest9(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D tuple
# value used in probs and log_prob method is 1-D Tensor
self.logits = tuple(self.logits_np.tolist())
self.other_logits = tuple(self.other_logits_np.tolist())
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = tuple(self.logits_np.tolist())
self.other_logits_static = tuple(self.other_logits_np.tolist())
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class DistributionTestError(unittest.TestCase):
def test_distribution_error(self):
distribution = Distribution()
self.assertRaises(NotImplementedError, distribution.sample)
self.assertRaises(NotImplementedError, distribution.entropy)
normal = Normal(0.0, 1.0)
self.assertRaises(NotImplementedError, distribution.kl_divergence,
normal)
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
self.assertRaises(NotImplementedError, distribution.log_prob,
value_tensor)
self.assertRaises(NotImplementedError, distribution.probs, value_tensor)
def test_normal_error(self):
paddle.enable_static()
normal = Normal(0.0, 1.0)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, normal.log_prob, value)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, normal.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, normal.sample, shape)
seed = 1.0
# type of seed must be int
self.assertRaises(TypeError, normal.sample, [2, 3], seed)
normal_other = Uniform(1.0, 2.0)
# type of other must be an instance of Normal
self.assertRaises(TypeError, normal.kl_divergence, normal_other)
def test_uniform_error(self):
paddle.enable_static()
uniform = Uniform(0.0, 1.0)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, uniform.log_prob, value)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, uniform.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, uniform.sample, shape)
seed = 1.0
# type of seed must be int
self.assertRaises(TypeError, uniform.sample, [2, 3], seed)
def test_categorical_error(self):
paddle.enable_static()
categorical = Categorical([0.4, 0.6])
value = [1, 0]
# type of value must be variable
self.assertRaises(AttributeError, categorical.log_prob, value)
value = [1, 0]
# type of value must be variable
self.assertRaises(AttributeError, categorical.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, categorical.sample, shape)
categorical_other = Uniform(1.0, 2.0)
# type of other must be an instance of Categorical
self.assertRaises(TypeError, categorical.kl_divergence,
categorical_other)
def test_shape_not_match_error():
# shape of value must match shape of logits
# value_shape[:-1] == logits_shape[:-1]
paddle.disable_static()
logits = paddle.rand([3, 5])
cat = Categorical(logits)
value = paddle.to_tensor([[2, 1, 3], [3, 2, 1]], dtype='int64')
cat.log_prob(value)
self.assertRaises(ValueError, test_shape_not_match_error)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.distribution import *
from paddle.fluid import layers
from test_distribution import DistributionNumpy
class UniformNumpy(DistributionNumpy):
def __init__(self, low, high):
self.low = np.array(low)
self.high = np.array(high)
if str(self.low.dtype) not in ['float32', 'float64']:
self.low = self.low.astype('float32')
self.high = self.high.astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.low + self.high).shape
return self.low + (np.random.uniform(size=shape) *
(self.high - self.low))
def log_prob(self, value):
lb = np.less(self.low, value).astype(self.low.dtype)
ub = np.less(value, self.high).astype(self.low.dtype)
return np.log(lb * ub) - np.log(self.high - self.low)
def probs(self, value):
lb = np.less(self.low, value).astype(self.low.dtype)
ub = np.less(value, self.high).astype(self.low.dtype)
return (lb * ub) / (self.high - self.low)
def entropy(self):
return np.log(self.high - self.low)
class UniformTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=5, dims=6):
self.use_gpu = use_gpu
if not use_gpu:
self.place = fluid.CPUPlace()
self.gpu_id = -1
else:
self.place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.init_numpy_data(batch_size, dims)
paddle.disable_static(self.place)
self.init_dynamic_data(batch_size, dims)
paddle.enable_static()
self.test_program = fluid.Program()
self.executor = fluid.Executor(self.place)
self.init_static_data(batch_size, dims)
def init_numpy_data(self, batch_size, dims):
# low ans high are 'float'
self.low_np = np.random.uniform(-2, 1)
self.high_np = np.random.uniform(2, 4)
self.values_np = np.array([1.0]).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = self.low_np
self.dynamic_high = self.high_np
self.dynamic_values = paddle.to_tensor(self.values_np)
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[], dtype='float32')
def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6):
sample, entropy, log_prob, probs = fetch_list
np_uniform = UniformNumpy(self.low_np, self.high_np)
np_sample = np_uniform.sample([sample_shape])
np_entropy = np_uniform.entropy()
np_lp = np_uniform.log_prob(self.values_np)
np_p = np_uniform.probs(self.values_np)
np.testing.assert_equal(sample.shape, np_sample.shape)
np.testing.assert_allclose(
entropy, np_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
log_prob, np_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance)
def test_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6):
paddle.disable_static(self.place)
uniform = Uniform(self.dynamic_low, self.dynamic_high)
sample = uniform.sample([sample_shape]).numpy()
entropy = uniform.entropy().numpy()
log_prob = uniform.log_prob(self.dynamic_values).numpy()
probs = uniform.probs(self.dynamic_values).numpy()
fetch_list = [sample, entropy, log_prob, probs]
self.compare_with_numpy(fetch_list)
def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6):
paddle.enable_static()
with fluid.program_guard(self.test_program):
uniform = Uniform(self.static_low, self.static_high)
sample = uniform.sample([sample_shape])
entropy = uniform.entropy()
log_prob = uniform.log_prob(self.static_values)
probs = uniform.probs(self.static_values)
fetch_list = [sample, entropy, log_prob, probs]
feed_vars = {
'low': self.low_np,
'high': self.high_np,
'values': self.values_np
}
self.executor.run(fluid.default_startup_program())
fetch_list = self.executor.run(program=self.test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_with_numpy(fetch_list)
class UniformTest2(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low ans high are 'int'
self.low_np = int(np.random.uniform(-2, 1))
self.high_np = int(np.random.uniform(2, 4))
self.values_np = np.array([1.0]).astype('float32')
class UniformTest3(UniformTest):
def init_numpy_data(self, batch_size, dims):
# test broadcast: low is float, high is numpy.ndarray with dtype 'float32'.
self.low_np = np.random.uniform(-2, 1)
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest4(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float32'.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest5(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float64'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = self.low_np
self.dynamic_high = self.high_np
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
class UniformTest6(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP32'.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np)
self.dynamic_high = paddle.to_tensor(self.high_np)
self.dynamic_values = paddle.to_tensor(self.values_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float32')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float32')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest7(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP64'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64')
self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64')
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float64')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float64')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
class UniformTest8(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
self.high_np = np.random.uniform(5.0, 15.0,
(batch_size, dims)).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64')
self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64')
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float32')
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_low = layers.data(
name='low', shape=[dims], dtype='float64')
self.static_high = layers.data(
name='high', shape=[dims], dtype='float64')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest9(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are numpy.ndarray with dtype 'float32'.
# high < low.
self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(-10.0, -5.0,
(batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest10(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are list.
self.low_np = np.random.randn(batch_size,
dims).astype('float32').tolist()
self.high_np = np.random.uniform(
5.0, 15.0, (batch_size, dims)).astype('float32').tolist()
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTest11(UniformTest):
def init_numpy_data(self, batch_size, dims):
# low and high are tuple.
self.low_np = tuple(
np.random.randn(batch_size, dims).astype('float32').tolist())
self.high_np = tuple(
np.random.uniform(5.0, 15.0, (batch_size, dims)).astype('float32')
.tolist())
self.values_np = np.random.randn(batch_size, dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_low = self.low_np
self.static_high = self.high_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class UniformTestSample(unittest.TestCase):
def setUp(self):
self.init_param()
def init_param(self):
self.low = 3.0
self.high = 4.0
def test_uniform_sample(self):
paddle.disable_static()
uniform = Uniform(low=self.low, high=self.high)
s = uniform.sample([100])
self.assertTrue((s >= self.low).all())
self.assertTrue((s < self.high).all())
paddle.enable_static()
class UniformTestSample2(UniformTestSample):
def init_param(self):
self.low = -5.0
self.high = 2.0
......@@ -276,6 +276,7 @@ packages=['paddle',
'paddle.incubate.tensor',
'paddle.incubate.nn',
'paddle.incubate.passes',
'paddle.distribution',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册