未验证 提交 fc8bc1ba 编写于 作者: P pangyoki 提交者: GitHub

Cherry pick 26767 (#27102)

* fix dtype not matching bug in log_prob and probs method of Distribution class (#26767)

* fix _to_tensor method of Distribution class

* Add unittest

* let dtype be consistent with value in log_prob and probs

* fix format

* fix dtype problem and change unittest

* fix dtype of Numpy class in unittest

* add formula for entropy and kl

* change formula

* fix kl formula format

* fix kl formula format 2

* change gt to np in unittest

* optimize unittest format

* delete dumplicate

* delete dumplicate 2

* extract common function used to convert dtype value

* cherry pick 27046
上级 eed05e1a
...@@ -102,21 +102,24 @@ class Distribution(object): ...@@ -102,21 +102,24 @@ class Distribution(object):
tmp = 0. tmp = 0.
for arg in args: for arg in args:
valid_arg = False
for cls in [float, list, np.ndarray, tensor.Variable]:
if isinstance(arg, cls):
valid_arg = True
break
assert valid_arg, "type of input args must be float, list, numpy.ndarray or Tensor."
if isinstance(arg, float): if isinstance(arg, float):
arg = np.zeros(1) + arg arg = [arg]
if not isinstance(arg, (list, np.ndarray, tensor.Variable)):
raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".
format(type(arg)))
arg_np = np.array(arg) arg_np = np.array(arg)
arg_dtype = arg_np.dtype arg_dtype = arg_np.dtype
if str(arg_dtype) not in ['float32']: if str(arg_dtype) != 'float32':
warnings.warn( if str(arg_dtype) != 'float64':
"data type of argument only support float32, your argument will be convert to float32." # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
) # and converted to float64 later using "cast".
warnings.warn(
"data type of argument only support float32 and float64, your argument will be convert to float32."
)
arg_np = arg_np.astype('float32') arg_np = arg_np.astype('float32')
# tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
tmp = tmp + arg_np tmp = tmp + arg_np
numpy_args.append(arg_np) numpy_args.append(arg_np)
...@@ -129,6 +132,37 @@ class Distribution(object): ...@@ -129,6 +132,37 @@ class Distribution(object):
return tuple(variable_args) return tuple(variable_args)
def _check_values_dtype_in_probs(self, param, value):
"""
Log_prob and probs methods have input ``value``, if value's dtype is different from param,
convert value's dtype to be consistent with param's dtype.
Args:
param (Tensor): low and high in Uniform class, loc and scale in Normal class.
value (Tensor): The input tensor.
Returns:
value (Tensor): Change value's dtype if value's dtype is different from param.
"""
if in_dygraph_mode():
if value.dtype != param.dtype and convert_dtype(
value.dtype) in ['float32', 'float64']:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return core.ops.cast(value, 'in_dtype', value.dtype,
'out_dtype', param.dtype)
return value
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
if value.dtype != param.dtype:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return tensor.cast(value, dtype=param.dtype)
return value
class Uniform(Distribution): class Uniform(Distribution):
"""Uniform distribution with `low` and `high` parameters. """Uniform distribution with `low` and `high` parameters.
...@@ -155,8 +189,8 @@ class Uniform(Distribution): ...@@ -155,8 +189,8 @@ class Uniform(Distribution):
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation). [broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args: Args:
low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples: Examples:
...@@ -206,6 +240,7 @@ class Uniform(Distribution): ...@@ -206,6 +240,7 @@ class Uniform(Distribution):
self.all_arg_is_float = False self.all_arg_is_float = False
self.batch_size_unknown = False self.batch_size_unknown = False
self.name = name if name is not None else 'Uniform' self.name = name if name is not None else 'Uniform'
self.dtype = 'float32'
if isinstance(low, int): if isinstance(low, int):
low = float(low) low = float(low)
...@@ -216,10 +251,22 @@ class Uniform(Distribution): ...@@ -216,10 +251,22 @@ class Uniform(Distribution):
self.batch_size_unknown = True self.batch_size_unknown = True
self.low = low self.low = low
self.high = high self.high = high
self.dtype = convert_dtype(low.dtype)
else: else:
if isinstance(low, float) and isinstance(high, float): if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True self.all_arg_is_float = True
if isinstance(
low,
np.ndarray) and str(low.dtype) in ['float32', 'float64']:
self.dtype = low.dtype
elif isinstance(
high,
np.ndarray) and str(high.dtype) in ['float32', 'float64']:
self.dtype = high.dtype
self.low, self.high = self._to_tensor(low, high) self.low, self.high = self._to_tensor(low, high)
if self.dtype != convert_dtype(self.low.dtype):
self.low = tensor.cast(self.low, dtype=self.dtype)
self.high = tensor.cast(self.high, dtype=self.dtype)
def sample(self, shape, seed=0): def sample(self, shape, seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
...@@ -241,11 +288,11 @@ class Uniform(Distribution): ...@@ -241,11 +288,11 @@ class Uniform(Distribution):
if self.batch_size_unknown: if self.batch_size_unknown:
output_shape = shape + batch_shape output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like( zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.low.dtype, 0.) self.low + self.high, batch_shape + shape, self.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like( uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp, zero_tmp,
zero_tmp.shape, zero_tmp.shape,
dtype=convert_dtype(zero_tmp.dtype), dtype=self.dtype,
min=0., min=0.,
max=1., max=1.,
seed=seed) seed=seed)
...@@ -259,9 +306,8 @@ class Uniform(Distribution): ...@@ -259,9 +306,8 @@ class Uniform(Distribution):
else: else:
output_shape = shape + batch_shape output_shape = shape + batch_shape
output = nn.uniform_random( output = nn.uniform_random(
output_shape, seed=seed) * (tensor.zeros( output_shape, seed=seed, dtype=self.dtype) * (tensor.zeros(
output_shape, dtype=self.low.dtype) + output_shape, dtype=self.dtype) + (self.high - self.low))
(self.high - self.low))
output = elementwise_add(output, self.low, name=name) output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float: if self.all_arg_is_float:
return nn.reshape(output, shape, name=name) return nn.reshape(output, shape, name=name)
...@@ -279,22 +325,20 @@ class Uniform(Distribution): ...@@ -279,22 +325,20 @@ class Uniform(Distribution):
""" """
name = self.name + '_log_prob' name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode(): if in_dygraph_mode():
# ensure value in [low, high]
lb_bool = self.low < value lb_bool = self.low < value
ub_bool = value < self.high ub_bool = value < self.high
dtype = value.dtype
lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype', lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype) value.dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype) value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low) return nn.log(lb * ub) - nn.log(self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'], lb_bool = self.low < value
'log_prob') ub_bool = value < self.high
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb = tensor.cast(lb_bool, dtype=value.dtype) lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub( return elementwise_sub(
...@@ -311,22 +355,19 @@ class Uniform(Distribution): ...@@ -311,22 +355,19 @@ class Uniform(Distribution):
""" """
name = self.name + '_probs' name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode(): if in_dygraph_mode():
lb_bool = self.low < value lb_bool = self.low < value
ub_bool = value < self.high ub_bool = value < self.high
dtype = value.dtype
lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype', lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype) value.dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype) value.dtype)
return (lb * ub) / (self.high - self.low) return (lb * ub) / (self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'], lb_bool = self.low < value
'log_prob') ub_bool = value < self.high
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb = tensor.cast(lb_bool, dtype=value.dtype) lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name) return elementwise_div((lb * ub), (self.high - self.low), name=name)
...@@ -334,6 +375,12 @@ class Uniform(Distribution): ...@@ -334,6 +375,12 @@ class Uniform(Distribution):
def entropy(self): def entropy(self):
"""Shannon entropy in nats. """Shannon entropy in nats.
The entropy is
.. math::
entropy(low, high) = \\log (high - low)
Returns: Returns:
Tensor: Shannon entropy of uniform distribution.The data type is float32. Tensor: Shannon entropy of uniform distribution.The data type is float32.
...@@ -364,8 +411,8 @@ class Normal(Distribution): ...@@ -364,8 +411,8 @@ class Normal(Distribution):
* :math:`Z`: is the normalization constant. * :math:`Z`: is the normalization constant.
Args: Args:
loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor. loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor. scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples: Examples:
...@@ -418,6 +465,7 @@ class Normal(Distribution): ...@@ -418,6 +465,7 @@ class Normal(Distribution):
self.batch_size_unknown = False self.batch_size_unknown = False
self.all_arg_is_float = False self.all_arg_is_float = False
self.name = name if name is not None else 'Normal' self.name = name if name is not None else 'Normal'
self.dtype = 'float32'
if isinstance(loc, int): if isinstance(loc, int):
loc = float(loc) loc = float(loc)
...@@ -428,10 +476,22 @@ class Normal(Distribution): ...@@ -428,10 +476,22 @@ class Normal(Distribution):
self.batch_size_unknown = True self.batch_size_unknown = True
self.loc = loc self.loc = loc
self.scale = scale self.scale = scale
self.dtype = convert_dtype(loc.dtype)
else: else:
if isinstance(loc, float) and isinstance(scale, float): if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True self.all_arg_is_float = True
if isinstance(
loc,
np.ndarray) and str(loc.dtype) in ['float32', 'float64']:
self.dtype = loc.dtype
elif isinstance(
scale,
np.ndarray) and str(scale.dtype) in ['float32', 'float64']:
self.dtype = scale.dtype
self.loc, self.scale = self._to_tensor(loc, scale) self.loc, self.scale = self._to_tensor(loc, scale)
if self.dtype != convert_dtype(self.loc.dtype):
self.loc = tensor.cast(self.loc, dtype=self.dtype)
self.scale = tensor.cast(self.scale, dtype=self.dtype)
def sample(self, shape, seed=0): def sample(self, shape, seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
...@@ -454,22 +514,18 @@ class Normal(Distribution): ...@@ -454,22 +514,18 @@ class Normal(Distribution):
if self.batch_size_unknown: if self.batch_size_unknown:
output_shape = shape + batch_shape output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like( zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.) self.loc + self.scale, batch_shape + shape, self.dtype, 0.)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape) zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
zero_tmp_shape = nn.shape(zero_tmp_reshape) zero_tmp_shape = nn.shape(zero_tmp_reshape)
normal_random_tmp = nn.gaussian_random( normal_random_tmp = nn.gaussian_random(
zero_tmp_shape, zero_tmp_shape, mean=0., std=1., seed=seed, dtype=self.dtype)
mean=0.,
std=1.,
seed=seed,
dtype=convert_dtype(self.loc.dtype))
output = normal_random_tmp * (zero_tmp_reshape + self.scale) output = normal_random_tmp * (zero_tmp_reshape + self.scale)
output = elementwise_add(output, self.loc, name=name) output = elementwise_add(output, self.loc, name=name)
return output return output
else: else:
output_shape = shape + batch_shape output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed) * \ output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \
(tensor.zeros(output_shape, dtype=self.loc.dtype) + self.scale) (tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name) output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float: if self.all_arg_is_float:
return nn.reshape(output, shape, name=name) return nn.reshape(output, shape, name=name)
...@@ -479,6 +535,16 @@ class Normal(Distribution): ...@@ -479,6 +535,16 @@ class Normal(Distribution):
def entropy(self): def entropy(self):
"""Shannon entropy in nats. """Shannon entropy in nats.
The entropy is
.. math::
entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2)
In the above equation:
* :math:`scale = \sigma`: is the std.
Returns: Returns:
Tensor: Shannon entropy of normal distribution.The data type is float32. Tensor: Shannon entropy of normal distribution.The data type is float32.
...@@ -486,7 +552,7 @@ class Normal(Distribution): ...@@ -486,7 +552,7 @@ class Normal(Distribution):
name = self.name + '_entropy' name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape) batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like( zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.loc.dtype, 0.) self.loc + self.scale, batch_shape, self.dtype, 0.)
return elementwise_add( return elementwise_add(
0.5 + zero_tmp, 0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)), 0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)),
...@@ -502,11 +568,9 @@ class Normal(Distribution): ...@@ -502,11 +568,9 @@ class Normal(Distribution):
Tensor: log probability.The data type is same with value. Tensor: log probability.The data type is same with value.
""" """
if not in_dygraph_mode():
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
name = self.name + '_log_prob' name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale var = self.scale * self.scale
log_scale = nn.log(self.scale) log_scale = nn.log(self.scale)
return elementwise_sub( return elementwise_sub(
...@@ -524,11 +588,9 @@ class Normal(Distribution): ...@@ -524,11 +588,9 @@ class Normal(Distribution):
Tensor: probability.The data type is same with value. Tensor: probability.The data type is same with value.
""" """
if not in_dygraph_mode():
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
name = self.name + '_probs' name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale var = self.scale * self.scale
return elementwise_div( return elementwise_div(
ops.exp(-1. * ((value - self.loc) * (value - self.loc)) / ops.exp(-1. * ((value - self.loc) * (value - self.loc)) /
...@@ -538,6 +600,29 @@ class Normal(Distribution): ...@@ -538,6 +600,29 @@ class Normal(Distribution):
def kl_divergence(self, other): def kl_divergence(self, other):
"""The KL-divergence between two normal distributions. """The KL-divergence between two normal distributions.
The probability density function (pdf) is
.. math::
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio})
.. math::
ratio = \\frac{\sigma_0}{\sigma_1}
.. math::
diff = \mu_1 - \mu_0
In the above equation:
* :math:`loc = \mu_0`: is the mean of current Normal distribution.
* :math:`scale = \sigma_0`: is the std of current Normal distribution.
* :math:`loc = \mu_1`: is the mean of other Normal distribution.
* :math:`scale = \sigma_1`: is the std of other Normal distribution.
* :math:`ratio`: is the ratio of scales.
* :math:`diff`: is the difference between means.
Args: Args:
other (Normal): instance of Normal. other (Normal): instance of Normal.
......
...@@ -40,8 +40,11 @@ class DistributionNumpy(): ...@@ -40,8 +40,11 @@ class DistributionNumpy():
class UniformNumpy(DistributionNumpy): class UniformNumpy(DistributionNumpy):
def __init__(self, low, high): def __init__(self, low, high):
self.low = np.array(low).astype('float32') self.low = np.array(low)
self.high = np.array(high).astype('float32') self.high = np.array(high)
if str(self.low.dtype) not in ['float32', 'float64']:
self.low = self.low.astype('float32')
self.high = self.high.astype('float32')
def sample(self, shape): def sample(self, shape):
shape = tuple(shape) + (self.low + self.high).shape shape = tuple(shape) + (self.low + self.high).shape
...@@ -49,13 +52,13 @@ class UniformNumpy(DistributionNumpy): ...@@ -49,13 +52,13 @@ class UniformNumpy(DistributionNumpy):
(self.high - self.low)) (self.high - self.low))
def log_prob(self, value): def log_prob(self, value):
lb = np.less(self.low, value).astype('float32') lb = np.less(self.low, value).astype(self.low.dtype)
ub = np.less(value, self.high).astype('float32') ub = np.less(value, self.high).astype(self.low.dtype)
return np.log(lb * ub) - np.log(self.high - self.low) return np.log(lb * ub) - np.log(self.high - self.low)
def probs(self, value): def probs(self, value):
lb = np.less(self.low, value).astype('float32') lb = np.less(self.low, value).astype(self.low.dtype)
ub = np.less(value, self.high).astype('float32') ub = np.less(value, self.high).astype(self.low.dtype)
return (lb * ub) / (self.high - self.low) return (lb * ub) / (self.high - self.low)
def entropy(self): def entropy(self):
...@@ -64,8 +67,11 @@ class UniformNumpy(DistributionNumpy): ...@@ -64,8 +67,11 @@ class UniformNumpy(DistributionNumpy):
class NormalNumpy(DistributionNumpy): class NormalNumpy(DistributionNumpy):
def __init__(self, loc, scale): def __init__(self, loc, scale):
self.loc = np.array(loc).astype('float32') self.loc = np.array(loc)
self.scale = np.array(scale).astype('float32') self.scale = np.array(scale)
if str(self.loc.dtype) not in ['float32', 'float64']:
self.loc = self.loc.astype('float32')
self.scale = self.scale.astype('float32')
def sample(self, shape): def sample(self, shape):
shape = tuple(shape) + (self.loc + self.scale).shape shape = tuple(shape) + (self.loc + self.scale).shape
...@@ -83,8 +89,8 @@ class NormalNumpy(DistributionNumpy): ...@@ -83,8 +89,8 @@ class NormalNumpy(DistributionNumpy):
(2. * var)) / (math.sqrt(2 * math.pi) * self.scale) (2. * var)) / (math.sqrt(2 * math.pi) * self.scale)
def entropy(self): def entropy(self):
return 0.5 + 0.5 * np.log(np.array(2. * math.pi).astype( return 0.5 + 0.5 * np.log(
'float32')) + np.log(self.scale) np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale)
def kl_divergence(self, other): def kl_divergence(self, other):
var_ratio = (self.scale / other.scale) var_ratio = (self.scale / other.scale)
...@@ -94,724 +100,571 @@ class NormalNumpy(DistributionNumpy): ...@@ -94,724 +100,571 @@ class NormalNumpy(DistributionNumpy):
return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio))
class DistributionTest(unittest.TestCase): class UniformTest(unittest.TestCase):
def setUp(self, use_gpu=False): def setUp(self, use_gpu=False, batch_size=5, dims=6):
self.use_gpu = use_gpu self.use_gpu = use_gpu
if not use_gpu: if not use_gpu:
place = fluid.CPUPlace() self.place = fluid.CPUPlace()
self.gpu_id = -1 self.gpu_id = -1
else: else:
place = fluid.CUDAPlace(0) self.place = fluid.CUDAPlace(0)
self.gpu_id = 0 self.gpu_id = 0
self.executor = fluid.Executor(place)
def build_normal_common_net(self, batch_size, dims, sample_shape, loc_float,
scale_float, other_loc_float, other_scale_float,
scale_np, other_scale_np, loc_np, other_loc_np,
loc, scale, other_loc, other_scale, values):
"""Generate Normal object and get the output of its methods including
``sample``, ``entropy``, ``log_prob``, ``probs`` and ``kl_divergence``.
Parameters ``loc`` and ``scale`` have different data types to test different situations.
Args:
batch_size(int): The first dimension of the shape of parameters(loc and scale).
dims(int): The second dimension of the shape of parameters.
sample_shape(int): The sample value used in ``sample`` method.
loc_float(float): Generated in function ``get_normal_random_input``, loc is a float number.
scale_float(float): Generated in function ``get_normal_random_input``, scale is a float number.
other_loc_float(float): Generated in function ``get_normal_random_input``, other_loc is a
float number. It is the first parameter in another Normal object used in ``kl_divergence``
method.
other_scale_float(float): Generated in function ``get_normal_random_input``, other_scale is a
float number. It is the second parameter in another Normal object used in ``kl_divergence``
method.
scale_np(numpy.ndarray): Generated in function ``get_normal_random_input``, An numpy array
whose shape is [batch_size, dims].
other_scale_np(numpy.ndarray): Generated in function ``get_normal_random_input``, other_scale_np
is an numpy array. It is the second parameter in another Normal object used in ``kl_divergence``
method.
loc_np(numpy.ndarray): Generated in function ``get_normal_random_input``, An numpy array
whose shape is [batch_size, dims].
other_loc_np(numpy.ndarray): Generated in function ``get_normal_random_input``, other_loc_np
is an numpy array. It is the first parameter in another Normal object used in ``kl_divergence``
method.
loc(Tensor): In dynamic mode, loc is generated in ``build_normal_dygraph``, it's a Tensor filled
with ``loc_np`` data. In static mode, loc is generated in ``build_normal_static``, ``layers.data``
method is used to get a Placeholder whose shape is [dims].
scale(Tensor): In dynamic mode, scale is generated in ``build_normal_dygraph``, it's a Tensor filled
with ``scale_np`` data. In static mode, scale is generated in ``build_normal_static``, ``layers.data``
method is used to get a Placeholder whose shape is [dims].
other_loc(Tensor): In dynamic mode, other_loc is generated in ``build_normal_dygraph``, it's a Tensor
filled with ``other_loc_np`` data. In static mode, other_loc is generated in ``build_normal_static``,
``layers.data`` method is used to get a Placeholder whose shape is [dims]. It is the first parameter
in another Normal object used in ``kl_divergence`` method.
other_scale(Tensor): In dynamic mode, other_scale is generated in ``build_normal_dygraph``, it's a Tensor
filled with ``other_scale_np`` data. In static mode, other_scale is generated in ``build_normal_static``,
``layers.data`` method is used to get a Placeholder whose shape is [dims]. It is the second parameter
in another Normal object used in ``kl_divergence`` method.
values(Tensor): In dynamic mode, values is generated in ``build_normal_dygraph``, it's a Tensor filled with
``values_np`` data. In static mode, values is generated in ``build_normal_static``, ``layers.data``
method is used to get a Placeholder whose shape is [dims].
Returns:
List: The elements of the list are the output of sample, entropy, log_prob, probs, kl_divergence methods.
The inputs' type of these methods can be float, np.ndarray and Tensor. And broadcast will be considered.
"""
normal_int = Normal(int(loc_float), int(scale_float))
normal_float = Normal(loc_float, scale_float)
other_normal_float = Normal(other_loc_float, other_scale_float)
normal_float_np_broadcast = Normal(loc_float, scale_np)
other_normal_float_np_broadcast = Normal(other_loc_float,
other_scale_np)
normal_np = Normal(loc_np, scale_np)
other_normal_np = Normal(other_loc_np, other_scale_np)
normal_variable = Normal(loc, scale)
other_normal_variable = Normal(other_loc, other_scale)
sample_int = normal_int.sample([batch_size, dims])
sample_float = normal_float.sample([batch_size, dims])
sample_float_np_broadcast = normal_float_np_broadcast.sample(
[batch_size, dims])
sample_np = normal_np.sample([batch_size, dims])
sample_variable = normal_variable.sample([batch_size, dims])
sample_int_diff = normal_int.sample([sample_shape])
sample_float_diff = normal_float.sample([sample_shape])
sample_float_np_broadcast_diff = normal_float_np_broadcast.sample(
[sample_shape])
sample_np_diff = normal_np.sample([sample_shape])
sample_variable_diff = normal_variable.sample([sample_shape])
entropy_int = normal_int.entropy()
entropy_float = normal_float.entropy()
entropy_float_np_broadcast = normal_float_np_broadcast.entropy()
entropy_np = normal_np.entropy()
entropy_variable = normal_variable.entropy()
lp_float_np_broadcast = normal_float_np_broadcast.log_prob(values)
lp_np = normal_np.log_prob(values)
lp_variable = normal_variable.log_prob(values)
p_float_np_broadcast = normal_float_np_broadcast.probs(values)
p_np = normal_np.probs(values)
p_variable = normal_variable.probs(values)
kl_float = normal_float.kl_divergence(other_normal_float)
kl_float_np_broadcast = normal_float_np_broadcast.kl_divergence(
other_normal_float_np_broadcast)
kl_np = normal_np.kl_divergence(other_normal_np)
kl_variable = normal_variable.kl_divergence(other_normal_variable)
fetch_list = [
sample_int, sample_float, sample_float_np_broadcast, sample_np,
sample_variable, sample_int_diff, sample_float_diff,
sample_float_np_broadcast_diff, sample_np_diff,
sample_variable_diff, entropy_int, entropy_float,
entropy_float_np_broadcast, entropy_np, entropy_variable,
lp_float_np_broadcast, lp_np, lp_variable, p_float_np_broadcast,
p_np, p_variable, kl_float, kl_float_np_broadcast, kl_np,
kl_variable
]
return fetch_list
def build_normal_static(self, test_program, batch_size, dims, sample_shape,
loc_float, scale_float, other_loc_float,
other_scale_float, scale_np, other_scale_np, loc_np,
other_loc_np, values_np):
"""
In static mode, generate feed data of Normal network, and get output fetch_list using
``build_normal_common_net``.
Args:
test_program: In static mode, the Program object.
other args can refer to function ``build_normal_common_net``.
Returns:
feed_vars: The feed data of Normal network in static mode.
fetch_list: The output is generated by function ``build_normal_common_net``.
"""
with fluid.program_guard(test_program):
loc = layers.data(name='loc', shape=[dims], dtype='float32')
scale = layers.data(name='scale', shape=[dims], dtype='float32')
other_loc = layers.data(
name='other_loc', shape=[dims], dtype='float32')
other_scale = layers.data(
name='other_scale', shape=[dims], dtype='float32')
values = layers.data(name='values', shape=[dims], dtype='float32') self.init_numpy_data(batch_size, dims)
fetch_list = self.build_normal_common_net( paddle.disable_static(self.place)
batch_size, dims, sample_shape, loc_float, scale_float, self.init_dynamic_data(batch_size, dims)
other_loc_float, other_scale_float, scale_np, other_scale_np,
loc_np, other_loc_np, loc, scale, other_loc, other_scale,
values)
feed_vars = { paddle.enable_static()
'loc': loc_np, self.test_program = fluid.Program()
'scale': scale_np, self.executor = fluid.Executor(self.place)
'other_loc': other_loc_np, self.init_static_data(batch_size, dims)
'other_scale': other_scale_np,
'values': values_np def init_numpy_data(self, batch_size, dims):
} # low ans high are 'float'
return feed_vars, fetch_list self.low_np = np.random.uniform(-2, 1)
self.high_np = np.random.uniform(1, 3)
def build_normal_dygraph(self, batch_size, dims, sample_shape, loc_float, self.values_np = np.array([1.0]).astype('float32')
scale_float, other_loc_float, other_scale_float,
scale_np, other_scale_np, loc_np, other_loc_np, def init_dynamic_data(self, batch_size, dims):
values_np): self.dynamic_low = self.low_np
""" self.dynamic_high = self.high_np
In dynamic mode, generate input data of Normal network, and get output fetch_list using self.dynamic_values = paddle.to_tensor(self.values_np)
``build_normal_common_net``.
def init_static_data(self, batch_size, dims):
Args: self.static_low = self.low_np
refer to function ``build_normal_common_net``. self.static_high = self.high_np
with fluid.program_guard(self.test_program):
Returns: self.static_values = layers.data(
fetch_list_numpy: The output is generated by function ``build_normal_common_net``. Transform name='values', shape=[], dtype='float32')
these tensor to numpy.ndarray.
""" def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6):
loc = paddle.to_tensor(loc_np) sample, entropy, log_prob, probs = fetch_list
scale = paddle.to_tensor(scale_np)
other_loc = paddle.to_tensor(other_loc_np) np_uniform = UniformNumpy(self.low_np, self.high_np)
other_scale = paddle.to_tensor(other_scale_np) np_sample = np_uniform.sample([sample_shape])
values = paddle.to_tensor(values_np) np_entropy = np_uniform.entropy()
np_lp = np_uniform.log_prob(self.values_np)
fetch_list = self.build_normal_common_net( np_p = np_uniform.probs(self.values_np)
batch_size, dims, sample_shape, loc_float, scale_float,
other_loc_float, other_scale_float, scale_np, other_scale_np, np.testing.assert_equal(sample.shape, np_sample.shape)
loc_np, other_loc_np, loc, scale, other_loc, other_scale, values)
fetch_list_numpy = [t.numpy() for t in fetch_list]
return fetch_list_numpy
def get_normal_random_input(self, batch_size, dims):
"""
Generate input data ``loc`` and ``scale`` used in Normal network.
Args:
refer to function ``build_normal_common_net``.
Returns:
List: Different data type of ``loc`` and ``scale``, including float, numpy.ndarray.
By the way, ``other_loc`` and ``other_scale`` are used in ``kl_divergence`` method.
refer to ``args`` in function ``build_normal_common_net``.
"""
loc_np = np.random.randn(batch_size, dims).astype('float32')
other_loc_np = np.random.randn(batch_size, dims).astype('float32')
loc_float = (np.random.ranf() - 0.5) * 4
scale_float = (np.random.ranf() - 0.5) * 4
while scale_float < 0:
scale_float = (np.random.ranf() - 0.5) * 4
other_loc_float = (np.random.ranf() - 0.5) * 4
other_scale_float = (np.random.ranf() - 0.5) * 4
while other_scale_float < 0:
other_scale_float = (np.random.ranf() - 0.5) * 4
scale_np = np.random.randn(batch_size, dims).astype('float32')
other_scale_np = np.random.randn(batch_size, dims).astype('float32')
values_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(scale_np > 0):
scale_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(other_scale_np > 0):
other_scale_np = np.random.randn(batch_size, dims).astype('float32')
return [
loc_np, other_loc_np, loc_float, scale_float, other_loc_float,
other_scale_float, scale_np, other_scale_np, values_np
]
def compare_normal_with_numpy(self,
data_list,
output_list,
batch_size=2,
dims=3,
sample_shape=7,
tolerance=1e-6):
"""
Compare the outputs of Normal's methods in paddle and numpy. If the outputs are not consistent,
raise errors.
Args:
data_list: Input data generated by function ``get_normal_random_input``.
output_list: The outputs of Normal's methods in static or dynamic mode.
batch_size(int): The first dimension of the shape of parameters(loc and scale).
dims(int): The second dimension of the shape of parameters.
sample_shape(int): The sample value used in ``sample`` method.
tolerance(float): The tolerance of the error.
"""
loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list
np_normal_int = NormalNumpy(int(loc_float), int(scale_float))
np_normal_float = NormalNumpy(loc_float, scale_float)
np_other_normal_float = NormalNumpy(other_loc_float, other_scale_float)
np_normal_float_np_broadcast = NormalNumpy(loc_float, scale_np)
np_other_normal_float_np_broadcast = NormalNumpy(other_loc_float,
other_scale_np)
np_normal = NormalNumpy(loc_np, scale_np)
np_other_normal = NormalNumpy(other_loc_np, other_scale_np)
gt_sample_int = np_normal_int.sample([batch_size, dims])
gt_sample_float = np_normal_float.sample([batch_size, dims])
gt_sample_float_np_broadcast = np_normal_float_np_broadcast.sample(
[batch_size, dims])
gt_sample_np = np_normal.sample([batch_size, dims])
gt_sample_int_diff = np_normal_int.sample([sample_shape])
gt_sample_float_diff = np_normal_float.sample([sample_shape])
gt_sample_float_np_broadcast_diff = np_normal_float_np_broadcast.sample(
[sample_shape])
gt_sample_np_diff = np_normal.sample([sample_shape])
gt_entropy_int = np_normal_int.entropy()
gt_entropy_float = np_normal_float.entropy()
gt_entropy_float_np_broadcast = np_normal_float_np_broadcast.entropy()
gt_entropy = np_normal.entropy()
gt_lp_float_np_broadcast = np_normal_float_np_broadcast.log_prob(
values_np)
gt_lp = np_normal.log_prob(values_np)
gt_p_float_np_broadcast = np_normal_float_np_broadcast.probs(values_np)
gt_p = np_normal.probs(values_np)
gt_kl_float = np_normal_float.kl_divergence(np_other_normal_float)
gt_kl_float_np_broadcast = np_normal_float_np_broadcast.kl_divergence(
np_other_normal_float_np_broadcast)
gt_kl = np_normal.kl_divergence(np_other_normal)
[
output_sample_int, output_sample_float,
output_sample_float_np_broadcast, output_sample_np,
output_sample_variable, output_sample_int_diff,
output_sample_float_diff, output_sample_float_np_broadcast_diff,
output_sample_np_diff, output_sample_variable_diff,
output_entropy_int, output_entropy_float,
output_entropy_float_np_broadcast, output_entropy_np,
output_entropy_variable, output_lp_float_np_broadcast, output_lp_np,
output_lp_variable, output_p_float_np_broadcast, output_p_np,
output_p_variable, output_kl_float, output_kl_float_np_broadcast,
output_kl_np, output_kl_variable
] = output_list
np.testing.assert_equal(output_sample_int.shape, gt_sample_int.shape)
np.testing.assert_equal(output_sample_float.shape,
gt_sample_float.shape)
np.testing.assert_equal(output_sample_float_np_broadcast.shape,
gt_sample_float_np_broadcast.shape)
np.testing.assert_equal(output_sample_np.shape, gt_sample_np.shape)
np.testing.assert_equal(output_sample_variable.shape,
gt_sample_np.shape)
np.testing.assert_equal(output_sample_int_diff.shape,
gt_sample_int_diff.shape)
np.testing.assert_equal(output_sample_float_diff.shape,
gt_sample_float_diff.shape)
np.testing.assert_equal(output_sample_float_np_broadcast_diff.shape,
gt_sample_float_np_broadcast_diff.shape)
np.testing.assert_equal(output_sample_np_diff.shape,
gt_sample_np_diff.shape)
np.testing.assert_equal(output_sample_variable_diff.shape,
gt_sample_np_diff.shape)
np.testing.assert_allclose(
output_entropy_int, gt_entropy_int, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_float,
gt_entropy_float,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_float_np_broadcast,
gt_entropy_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_np, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_variable, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_float_np_broadcast,
gt_lp_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_lp_np, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_variable, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_p_float_np_broadcast,
gt_p_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_p_np, gt_p, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_p_variable, gt_p, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_kl_float, gt_kl_float, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_kl_float_np_broadcast,
gt_kl_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose( np.testing.assert_allclose(
output_kl_np, gt_kl, rtol=tolerance, atol=tolerance) entropy, np_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose( np.testing.assert_allclose(
output_kl_variable, gt_kl, rtol=tolerance, atol=tolerance) log_prob, np_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance)
def test_normal_distribution_static(self,
batch_size=2,
dims=3,
sample_shape=7,
tolerance=1e-6):
"""
Test Normal's methods in static mode.
Args:
refer to ``compare_normal_with_numpy`` function.
"""
test_program = fluid.Program()
data_list = self.get_normal_random_input(batch_size, dims)
loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list
feed_vars, fetch_list = self.build_normal_static(
test_program, batch_size, dims, sample_shape, loc_float,
scale_float, other_loc_float, other_scale_float, scale_np,
other_scale_np, loc_np, other_loc_np, values_np)
self.executor.run(fluid.default_startup_program())
output_list = self.executor.run(program=test_program, def test_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6):
feed=feed_vars, paddle.disable_static(self.place)
fetch_list=fetch_list) uniform = Uniform(self.dynamic_low, self.dynamic_high)
sample = uniform.sample([sample_shape]).numpy()
self.compare_normal_with_numpy(data_list, output_list, batch_size, dims, entropy = uniform.entropy().numpy()
sample_shape, tolerance) log_prob = uniform.log_prob(self.dynamic_values).numpy()
probs = uniform.probs(self.dynamic_values).numpy()
def test_normal_distribution_dygraph(self, fetch_list = [sample, entropy, log_prob, probs]
batch_size=2,
dims=3, self.compare_with_numpy(fetch_list)
sample_shape=7,
tolerance=1e-6): def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6):
"""
Test Normal's methods in dynamic mode.
Args:
refer to ``compare_normal_with_numpy`` function.
"""
paddle.disable_static()
data_list = self.get_normal_random_input(batch_size, dims)
loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list
output_list = self.build_normal_dygraph(
batch_size, dims, sample_shape, loc_float, scale_float,
other_loc_float, other_scale_float, scale_np, other_scale_np,
loc_np, other_loc_np, values_np)
self.compare_normal_with_numpy(data_list, output_list, batch_size, dims,
sample_shape, tolerance)
paddle.enable_static() paddle.enable_static()
with fluid.program_guard(self.test_program):
uniform = Uniform(self.static_low, self.static_high)
sample = uniform.sample([sample_shape])
entropy = uniform.entropy()
log_prob = uniform.log_prob(self.static_values)
probs = uniform.probs(self.static_values)
fetch_list = [sample, entropy, log_prob, probs]
def build_uniform_common_net(self, batch_size, dims, sample_shape, feed_vars = {
low_float, high_float, high_np, low_np, 'low': self.low_np,
values_np, low, high, values): 'high': self.high_np,
"""Generate Uniform object and get the output of its methods including ``sample``, ``entropy``, 'values': self.values_np
``log_prob`` and ``probs``. }
Parameters ``low`` and ``high`` have different data types to test different situations.
self.executor.run(fluid.default_startup_program())
Args: fetch_list = self.executor.run(program=self.test_program,
batch_size(int): The first dimension of the shape of parameters(low and high). feed=feed_vars,
dims(int): The second dimension of the shape of parameters. fetch_list=fetch_list)
sample_shape(int): The sample value used in ``sample`` method.
low_float(float): Parameter ``low`` is a float number. self.compare_with_numpy(fetch_list)
high_float(float): Parameter ``high`` is a float number.
high_np(numpy.ndarray): An numpy array whose shape is [batch_size, dims].
low_np(numpy.ndarray): An numpy array whose shape is [batch_size, dims]. class UniformTest2(UniformTest):
values_np(numpy.ndarray): The input of ``log_prob`` and ``probs`` methods. An numpy array whose def init_numpy_data(self, batch_size, dims):
shape is [batch_size, dims]. # low ans high are 'int'
low(Tensor): In dynamic mode, low is generated in ``build_uniform_dygraph``, it's a Tensor filled self.low_np = int(np.random.uniform(-2, 1))
with ``low_np`` data. In static mode, low is generated in ``build_uniform_static``. self.high_np = int(np.random.uniform(1, 3))
high(Tensor): In dynamic mode, high is generated in ``build_uniform_dygraph``, it's a Tensor filled self.values_np = np.array([1.0]).astype('float32')
with ``high_np`` data. In static mode, high is generated in ``build_uniform_static``.
values(Tensor): In dynamic mode, values is generated in ``build_uniform_dygraph``, it's a Tensor
filled with ``values_np`` data. In static mode, values is generated in ``build_uniform_static``. class UniformTest3(UniformTest):
def init_numpy_data(self, batch_size, dims):
Returns: # test broadcast: low is float, high is numpy.ndarray with dtype 'float32'.
List: The elements of the list are the output of sample, entropy, log_prob, probs methods. self.low_np = np.random.uniform(-2, 1)
The inputs' type of these methods can be float, np.ndarray and Tensor. And broadcast will be self.high_np = np.random.uniform(-5.0, 5.0,
considered. (batch_size, dims)).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
"""
uniform_int = Uniform(int(low_float), int(high_float)) def init_static_data(self, batch_size, dims):
uniform_float = Uniform(low_float, high_float) self.static_low = self.low_np
uniform_float_np_broadcast = Uniform(low_float, high_np) self.static_high = self.high_np
uniform_np = Uniform(low_np, high_np) with fluid.program_guard(self.test_program):
uniform_variable = Uniform(low, high) self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
sample_int = uniform_int.sample([batch_size, dims])
sample_float = uniform_float.sample([batch_size, dims])
sample_float_np_broadcast = uniform_float_np_broadcast.sample( class UniformTest4(UniformTest):
[batch_size, dims]) def init_numpy_data(self, batch_size, dims):
sample_np = uniform_np.sample([batch_size, dims]) # low and high are numpy.ndarray with dtype 'float32'.
sample_variable = uniform_variable.sample([batch_size, dims]) self.low_np = np.random.randn(batch_size, dims).astype('float32')
self.high_np = np.random.uniform(-5.0, 5.0,
sample_int_diff = uniform_int.sample([sample_shape]) (batch_size, dims)).astype('float32')
sample_float_diff = uniform_float.sample([sample_shape]) self.values_np = np.random.randn(batch_size, dims).astype('float32')
sample_float_np_broadcast_diff = uniform_float_np_broadcast.sample(
[sample_shape]) def init_static_data(self, batch_size, dims):
sample_np_diff = uniform_np.sample([sample_shape]) self.static_low = self.low_np
sample_variable_diff = uniform_variable.sample([sample_shape]) self.static_high = self.high_np
with fluid.program_guard(self.test_program):
entropy_int = uniform_int.entropy() self.static_values = layers.data(
entropy_float = uniform_float.entropy() name='values', shape=[dims], dtype='float32')
entropy_float_np_broadcast = uniform_float_np_broadcast.entropy()
entropy_np = uniform_np.entropy()
entropy_variable = uniform_variable.entropy() class UniformTest5(UniformTest):
def init_numpy_data(self, batch_size, dims):
lp_float_np_broadcast = uniform_float_np_broadcast.log_prob(values) # low and high are numpy.ndarray with dtype 'float64'.
lp_np = uniform_np.log_prob(values) self.low_np = np.random.randn(batch_size, dims).astype('float64')
lp_variable = uniform_variable.log_prob(values) self.high_np = np.random.uniform(-5.0, 5.0,
(batch_size, dims)).astype('float64')
p_float_np_broadcast = uniform_float_np_broadcast.probs(values) self.values_np = np.random.randn(batch_size, dims).astype('float64')
p_np = uniform_np.probs(values)
p_variable = uniform_variable.probs(values) def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = self.low_np
fetch_list = [ self.dynamic_high = self.high_np
sample_int, sample_float, sample_float_np_broadcast, sample_np, self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
sample_variable, sample_int_diff, sample_float_diff,
sample_float_np_broadcast_diff, sample_np_diff, def init_static_data(self, batch_size, dims):
sample_variable_diff, entropy_int, entropy_float, self.static_low = self.low_np
entropy_float_np_broadcast, entropy_np, entropy_variable, self.static_high = self.high_np
lp_float_np_broadcast, lp_np, lp_variable, p_float_np_broadcast, with fluid.program_guard(self.test_program):
p_np, p_variable self.static_values = layers.data(
] name='values', shape=[dims], dtype='float64')
return fetch_list
def build_uniform_static(self, test_program, batch_size, dims, sample_shape, class UniformTest6(UniformTest):
low_float, high_float, high_np, low_np, values_np): def init_numpy_data(self, batch_size, dims):
""" # low and high are Tensor with dtype 'VarType.FP32'.
In static mode, generate feed data of Uniform network, and get output fetch_list using self.low_np = np.random.randn(batch_size, dims).astype('float32')
``build_uniform_common_net``. self.high_np = np.random.uniform(-5.0, 5.0,
(batch_size, dims)).astype('float32')
Args: self.values_np = np.random.randn(batch_size, dims).astype('float32')
test_program: In static mode, the Program object.
other args can refer to function ``build_uniform_common_net``. def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np)
Returns: self.dynamic_high = paddle.to_tensor(self.high_np)
feed_vars: The feed data of Uniform network in static mode. self.dynamic_values = paddle.to_tensor(self.values_np)
fetch_list: The output is generated by function ``build_uniform_common_net``.
""" def init_static_data(self, batch_size, dims):
with fluid.program_guard(test_program): with fluid.program_guard(self.test_program):
low = layers.data(name='low', shape=[dims], dtype='float32') self.static_low = layers.data(
high = layers.data(name='high', shape=[dims], dtype='float32') name='low', shape=[dims], dtype='float32')
self.static_high = layers.data(
values = layers.data(name='values', shape=[dims], dtype='float32') name='high', shape=[dims], dtype='float32')
self.static_values = layers.data(
fetch_list = self.build_uniform_common_net( name='values', shape=[dims], dtype='float32')
batch_size, dims, sample_shape, low_float, high_float, high_np,
low_np, values_np, low, high, values)
class UniformTest7(UniformTest):
feed_vars = {'low': low_np, 'high': high_np, 'values': values_np} def init_numpy_data(self, batch_size, dims):
return feed_vars, fetch_list # low and high are Tensor with dtype 'VarType.FP64'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
def build_uniform_dygraph(self, batch_size, dims, sample_shape, low_float, self.high_np = np.random.uniform(-5.0, 5.0,
high_float, high_np, low_np, values_np): (batch_size, dims)).astype('float64')
""" self.values_np = np.random.randn(batch_size, dims).astype('float64')
In dynamic mode, generate input data of Uniform network, and get output fetch_list using
``build_uniform_common_net``. def init_dynamic_data(self, batch_size, dims):
self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64')
Args: self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64')
refer to function ``build_uniform_common_net``. self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
Returns: def init_static_data(self, batch_size, dims):
fetch_list_numpy: The output is generated by function ``build_uniform_common_net``. Transform with fluid.program_guard(self.test_program):
these tensor to numpy.ndarray. self.static_low = layers.data(
""" name='low', shape=[dims], dtype='float64')
low = paddle.to_tensor(low_np) self.static_high = layers.data(
high = paddle.to_tensor(high_np) name='high', shape=[dims], dtype='float64')
values = paddle.to_tensor(values_np) self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
fetch_list = self.build_uniform_common_net(
batch_size, dims, sample_shape, low_float, high_float, high_np,
low_np, values_np, low, high, values) class UniformTest8(UniformTest):
fetch_list_numpy = [t.numpy() for t in fetch_list] def init_numpy_data(self, batch_size, dims):
return fetch_list_numpy # low and high are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'.
self.low_np = np.random.randn(batch_size, dims).astype('float64')
def compare_uniform_with_numpy(self, self.high_np = np.random.uniform(-5.0, 5.0,
data_list, (batch_size, dims)).astype('float64')
output_list, self.values_np = np.random.randn(batch_size, dims).astype('float32')
batch_size=2,
dims=3, def init_dynamic_data(self, batch_size, dims):
sample_shape=7, self.dynamic_low = paddle.to_tensor(self.low_np, dtype='float64')
tolerance=1e-6): self.dynamic_high = paddle.to_tensor(self.high_np, dtype='float64')
""" self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float32')
Compare the outputs of Uniform's methods in paddle and numpy. If the outputs are not consistent,
raise errors. def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
Args: self.static_low = layers.data(
data_list: Input data including float and numpy.ndarray type of ``low`` and ``high`` parameters. name='low', shape=[dims], dtype='float64')
output_list: The outputs of Uniform's methods in static or dynamic mode. self.static_high = layers.data(
batch_size(int): The first dimension of the shape of parameters(low and high). name='high', shape=[dims], dtype='float64')
dims(int): The second dimension of the shape of parameters. self.static_values = layers.data(
sample_shape(int): The sample value used in ``sample`` method. name='values', shape=[dims], dtype='float32')
tolerance(float): The tolerance of the error.
"""
[low_np, low_float, high_float, high_np, values_np] = data_list class NormalTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=2, dims=3):
np_uniform_int = UniformNumpy(int(low_float), int(high_float)) self.use_gpu = use_gpu
np_uniform_float = UniformNumpy(low_float, high_float) if not use_gpu:
np_uniform_float_np_broadcast = UniformNumpy(low_float, high_np) self.place = fluid.CPUPlace()
np_uniform = UniformNumpy(low_np, high_np) self.gpu_id = -1
else:
gt_sample_int = np_uniform_int.sample([batch_size, dims]) self.place = fluid.CUDAPlace(0)
gt_sample_float = np_uniform_float.sample([batch_size, dims]) self.gpu_id = 0
gt_sample_float_np_broadcast = np_uniform_float_np_broadcast.sample(
[batch_size, dims]) self.init_numpy_data(batch_size, dims)
gt_sample_np = np_uniform.sample([batch_size, dims])
gt_sample_int_diff = np_uniform_int.sample([sample_shape]) paddle.disable_static(self.place)
gt_sample_float_diff = np_uniform_float.sample([sample_shape]) self.init_dynamic_data(batch_size, dims)
gt_sample_float_np_broadcast_diff = np_uniform_float_np_broadcast.sample(
[sample_shape]) paddle.enable_static()
gt_sample_np_diff = np_uniform.sample([sample_shape]) self.test_program = fluid.Program()
gt_entropy_int = np_uniform_int.entropy() self.executor = fluid.Executor(self.place)
gt_entropy_float = np_uniform_float.entropy() self.init_static_data(batch_size, dims)
gt_entropy_float_np_broadcast = np_uniform_float_np_broadcast.entropy()
gt_entropy = np_uniform.entropy() def init_numpy_data(self, batch_size, dims):
gt_lp_float_np_broadcast = np_uniform_float_np_broadcast.log_prob( # loc ans scale are 'float'
values_np) self.loc_np = (np.random.ranf() - 0.5) * 4
gt_lp = np_uniform.log_prob(values_np) self.scale_np = (np.random.ranf() - 0.5) * 4
gt_p_float_np_broadcast = np_uniform_float_np_broadcast.probs(values_np) while self.scale_np < 0:
gt_p = np_uniform.probs(values_np) self.scale_np = (np.random.ranf() - 0.5) * 4
# used to construct another Normal object to calculate kl_divergence
[ self.other_loc_np = (np.random.ranf() - 0.5) * 4
output_sample_int, output_sample_float, self.other_scale_np = (np.random.ranf() - 0.5) * 4
output_sample_float_np_broadcast, output_sample_np, while self.other_scale_np < 0:
output_sample_variable, output_sample_int_diff, self.other_scale_np = (np.random.ranf() - 0.5) * 4
output_sample_float_diff, output_sample_float_np_broadcast_diff, self.values_np = np.random.ranf(1).astype('float32')
output_sample_np_diff, output_sample_variable_diff,
output_entropy_int, output_entropy_float, def init_dynamic_data(self, batch_size, dims):
output_entropy_float_np_broadcast, output_entropy_np, self.dynamic_loc = self.loc_np
output_entropy_variable, output_lp_float_np_broadcast, output_lp_np, self.dynamic_scale = self.scale_np
output_lp_variable, output_p_float_np_broadcast, output_p_np, self.dynamic_other_loc = self.other_loc_np
output_p_variable self.dynamic_other_scale = self.other_scale_np
] = output_list self.dynamic_values = paddle.to_tensor(self.values_np)
np.testing.assert_equal(output_sample_int.shape, gt_sample_int.shape) def init_static_data(self, batch_size, dims):
np.testing.assert_equal(output_sample_float.shape, self.static_loc = self.loc_np
gt_sample_float.shape) self.static_scale = self.scale_np
np.testing.assert_equal(output_sample_float_np_broadcast.shape, self.static_other_loc = self.other_loc_np
gt_sample_float_np_broadcast.shape) self.static_other_scale = self.other_scale_np
np.testing.assert_equal(output_sample_np.shape, gt_sample_np.shape) with fluid.program_guard(self.test_program):
np.testing.assert_equal(output_sample_variable.shape, self.static_values = layers.data(
gt_sample_np.shape) name='values', shape=[], dtype='float32')
np.testing.assert_equal(output_sample_int_diff.shape,
gt_sample_int_diff.shape) def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6):
np.testing.assert_equal(output_sample_float_diff.shape, sample, entropy, log_prob, probs, kl = fetch_list
gt_sample_float_diff.shape)
np.testing.assert_equal(output_sample_float_np_broadcast_diff.shape, np_normal = NormalNumpy(self.loc_np, self.scale_np)
gt_sample_float_np_broadcast_diff.shape) np_sample = np_normal.sample([sample_shape])
np.testing.assert_equal(output_sample_np_diff.shape, np_entropy = np_normal.entropy()
gt_sample_np_diff.shape) np_lp = np_normal.log_prob(self.values_np)
np.testing.assert_equal(output_sample_variable_diff.shape, np_p = np_normal.probs(self.values_np)
gt_sample_np_diff.shape) np_other_normal = NormalNumpy(self.other_loc_np, self.other_scale_np)
np.testing.assert_allclose( np_kl = np_normal.kl_divergence(np_other_normal)
output_entropy_int, gt_entropy_int, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose( np.testing.assert_equal(sample.shape, np_sample.shape)
output_entropy_float,
gt_entropy_float,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_float_np_broadcast,
gt_entropy_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_np, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_variable, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_float_np_broadcast,
gt_lp_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_lp_np, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_variable, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_p_float_np_broadcast,
gt_p_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose( np.testing.assert_allclose(
output_p_np, gt_p, rtol=tolerance, atol=tolerance) entropy, np_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose( np.testing.assert_allclose(
output_p_variable, gt_p, rtol=tolerance, atol=tolerance) log_prob, np_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance)
def test_uniform_distribution_static(self, np.testing.assert_allclose(kl, np_kl, rtol=tolerance, atol=tolerance)
batch_size=2,
dims=3,
sample_shape=7,
tolerance=1e-6):
"""
Test Uniform's methods in static mode.
Args:
refer to ``compare_uniform_with_numpy`` function.
"""
test_program = fluid.Program()
low_np = np.random.randn(batch_size, dims).astype('float32')
low_float = np.random.uniform(-2, 1)
high_float = np.random.uniform(1, 3)
high_np = np.random.uniform(-5.0, 5.0,
(batch_size, dims)).astype('float32')
values_np = np.random.randn(batch_size, dims).astype('float32')
data_list = [low_np, low_float, high_float, high_np, values_np]
feed_vars, fetch_list = self.build_uniform_static(
test_program, batch_size, dims, sample_shape, low_float, high_float,
high_np, low_np, values_np)
self.executor.run(fluid.default_startup_program()) def test_normal_distribution_dygraph(self, sample_shape=7, tolerance=1e-6):
paddle.disable_static(self.place)
normal = Normal(self.dynamic_loc, self.dynamic_scale)
sample = normal.sample([sample_shape]).numpy()
entropy = normal.entropy().numpy()
log_prob = normal.log_prob(self.dynamic_values).numpy()
probs = normal.probs(self.dynamic_values).numpy()
other_normal = Normal(self.dynamic_other_loc, self.dynamic_other_scale)
kl = normal.kl_divergence(other_normal).numpy()
# result calculated by paddle fetch_list = [sample, entropy, log_prob, probs, kl]
output_list = self.executor.run(program=test_program, self.compare_with_numpy(fetch_list)
feed=feed_vars,
fetch_list=fetch_list) def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6):
self.compare_uniform_with_numpy(data_list, output_list, batch_size,
dims, sample_shape, tolerance)
def test_uniform_distribution_dygraph(self,
batch_size=2,
dims=3,
sample_shape=7,
tolerance=1e-6):
"""
Test Uniform's methods in dynamic mode.
Args:
refer to ``compare_uniform_with_numpy`` function.
"""
paddle.disable_static()
low_np = np.random.randn(batch_size, dims).astype('float32')
low_float = np.random.uniform(-2, 1)
high_float = np.random.uniform(1, 3)
high_np = np.random.uniform(-5.0, 5.0,
(batch_size, dims)).astype('float32')
values_np = np.random.randn(batch_size, dims).astype('float32')
data_list = [low_np, low_float, high_float, high_np, values_np]
output_list = self.build_uniform_dygraph(batch_size, dims, sample_shape,
low_float, high_float, high_np,
low_np, values_np)
self.compare_uniform_with_numpy(data_list, output_list, batch_size,
dims, sample_shape, tolerance)
paddle.enable_static() paddle.enable_static()
with fluid.program_guard(self.test_program):
normal = Normal(self.static_loc, self.static_scale)
sample = normal.sample([sample_shape])
entropy = normal.entropy()
log_prob = normal.log_prob(self.static_values)
probs = normal.probs(self.static_values)
other_normal = Normal(self.static_other_loc,
self.static_other_scale)
kl = normal.kl_divergence(other_normal)
fetch_list = [sample, entropy, log_prob, probs, kl]
feed_vars = {
'loc': self.loc_np,
'scale': self.scale_np,
'values': self.values_np,
'other_loc': self.other_loc_np,
'other_scale': self.other_scale_np
}
self.executor.run(fluid.default_startup_program())
fetch_list = self.executor.run(program=self.test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_with_numpy(fetch_list)
class NormalTest2(NormalTest):
def init_numpy_data(self, batch_size, dims):
# loc ans scale are 'int'
self.loc_np = int((np.random.ranf() - 0.5) * 8)
self.scale_np = int((np.random.ranf() - 0.5) * 8)
while self.scale_np < 0:
self.scale_np = int((np.random.ranf() - 0.5) * 8)
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = int((np.random.ranf() - 0.5) * 8)
self.other_scale_np = int((np.random.ranf() - 0.5) * 8)
while self.other_scale_np < 0:
self.other_scale_np = int((np.random.ranf() - 0.5) * 8)
self.values_np = np.random.ranf(1).astype('float32')
class NormalTest3(NormalTest):
def init_numpy_data(self, batch_size, dims):
# test broadcast: loc is float, scale is numpy.ndarray with dtype 'float32'.
self.loc_np = (np.random.ranf() - 0.5) * 4
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(self.scale_np > 0):
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = (np.random.ranf() - 0.5) * 4
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
while not np.all(self.scale_np > 0):
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class NormalTest4(NormalTest):
def init_numpy_data(self, batch_size, dims):
# loc and scale are numpy.ndarray with dtype 'float32'.
self.loc_np = np.random.randn(batch_size, dims).astype('float32')
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(self.scale_np > 0):
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = np.random.randn(batch_size, dims).astype('float32')
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
while not np.all(self.scale_np > 0):
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
def init_static_data(self, batch_size, dims):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
class NormalTest5(NormalTest):
def init_numpy_data(self, batch_size, dims):
# loc and scale are numpy.ndarray with dtype 'float64'.
self.loc_np = np.random.randn(batch_size, dims).astype('float64')
self.scale_np = np.random.randn(batch_size, dims).astype('float64')
while not np.all(self.scale_np > 0):
self.scale_np = np.random.randn(batch_size, dims).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float64')
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = np.random.randn(batch_size, dims).astype('float64')
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float64')
while not np.all(self.scale_np > 0):
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_loc = self.loc_np
self.dynamic_scale = self.scale_np
self.dynamic_other_loc = self.other_loc_np
self.dynamic_other_scale = self.other_scale_np
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
def init_static_data(self, batch_size, dims):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
class NormalTest6(NormalTest):
def init_data(self, batch_size=2, dims=3):
# loc and scale are Tensor with dtype 'VarType.FP32'.
self.loc_np = np.random.randn(batch_size, dims).astype('float32')
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(self.scale_np > 0):
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
self.loc = paddle.to_tensor(self.loc_np)
self.scale = paddle.to_tensor(self.scale_np)
self.values = paddle.to_tensor(self.values_np)
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = np.random.randn(batch_size, dims).astype('float32')
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
while not np.all(self.scale_np > 0):
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
self.other_loc = paddle.to_tensor(self.other_loc_np)
self.other_scale = paddle.to_tensor(self.other_scale_np)
def init_numpy_data(self, batch_size, dims):
# loc and scale are Tensor with dtype 'VarType.FP32'.
self.loc_np = np.random.randn(batch_size, dims).astype('float32')
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(self.scale_np > 0):
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = np.random.randn(batch_size, dims).astype('float32')
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
while not np.all(self.scale_np > 0):
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_loc = paddle.to_tensor(self.loc_np)
self.dynamic_scale = paddle.to_tensor(self.scale_np)
self.dynamic_values = paddle.to_tensor(self.values_np)
self.dynamic_other_loc = paddle.to_tensor(self.other_loc_np)
self.dynamic_other_scale = paddle.to_tensor(self.other_scale_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_loc = layers.data(
name='loc', shape=[dims], dtype='float32')
self.static_scale = layers.data(
name='scale', shape=[dims], dtype='float32')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
self.static_other_loc = layers.data(
name='other_loc', shape=[dims], dtype='float32')
self.static_other_scale = layers.data(
name='other_scale', shape=[dims], dtype='float32')
class NormalTest7(NormalTest):
def init_numpy_data(self, batch_size, dims):
# loc and scale are Tensor with dtype 'VarType.FP64'.
self.loc_np = np.random.randn(batch_size, dims).astype('float64')
self.scale_np = np.random.randn(batch_size, dims).astype('float64')
while not np.all(self.scale_np > 0):
self.scale_np = np.random.randn(batch_size, dims).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float64')
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = np.random.randn(batch_size, dims).astype('float64')
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float64')
while not np.all(self.scale_np > 0):
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_loc = paddle.to_tensor(self.loc_np, dtype='float64')
self.dynamic_scale = paddle.to_tensor(self.scale_np, dtype='float64')
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')
self.dynamic_other_loc = paddle.to_tensor(
self.other_loc_np, dtype='float64')
self.dynamic_other_scale = paddle.to_tensor(
self.other_scale_np, dtype='float64')
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_loc = layers.data(
name='loc', shape=[dims], dtype='float64')
self.static_scale = layers.data(
name='scale', shape=[dims], dtype='float64')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float64')
self.static_other_loc = layers.data(
name='other_loc', shape=[dims], dtype='float64')
self.static_other_scale = layers.data(
name='other_scale', shape=[dims], dtype='float64')
class NormalTest8(NormalTest):
def init_numpy_data(self, batch_size, dims):
# loc and scale are Tensor with dtype 'VarType.FP64'. value's dtype is 'VarType.FP32'.
self.loc_np = np.random.randn(batch_size, dims).astype('float64')
self.scale_np = np.random.randn(batch_size, dims).astype('float64')
while not np.all(self.scale_np > 0):
self.scale_np = np.random.randn(batch_size, dims).astype('float64')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = np.random.randn(batch_size, dims).astype('float64')
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float64')
while not np.all(self.scale_np > 0):
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float64')
def init_dynamic_data(self, batch_size, dims):
self.dynamic_loc = paddle.to_tensor(self.loc_np, dtype='float64')
self.dynamic_scale = paddle.to_tensor(self.scale_np, dtype='float64')
self.dynamic_values = paddle.to_tensor(self.values_np)
self.dynamic_other_loc = paddle.to_tensor(
self.other_loc_np, dtype='float64')
self.dynamic_other_scale = paddle.to_tensor(
self.other_scale_np, dtype='float64')
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.static_loc = layers.data(
name='loc', shape=[dims], dtype='float64')
self.static_scale = layers.data(
name='scale', shape=[dims], dtype='float64')
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
self.static_other_loc = layers.data(
name='other_loc', shape=[dims], dtype='float64')
self.static_other_scale = layers.data(
name='other_scale', shape=[dims], dtype='float64')
class DistributionTestError(unittest.TestCase): class DistributionTestError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册