提交 43e17c79 编写于 作者: L LielinJiang 提交者: whs

Add distributions of normal and uniform (#18023)

* add_distributions_of_normal_and_uniform

* paddle/fluid/API.spec

* modify API.spec

* modified paddle/fluid/API.spec, test=develop

* modify paddle/fluid/API.spec, test=develop

* modify paddle/fluid/API.spec, test=develop

* fix some comment, test=develop

* modify API.spec, test=develop

* add comment for init function, modify hard code, test=develop

* modify API.spec, test=develop

* modify API.spec, test=develop

* make unit test function shorter, test=develop

* modify paddle/fluid/API.spec
上级 3fe6bf5e
......@@ -407,6 +407,18 @@ paddle.fluid.layers.piecewise_decay (ArgSpec(args=['boundaries', 'values'], vara
paddle.fluid.layers.noam_decay (ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None), ('document', 'fd57228fb76195e66bbcc8d8e42c494d'))
paddle.fluid.layers.cosine_decay (ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None), ('document', 'f0d65d8c89d0fe78051ca689daa15e35'))
paddle.fluid.layers.linear_lr_warmup (ArgSpec(args=['learning_rate', 'warmup_steps', 'start_lr', 'end_lr'], varargs=None, keywords=None, defaults=None), ('document', 'dc7292c456847ba41cfd318e9f7f4363'))
paddle.fluid.layers.Uniform ('paddle.fluid.layers.distributions.Uniform', ('document', 'af70e7003f437e7a8a9e28cded35c433'))
paddle.fluid.layers.Uniform.__init__ (ArgSpec(args=['self', 'low', 'high'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.Uniform.entropy (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'ba59f9ce77af3c93e2b4c8af1801a24e'))
paddle.fluid.layers.Uniform.kl_divergence (ArgSpec(args=['self', 'other'], varargs=None, keywords=None, defaults=None), ('document', '3baee52abbed82d47e9588d9dfe2f42f'))
paddle.fluid.layers.Uniform.log_prob (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', 'b79091014ceaffb6a7372a198a341c23'))
paddle.fluid.layers.Uniform.sample (ArgSpec(args=['self', 'shape', 'seed'], varargs=None, keywords=None, defaults=(0,)), ('document', 'adac334af13f6984e991b3ecf12b8cb7'))
paddle.fluid.layers.Normal ('paddle.fluid.layers.distributions.Normal', ('document', '3265262d0d8b3b32c6245979a5cdced9'))
paddle.fluid.layers.Normal.__init__ (ArgSpec(args=['self', 'loc', 'scale'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.Normal.entropy (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'd2db47b1e62c037a2570fc526b93f518'))
paddle.fluid.layers.Normal.kl_divergence (ArgSpec(args=['self', 'other'], varargs=None, keywords=None, defaults=None), ('document', '2e8845cdf1129647e6fa6e816876cd3b'))
paddle.fluid.layers.Normal.log_prob (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', 'b79091014ceaffb6a7372a198a341c23'))
paddle.fluid.layers.Normal.sample (ArgSpec(args=['self', 'shape', 'seed'], varargs=None, keywords=None, defaults=(0,)), ('document', 'adac334af13f6984e991b3ecf12b8cb7'))
paddle.fluid.contrib.InitState ('paddle.fluid.contrib.decoder.beam_search_decoder.InitState', ('document', '3afd1f84232718e628e9e566941c5f05'))
paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.StateCell ('paddle.fluid.contrib.decoder.beam_search_decoder.StateCell', ('document', 'ecd0066c02867d445d7b461e28220c50'))
......
......@@ -34,6 +34,7 @@ from . import metric_op
from .metric_op import *
from .learning_rate_scheduler import *
from .collective import *
from .distributions import *
__all__ = []
__all__ += nn.__all__
......@@ -45,3 +46,4 @@ __all__ += device.__all__
__all__ += detection.__all__
__all__ += metric_op.__all__
__all__ += learning_rate_scheduler.__all__
__all__ += distributions.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import control_flow
from . import tensor
from . import ops
from . import nn
import math
import numpy as np
import warnings
__all__ = ['Uniform', 'Normal']
class Distribution(object):
"""
Distribution is the abstract base class for probability distributions.
"""
def sample(self):
"""Sampling from the distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the distribution."""
raise NotImplementedError
def kl_divergence(self, other):
"""The KL-divergence between self distributions and other."""
raise NotImplementedError
def log_prob(self, value):
"""Log probability density/mass function."""
raise NotImplementedError
def _validate_args(self, *args):
"""
Argument validation for distribution args
Args:
value (float, list, numpy.ndarray, Variable)
Raises
ValueError: if one argument is Variable, all arguments should be Variable
"""
is_variable = False
is_number = False
for arg in args:
if isinstance(arg, tensor.Variable):
is_variable = True
else:
is_number = True
if is_variable and is_number:
raise ValueError(
'if one argument is Variable, all arguments should be Variable')
return is_variable
def _to_variable(self, *args):
"""
Argument convert args to Variable
Args:
value (float, list, numpy.ndarray, Variable)
Returns:
Variable of args.
"""
numpy_args = []
variable_args = []
tmp = 0.
for arg in args:
valid_arg = False
for cls in [float, list, np.ndarray, tensor.Variable]:
if isinstance(arg, cls):
valid_arg = True
break
assert valid_arg, "type of input args must be float, list, numpy.ndarray or Variable."
if isinstance(arg, float):
arg = np.zeros(1) + arg
arg_np = np.array(arg)
arg_dtype = arg_np.dtype
if str(arg_dtype) not in ['float32']:
warnings.warn(
"data type of argument only support float32, your argument will be convert to float32."
)
arg_np = arg_np.astype('float32')
tmp = tmp + arg_np
numpy_args.append(arg_np)
dtype = tmp.dtype
for arg in numpy_args:
arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
arg_variable = tensor.create_tensor(dtype=dtype)
tensor.assign(arg_broadcasted, arg_variable)
variable_args.append(arg_variable)
return tuple(variable_args)
class Uniform(Distribution):
"""Uniform distribution with `low` and `high` parameters.
Mathematical Details
The probability density function (pdf) is,
.. math::
pdf(x; a, b) = \\frac{1}{Z}, \ a <=x <b
.. math::
Z = b - a
In the above equation:
* :math:`low = a`,
* :math:`high = b`,
* :math:`Z`: is the normalizing constant.
The parameters `low` and `high` must be shaped in a way that supports
broadcasting (e.g., `high - low` is a valid operation).
Args:
low(float|list|numpy.ndarray|Variable): The lower boundary of uniform distribution.
high(float|list|numpy.ndarray|Variable): The higher boundary of uniform distribution.
Examples:
.. code-block:: python
from paddle.fluid import layers
from paddle.fluid.layers import Uniform
# Without broadcasting, a single uniform distribution [3, 4]:
u1 = Uniform(low=3.0, high=4.0)
# 2 distributions [1, 3], [2, 4]
u2 = Uniform(low=[1.0, 2.0],
high=[3.0, 4.0])
# 4 distributions
u3 = Uniform(low=[[1.0, 2.0],
[3.0, 4.0]],
high=[[1.5, 2.5],
[3.5, 4.5]])
# With broadcasting:
u4 = Uniform(low=3.0, high=[5.0, 6.0, 7.0])
# Variable as input
dims = 3
low = layers.data(name='low', shape=[dims], dtype='float32')
high = layers.data(name='high', shape=[dims], dtype='float32')
values = layers.data(name='values', shape=[dims], dtype='float32')
uniform = Uniform(low, high)
sample = uniform.sample([2, 3])
entropy = uniform.entropy()
lp = uniform.log_prob(values)
"""
def __init__(self, low, high):
self.all_arg_is_float = False
self.batch_size_unknown = False
if self._validate_args(low, high):
self.batch_size_unknown = True
self.low = low
self.high = high
else:
if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True
self.low, self.high = self._to_variable(low, high)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Variable: A tensor with prepended dimensions shape.
"""
batch_shape = list((self.low + self.high).shape)
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.low.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp, zero_tmp.shape, min=0., max=1., seed=seed)
output = uniform_random_tmp * (zero_tmp + self.high - self.low
) + self.low
return nn.reshape(output, output_shape)
else:
output_shape = shape + batch_shape
output = ops.uniform_random(
output_shape, seed=seed) * (tensor.zeros(
output_shape, dtype=self.low.dtype) +
(self.high - self.low)) + self.low
if self.all_arg_is_float:
return nn.reshape(output, shape)
else:
return output
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Variable): The input tensor.
Returns:
Variable: log probability.
"""
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low)
def entropy(self):
"""Shannon entropy in nats.
Returns:
Variable: Shannon entropy of uniform distribution.
"""
return nn.log(self.high - self.low)
class Normal(Distribution):
"""The Normal distribution with location `loc` and `scale` parameters.
Mathematical details
The probability density function (pdf) is,
.. math::
pdf(x; \mu, \sigma) = \\frac{1}{Z}e^{\\frac {-0.5 (x - \mu)^2} {\sigma^2} }
.. math::
Z = (2 \pi \sigma^2)^{0.5}
In the above equation:
* :math:`loc = \mu`: is the mean.
* :math:`scale = \sigma`: is the std.
* :math:`Z`: is the normalization constant.
Args:
loc(float|list|numpy.ndarray|Variable): The mean of normal distribution.
scale(float|list|numpy.ndarray|Variable): The std of normal distribution.
Examples:
.. code-block:: python
from paddle.fluid import layers
from paddle.fluid.layers import Normal
# Define a single scalar Normal distribution.
dist = Normal(loc=0., scale=3.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = Normal(loc=[1, 2.], scale=[11, 22.])
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample([3])
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
dist = Normal(loc=1., scale=[11, 22.])
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
dist = Normal(loc=1., scale=[11, 22.])
# Variable as input
dims = 3
loc = layers.data(name='loc', shape=[dims], dtype='float32')
scale = layers.data(name='scale', shape=[dims], dtype='float32')
other_loc = layers.data(
name='other_loc', shape=[dims], dtype='float32')
other_scale = layers.data(
name='other_scale', shape=[dims], dtype='float32')
values = layers.data(name='values', shape=[dims], dtype='float32')
normal = Normal(loc, scale)
other_normal = Normal(other_loc, other_scale)
sample = normal.sample([2, 3])
entropy = normal.entropy()
lp = normal.log_prob(values)
kl = normal.kl_divergence(other_normal)
"""
def __init__(self, loc, scale):
self.batch_size_unknown = False
self.all_arg_is_float = False
if self._validate_args(loc, scale):
self.batch_size_unknown = True
self.loc = loc
self.scale = scale
else:
if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True
self.loc, self.scale = self._to_variable(loc, scale)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Variable: A tensor with prepended dimensions shape.
"""
batch_shape = list((self.loc + self.scale).shape)
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.)
normal_random_tmp = nn.gaussian_random_batch_size_like(
zero_tmp, zero_tmp.shape, mean=0., std=1., seed=seed)
output = normal_random_tmp * (zero_tmp + self.scale) + self.loc
return nn.reshape(output, output_shape)
else:
output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed) * \
(tensor.zeros(output_shape, dtype=self.loc.dtype) + self.scale) + self.loc
if self.all_arg_is_float:
return nn.reshape(output, shape)
else:
return output
def entropy(self):
"""Shannon entropy in nats.
Returns:
Variable: Shannon entropy of normal distribution.
"""
batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.loc.dtype, 0.)
return 0.5 + 0.5 * math.log(2 * math.pi) + nn.log(
(self.scale + zero_tmp))
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Variable): The input tensor.
Returns:
Variable: log probability.
"""
var = self.scale * self.scale
log_scale = nn.log(self.scale)
return -1. * ((value - self.loc) * (value - self.loc)) / (
2. * var) - log_scale - math.log(math.sqrt(2. * math.pi))
def kl_divergence(self, other):
"""The KL-divergence between two normal distributions.
Args:
other (Normal): instance of Normal.
Returns:
Variable: kl-divergence between two normal distributions.
"""
assert isinstance(other, Normal), "another distribution must be Normal"
var_ratio = self.scale / other.scale
var_ratio = (var_ratio * var_ratio)
t1 = (self.loc - other.loc) / other.scale
t1 = (t1 * t1)
return 0.5 * (var_ratio + t1 - 1. - nn.log(var_ratio))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
from paddle import fluid
from paddle.fluid import layers
from paddle.fluid.layers.distributions import *
import math
class DistributionNumpy():
"""
Distribution is the abstract base class for probability distributions.
"""
def sample(self):
"""Sampling from the distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the distribution."""
raise NotImplementedError
def kl_divergence(self, other):
"""The KL-divergence between self distributions and other."""
raise NotImplementedError
def log_prob(self, value):
"""Log probability density/mass function."""
raise NotImplementedError
class UniformNumpy(DistributionNumpy):
def __init__(self, low, high):
self.low = np.array(low).astype('float32')
self.high = np.array(high).astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.low + self.high).shape
return self.low + (np.random.uniform(size=shape) *
(self.high - self.low))
def log_prob(self, value):
lb = np.less(self.low, value).astype('float32')
ub = np.less(value, self.high).astype('float32')
return np.log(lb * ub) - np.log(self.high - self.low)
def entropy(self):
return np.log(self.high - self.low)
class NormalNumpy(DistributionNumpy):
def __init__(self, loc, scale):
self.loc = np.array(loc).astype('float32')
self.scale = np.array(scale).astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.loc + self.scale).shape
return self.loc + (np.random.randn(*shape) * self.scale)
def log_prob(self, value):
var = self.scale * self.scale
log_scale = np.log(self.scale)
return -((value - self.loc) * (value - self.loc)) / (
2. * var) - log_scale - math.log(math.sqrt(2. * math.pi))
def entropy(self):
return 0.5 + 0.5 * np.log(np.array(2. * math.pi).astype(
'float32')) + np.log(self.scale)
def kl_divergence(self, other):
var_ratio = (self.scale / other.scale)
var_ratio = var_ratio * var_ratio
t1 = ((self.loc - other.loc) / other.scale)
t1 = (t1 * t1)
return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio))
class DistributionTest(unittest.TestCase):
def setUp(self, use_gpu=False):
self.use_gpu = use_gpu
if not use_gpu:
place = fluid.CPUPlace()
self.gpu_id = -1
else:
place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.executor = fluid.Executor(place)
def build_normal_program(self, test_program, batch_size, dims, loc_float,
scale_float, other_loc_float, other_scale_float,
scale_np, other_scale_np, loc_np, other_loc_np,
values_np):
with fluid.program_guard(test_program):
loc = layers.data(name='loc', shape=[dims], dtype='float32')
scale = layers.data(name='scale', shape=[dims], dtype='float32')
other_loc = layers.data(
name='other_loc', shape=[dims], dtype='float32')
other_scale = layers.data(
name='other_scale', shape=[dims], dtype='float32')
values = layers.data(name='values', shape=[dims], dtype='float32')
normal_float = Normal(loc_float, scale_float)
other_normal_float = Normal(other_loc_float, other_scale_float)
normal_float_np_broadcast = Normal(loc_float, scale_np)
other_normal_float_np_broadcast = Normal(other_loc_float,
other_scale_np)
normal_np = Normal(loc_np, scale_np)
other_normal_np = Normal(other_loc_np, other_scale_np)
normal_variable = Normal(loc, scale)
other_normal_variable = Normal(other_loc, other_scale)
sample_float = normal_float.sample([batch_size, dims])
sample_float_np_broadcast = normal_float_np_broadcast.sample(
[batch_size, dims])
sample_np = normal_np.sample([batch_size, dims])
sample_variable = normal_variable.sample([batch_size, dims])
entropy_float = normal_float.entropy()
entropy_float_np_broadcast = normal_float_np_broadcast.entropy()
entropy_np = normal_np.entropy()
entropy_variable = normal_variable.entropy()
lp_float_np_broadcast = normal_float_np_broadcast.log_prob(values)
lp_np = normal_np.log_prob(values)
lp_variable = normal_variable.log_prob(values)
kl_float = normal_float.kl_divergence(other_normal_float)
kl_float_np_broadcast = normal_float_np_broadcast.kl_divergence(
other_normal_float_np_broadcast)
kl_np = normal_np.kl_divergence(other_normal_np)
kl_variable = normal_variable.kl_divergence(other_normal_variable)
fetch_list = [
sample_float, sample_float_np_broadcast, sample_np, sample_variable,
entropy_float, entropy_float_np_broadcast, entropy_np,
entropy_variable, lp_float_np_broadcast, lp_np, lp_variable,
kl_float, kl_float_np_broadcast, kl_np, kl_variable
]
feed_vars = {
'loc': loc_np,
'scale': scale_np,
'other_loc': other_loc_np,
'other_scale': other_scale_np,
'values': values_np
}
return feed_vars, fetch_list
def get_normal_random_input(self, batch_size, dims):
loc_np = np.random.randn(batch_size, dims).astype('float32')
other_loc_np = np.random.randn(batch_size, dims).astype('float32')
loc_float = (np.random.ranf() - 0.5) * 4
scale_float = (np.random.ranf() - 0.5) * 4
while scale_float < 0:
scale_float = (np.random.ranf() - 0.5) * 4
other_loc_float = (np.random.ranf() - 0.5) * 4
other_scale_float = (np.random.ranf() - 0.5) * 4
while other_scale_float < 0:
other_scale_float = (np.random.ranf() - 0.5) * 4
scale_np = np.random.randn(batch_size, dims).astype('float32')
other_scale_np = np.random.randn(batch_size, dims).astype('float32')
values_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(scale_np > 0):
scale_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(other_scale_np > 0):
other_scale_np = np.random.randn(batch_size, dims).astype('float32')
return loc_np, other_loc_np, loc_float, scale_float, other_loc_float, \
other_scale_float, scale_np, other_scale_np, values_np
def test_normal_distribution(self, batch_size=2, dims=3, tolerance=1e-6):
test_program = fluid.Program()
loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = self.get_normal_random_input(
batch_size, dims)
feed_vars, fetch_list = self.build_normal_program(
test_program, batch_size, dims, loc_float, scale_float,
other_loc_float, other_scale_float, scale_np, other_scale_np,
loc_np, other_loc_np, values_np)
self.executor.run(fluid.default_startup_program())
np_normal_float = NormalNumpy(loc_float, scale_float)
np_other_normal_float = NormalNumpy(other_loc_float, other_scale_float)
np_normal_float_np_broadcast = NormalNumpy(loc_float, scale_np)
np_other_normal_float_np_broadcast = NormalNumpy(other_loc_float,
other_scale_np)
np_normal = NormalNumpy(loc_np, scale_np)
np_other_normal = NormalNumpy(other_loc_np, other_scale_np)
gt_sample_float = np_normal_float.sample([batch_size, dims])
gt_sample_float_np_broadcast = np_normal_float_np_broadcast.sample(
[batch_size, dims])
gt_sample_np = np_normal.sample([batch_size, dims])
gt_entropy_float = np_normal_float.entropy()
gt_entropy_float_np_broadcast = np_normal_float_np_broadcast.entropy()
gt_entropy = np_normal.entropy()
gt_lp_float_np_broadcast = np_normal_float_np_broadcast.log_prob(
values_np)
gt_lp = np_normal.log_prob(values_np)
gt_kl_float = np_normal_float.kl_divergence(np_other_normal_float)
gt_kl_float_np_broadcast = np_normal_float_np_broadcast.kl_divergence(
np_other_normal_float_np_broadcast)
gt_kl = np_normal.kl_divergence(np_other_normal)
[
output_sample_float, output_sample_float_np_broadcast,
output_sample_np, output_sample_variable, output_entropy_float,
output_entropy_float_np_broadcast, output_entropy_np,
output_entropy_variable, output_lp_float_np_broadcast, output_lp_np,
output_lp_variable, output_kl_float, output_kl_float_np_broadcast,
output_kl_np, output_kl_variable
] = self.executor.run(program=test_program,
feed=feed_vars,
fetch_list=fetch_list)
np.testing.assert_allclose(
output_sample_float.shape, gt_sample_float.shape, rtol=tolerance)
np.testing.assert_allclose(
output_sample_float_np_broadcast.shape,
gt_sample_float_np_broadcast.shape,
rtol=tolerance)
np.testing.assert_allclose(
output_sample_np.shape, gt_sample_np.shape, rtol=tolerance)
np.testing.assert_allclose(
output_sample_variable.shape, gt_sample_np.shape, rtol=tolerance)
np.testing.assert_allclose(
output_entropy_float, gt_entropy_float, rtol=tolerance)
np.testing.assert_allclose(
output_entropy_float_np_broadcast,
gt_entropy_float_np_broadcast,
rtol=tolerance)
np.testing.assert_allclose(
output_entropy_np, gt_entropy, rtol=tolerance)
np.testing.assert_allclose(
output_entropy_variable, gt_entropy, rtol=tolerance)
np.testing.assert_allclose(
output_lp_float_np_broadcast,
gt_lp_float_np_broadcast,
rtol=tolerance)
np.testing.assert_allclose(output_lp_np, gt_lp, rtol=tolerance)
np.testing.assert_allclose(output_lp_variable, gt_lp, rtol=tolerance)
np.testing.assert_allclose(output_kl_float, gt_kl_float, rtol=tolerance)
np.testing.assert_allclose(
output_kl_float_np_broadcast,
gt_kl_float_np_broadcast,
rtol=tolerance)
np.testing.assert_allclose(output_kl_np, gt_kl, rtol=tolerance)
np.testing.assert_allclose(output_kl_variable, gt_kl, rtol=tolerance)
def build_uniform_program(self, test_program, batch_size, dims, low_float,
high_float, high_np, low_np, values_np):
with fluid.program_guard(test_program):
low = layers.data(name='low', shape=[dims], dtype='float32')
high = layers.data(name='high', shape=[dims], dtype='float32')
values = layers.data(name='values', shape=[dims], dtype='float32')
uniform_float = Uniform(low_float, high_float)
uniform_float_np_broadcast = Uniform(low_float, high_np)
uniform_np = Uniform(low_np, high_np)
uniform_variable = Uniform(low, high)
sample_float = uniform_float.sample([batch_size, dims])
sample_float_np_broadcast = uniform_float_np_broadcast.sample(
[batch_size, dims])
sample_np = uniform_np.sample([batch_size, dims])
sample_variable = uniform_variable.sample([batch_size, dims])
entropy_float = uniform_float.entropy()
entropy_float_np_broadcast = uniform_float_np_broadcast.entropy()
entropy_np = uniform_np.entropy()
entropy_variable = uniform_variable.entropy()
lp_float_np_broadcast = uniform_float_np_broadcast.log_prob(values)
lp_np = uniform_np.log_prob(values)
lp_variable = uniform_variable.log_prob(values)
fetch_list = [
sample_float, sample_float_np_broadcast, sample_np, sample_variable,
entropy_float, entropy_float_np_broadcast, entropy_np,
entropy_variable, lp_float_np_broadcast, lp_np, lp_variable
]
feed_vars = {'low': low_np, 'high': high_np, 'values': values_np}
return feed_vars, fetch_list
def test_uniform_distribution(self, batch_size=2, dims=3, tolerance=1e-6):
test_program = fluid.Program()
low_np = np.random.randn(batch_size, dims).astype('float32')
low_float = np.random.uniform(-2, 1)
high_float = np.random.uniform(1, 3)
high_np = np.random.uniform(-5.0, 5.0,
(batch_size, dims)).astype('float32')
values_np = np.random.randn(batch_size, dims).astype('float32')
feed_vars, fetch_list = self.build_uniform_program(
test_program, batch_size, dims, low_float, high_float, high_np,
low_np, values_np)
self.executor.run(fluid.default_startup_program())
np_uniform_float = UniformNumpy(low_float, high_float)
np_uniform_float_np_broadcast = UniformNumpy(low_float, high_np)
np_uniform = UniformNumpy(low_np, high_np)
gt_sample_float = np_uniform_float.sample([batch_size, dims])
gt_sample_float_np_broadcast = np_uniform_float_np_broadcast.sample(
[batch_size, dims])
gt_sample_np = np_uniform.sample([batch_size, dims])
gt_entropy_float = np_uniform_float.entropy()
gt_entropy_float_np_broadcast = np_uniform_float_np_broadcast.entropy()
gt_entropy = np_uniform.entropy()
gt_lp_float_np_broadcast = np_uniform_float_np_broadcast.log_prob(
values_np)
gt_lp = np_uniform.log_prob(values_np)
# result calculated by paddle
[
output_sample_float, output_sample_float_np_broadcast,
output_sample_np, output_sample_variable, output_entropy_float,
output_entropy_float_np_broadcast, output_entropy_np,
output_entropy_variable, output_lp_float_np_broadcast, output_lp_np,
output_lp_variable
] = self.executor.run(program=test_program,
feed=feed_vars,
fetch_list=fetch_list)
np.testing.assert_allclose(
output_sample_float.shape, gt_sample_float.shape, rtol=tolerance)
np.testing.assert_allclose(
output_sample_float_np_broadcast.shape,
gt_sample_float_np_broadcast.shape,
rtol=tolerance)
np.testing.assert_allclose(
output_sample_np.shape, gt_sample_np.shape, rtol=tolerance)
np.testing.assert_allclose(
output_sample_variable.shape, gt_sample_np.shape, rtol=tolerance)
np.testing.assert_allclose(
output_entropy_float, gt_entropy_float, rtol=tolerance)
np.testing.assert_allclose(
output_entropy_float_np_broadcast,
gt_entropy_float_np_broadcast,
rtol=tolerance)
np.testing.assert_allclose(
output_entropy_np, gt_entropy, rtol=tolerance)
np.testing.assert_allclose(
output_entropy_variable, gt_entropy, rtol=tolerance)
np.testing.assert_allclose(
output_lp_float_np_broadcast,
gt_lp_float_np_broadcast,
rtol=tolerance)
np.testing.assert_allclose(output_lp_np, gt_lp, rtol=tolerance)
np.testing.assert_allclose(output_lp_variable, gt_lp, rtol=tolerance)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册