未验证 提交 644dfd7d 编写于 作者: P pangyoki 提交者: GitHub

Release Distribution base class and Normal, Uniform class (#26355)

* fixed static module

* solve conflict

* Add Distribution base class, Uniform class and Normal class

* release Distribution class and Normal, Uniform class

* Add Doc args explaination

* save distributions.py and process in distribution.py

* delete useless function in test_distribution

* Add NormalNumpy test class

* Add probs in NormalNumpy

* add distribution to paddle init

* Add Distribution base class and name attribute unittest

* Change Doc

* Change Doc

* adjust format

* Fixed Doc Code

* implement probs and change Variable to Tensor

* Add name for all functions and add name unittest

* support int datatype

* Add dynamic mode

* optimize test_distribution static and dygraph
上级 e1245f5c
......@@ -38,6 +38,7 @@ import paddle.compat
import paddle.distributed
import paddle.sysconfig
import paddle.tensor
import paddle.distribution
import paddle.nn
import paddle.distributed.fleet
import paddle.optimizer
......
......@@ -18,3 +18,517 @@
# 'Normal',
# 'sampling_id',
# 'Uniform']
from __future__ import print_function
from .fluid.layers import control_flow
from .fluid.layers import tensor
from .fluid.layers import ops
from .fluid.layers import nn
from .fluid.framework import in_dygraph_mode
from .tensor.math import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub
import math
import numpy as np
import warnings
from .fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__all__ = ['Distribution', 'Uniform', 'Normal']
class Distribution(object):
"""
The abstract base class for probability distributions. Functions are
implemented in specific distributions.
"""
def __init__(self):
super(Distribution, self).__init__()
def sample(self):
"""Sampling from the distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the distribution."""
raise NotImplementedError
def kl_divergence(self, other):
"""The KL-divergence between self distributions and other."""
raise NotImplementedError
def log_prob(self, value):
"""Log probability density/mass function."""
raise NotImplementedError
def probs(self, value):
"""Probability density/mass function."""
raise NotImplementedError
def _validate_args(self, *args):
"""
Argument validation for distribution args
Args:
value (float, list, numpy.ndarray, Tensor)
Raises
ValueError: if one argument is Tensor, all arguments should be Tensor
"""
is_variable = False
is_number = False
for arg in args:
if isinstance(arg, tensor.Variable):
is_variable = True
else:
is_number = True
if is_variable and is_number:
raise ValueError(
'if one argument is Tensor, all arguments should be Tensor')
return is_variable
def _to_variable(self, *args):
"""
Argument convert args to Tensor
Args:
value (float, list, numpy.ndarray, Tensor)
Returns:
Tensor of args.
"""
numpy_args = []
variable_args = []
tmp = 0.
for arg in args:
valid_arg = False
for cls in [float, list, np.ndarray, tensor.Variable]:
if isinstance(arg, cls):
valid_arg = True
break
assert valid_arg, "type of input args must be float, list, numpy.ndarray or Tensor."
if isinstance(arg, float):
arg = np.zeros(1) + arg
arg_np = np.array(arg)
arg_dtype = arg_np.dtype
if str(arg_dtype) not in ['float32']:
warnings.warn(
"data type of argument only support float32, your argument will be convert to float32."
)
arg_np = arg_np.astype('float32')
tmp = tmp + arg_np
numpy_args.append(arg_np)
dtype = tmp.dtype
for arg in numpy_args:
arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
arg_variable = tensor.create_tensor(dtype=dtype)
tensor.assign(arg_broadcasted, arg_variable)
variable_args.append(arg_variable)
return tuple(variable_args)
class Uniform(Distribution):
"""Uniform distribution with `low` and `high` parameters.
Mathematical Details
The probability density function (pdf) is,
.. math::
pdf(x; a, b) = \\frac{1}{Z}, \ a <=x <b
.. math::
Z = b - a
In the above equation:
* :math:`low = a`,
* :math:`high = b`,
* :math:`Z`: is the normalizing constant.
The parameters `low` and `high` must be shaped in a way that supports
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args:
low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is float32 or int
high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is float32 or int
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distribution import Uniform
paddle.disable_static()
# Without broadcasting, a single uniform distribution [3, 4]:
u1 = Uniform(low=3.0, high=4.0)
# 2 distributions [1, 3], [2, 4]
u2 = Uniform(low=[1.0, 2.0], high=[3.0, 4.0])
# 4 distributions
u3 = Uniform(low=[[1.0, 2.0], [3.0, 4.0]],
high=[[1.5, 2.5], [3.5, 4.5]])
# With broadcasting:
u4 = Uniform(low=3.0, high=[5.0, 6.0, 7.0])
# Complete example
value_npdata = np.array([0.8], dtype="float32")
value_tensor = paddle.to_tensor(value_npdata)
uniform = Uniform([0.], [2.])
sample = uniform.sample([2])
# a random tensor created by uniform distribution with shape: [2, 1]
entropy = uniform.entropy()
# [0.6931472] with shape: [1]
lp = uniform.log_prob(value_tensor)
# [-0.6931472] with shape: [1]
p = uniform.probs(value_tensor)
# [0.5] with shape: [1]
"""
def __init__(self, low, high, name=None):
if not in_dygraph_mode():
check_type(low, 'low',
(int, float, np.ndarray, tensor.Variable, list),
'Uniform')
check_type(high, 'high',
(int, float, np.ndarray, tensor.Variable, list),
'Uniform')
self.all_arg_is_float = False
self.batch_size_unknown = False
self.name = name if name is not None else 'Uniform'
if isinstance(low, int):
low = float(low)
if isinstance(high, int):
high = float(high)
if self._validate_args(low, high):
self.batch_size_unknown = True
self.low = low
self.high = high
else:
if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True
self.low, self.high = self._to_variable(low, high)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')
name = self.name + '_sample'
batch_shape = list((self.low + self.high).shape)
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.low.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp, zero_tmp.shape, min=0., max=1., seed=seed)
output = uniform_random_tmp * (zero_tmp + self.high - self.low
) + self.low
return nn.reshape(output, output_shape, name=name)
else:
output_shape = shape + batch_shape
output = nn.uniform_random(
output_shape, seed=seed) * (tensor.zeros(
output_shape, dtype=self.low.dtype) +
(self.high - self.low))
output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
name = self.name + '_log_prob'
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub(
nn.log(lb * ub), nn.log(self.high - self.low), name=name)
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub(
nn.log(lb * ub), nn.log(self.high - self.low), name=name)
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
name = self.name + '_probs'
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name)
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name)
def entropy(self):
"""Shannon entropy in nats.
Returns:
Tensor: Shannon entropy of uniform distribution.The data type is float32.
"""
name = self.name + '_entropy'
return nn.log(self.high - self.low, name=name)
class Normal(Distribution):
"""The Normal distribution with location `loc` and `scale` parameters.
Mathematical details
The probability density function (pdf) is,
.. math::
pdf(x; \mu, \sigma) = \\frac{1}{Z}e^{\\frac {-0.5 (x - \mu)^2} {\sigma^2} }
.. math::
Z = (2 \pi \sigma^2)^{0.5}
In the above equation:
* :math:`loc = \mu`: is the mean.
* :math:`scale = \sigma`: is the std.
* :math:`Z`: is the normalization constant.
Args:
loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32 or int.
scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is float32 or int.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distribution import Normal
paddle.disable_static()
# Define a single scalar Normal distribution.
dist = Normal(loc=0., scale=3.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = Normal(loc=[1., 2.], scale=[11., 22.])
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample([3])
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
dist = Normal(loc=1., scale=[11., 22.])
# Complete example
value_npdata = np.array([0.8], dtype="float32")
value_tensor = paddle.to_tensor(value_npdata)
normal_a = Normal([0.], [1.])
normal_b = Normal([0.5], [2.])
sample = normal_a.sample([2])
# a random tensor created by normal distribution with shape: [2, 1]
entropy = normal_a.entropy()
# [1.4189385] with shape: [1]
lp = normal_a.log_prob(value_tensor)
# [-1.2389386] with shape: [1]
p = normal_a.probs(value_tensor)
# [0.28969154] with shape: [1]
kl = normal_a.kl_divergence(normal_b)
# [0.34939718] with shape: [1]
"""
def __init__(self, loc, scale, name=None):
if not in_dygraph_mode():
check_type(loc, 'loc',
(int, float, np.ndarray, tensor.Variable, list),
'Normal')
check_type(scale, 'scale',
(int, float, np.ndarray, tensor.Variable, list),
'Normal')
self.batch_size_unknown = False
self.all_arg_is_float = False
self.name = name if name is not None else 'Normal'
if isinstance(loc, int):
loc = float(loc)
if isinstance(scale, int):
scale = float(scale)
if self._validate_args(loc, scale):
self.batch_size_unknown = True
self.loc = loc
self.scale = scale
else:
if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True
self.loc, self.scale = self._to_variable(loc, scale)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
seed (int): Python integer number.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')
batch_shape = list((self.loc + self.scale).shape)
name = self.name + '_sample'
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.)
zero_tmp_shape = nn.shape(zero_tmp)
normal_random_tmp = nn.gaussian_random(
zero_tmp_shape, mean=0., std=1., seed=seed)
output = normal_random_tmp * (zero_tmp + self.scale) + self.loc
return nn.reshape(output, output_shape, name=name)
else:
output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed) * \
(tensor.zeros(output_shape, dtype=self.loc.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output
def entropy(self):
"""Shannon entropy in nats.
Returns:
Tensor: Shannon entropy of normal distribution.The data type is float32.
"""
name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.loc.dtype, 0.)
return elementwise_add(
0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)),
name=name)
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
if not in_dygraph_mode():
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
name = self.name + '_log_prob'
var = self.scale * self.scale
log_scale = nn.log(self.scale)
return elementwise_sub(
-1. * ((value - self.loc) * (value - self.loc)) / (2. * var),
log_scale + math.log(math.sqrt(2. * math.pi)),
name=name)
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
if not in_dygraph_mode():
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
name = self.name + '_probs'
var = self.scale * self.scale
return elementwise_div(
ops.exp(-1. * ((value - self.loc) * (value - self.loc)) /
(2. * var)), (math.sqrt(2 * math.pi) * self.scale),
name=name)
def kl_divergence(self, other):
"""The KL-divergence between two normal distributions.
Args:
other (Normal): instance of Normal.
Returns:
Tensor: kl-divergence between two normal distributions.The data type is float32.
"""
if not in_dygraph_mode():
check_type(other, 'other', Normal, 'kl_divergence')
name = self.name + '_kl_divergence'
var_ratio = self.scale / other.scale
var_ratio = (var_ratio * var_ratio)
t1 = (self.loc - other.loc) / other.scale
t1 = (t1 * t1)
return elementwise_add(
0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
import paddle
from paddle import fluid
from paddle.fluid import layers
from paddle.distribution import *
import math
class DistributionNumpy():
def sample(self):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def kl_divergence(self, other):
raise NotImplementedError
def log_prob(self, value):
raise NotImplementedError
def probs(self, value):
raise NotImplementedError
class UniformNumpy(DistributionNumpy):
def __init__(self, low, high):
self.low = np.array(low).astype('float32')
self.high = np.array(high).astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.low + self.high).shape
return self.low + (np.random.uniform(size=shape) *
(self.high - self.low))
def log_prob(self, value):
lb = np.less(self.low, value).astype('float32')
ub = np.less(value, self.high).astype('float32')
return np.log(lb * ub) - np.log(self.high - self.low)
def probs(self, value):
lb = np.less(self.low, value).astype('float32')
ub = np.less(value, self.high).astype('float32')
return (lb * ub) / (self.high - self.low)
def entropy(self):
return np.log(self.high - self.low)
class NormalNumpy(DistributionNumpy):
def __init__(self, loc, scale):
self.loc = np.array(loc).astype('float32')
self.scale = np.array(scale).astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.loc + self.scale).shape
return self.loc + (np.random.randn(*shape) * self.scale)
def log_prob(self, value):
var = self.scale * self.scale
log_scale = np.log(self.scale)
return -((value - self.loc) * (value - self.loc)) / (
2. * var) - log_scale - math.log(math.sqrt(2. * math.pi))
def probs(self, value):
var = self.scale * self.scale
return np.exp(-1. * ((value - self.loc) * (value - self.loc)) /
(2. * var)) / (math.sqrt(2 * math.pi) * self.scale)
def entropy(self):
return 0.5 + 0.5 * np.log(np.array(2. * math.pi).astype(
'float32')) + np.log(self.scale)
def kl_divergence(self, other):
var_ratio = (self.scale / other.scale)
var_ratio = var_ratio * var_ratio
t1 = ((self.loc - other.loc) / other.scale)
t1 = (t1 * t1)
return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio))
class DistributionTest(unittest.TestCase):
def setUp(self, use_gpu=False):
self.use_gpu = use_gpu
if not use_gpu:
place = fluid.CPUPlace()
self.gpu_id = -1
else:
place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.executor = fluid.Executor(place)
def build_normal_common_net(self, batch_size, dims, loc_float, scale_float,
other_loc_float, other_scale_float, scale_np,
other_scale_np, loc_np, other_loc_np, loc,
scale, other_loc, other_scale, values):
normal_int = Normal(int(loc_float), int(scale_float))
normal_float = Normal(loc_float, scale_float)
other_normal_float = Normal(other_loc_float, other_scale_float)
normal_float_np_broadcast = Normal(loc_float, scale_np)
other_normal_float_np_broadcast = Normal(other_loc_float,
other_scale_np)
normal_np = Normal(loc_np, scale_np)
other_normal_np = Normal(other_loc_np, other_scale_np)
normal_variable = Normal(loc, scale)
other_normal_variable = Normal(other_loc, other_scale)
sample_int = normal_int.sample([batch_size, dims])
sample_float = normal_float.sample([batch_size, dims])
sample_float_np_broadcast = normal_float_np_broadcast.sample(
[batch_size, dims])
sample_np = normal_np.sample([batch_size, dims])
sample_variable = normal_variable.sample([batch_size, dims])
entropy_int = normal_int.entropy()
entropy_float = normal_float.entropy()
entropy_float_np_broadcast = normal_float_np_broadcast.entropy()
entropy_np = normal_np.entropy()
entropy_variable = normal_variable.entropy()
lp_float_np_broadcast = normal_float_np_broadcast.log_prob(values)
lp_np = normal_np.log_prob(values)
lp_variable = normal_variable.log_prob(values)
p_float_np_broadcast = normal_float_np_broadcast.probs(values)
p_np = normal_np.probs(values)
p_variable = normal_variable.probs(values)
kl_float = normal_float.kl_divergence(other_normal_float)
kl_float_np_broadcast = normal_float_np_broadcast.kl_divergence(
other_normal_float_np_broadcast)
kl_np = normal_np.kl_divergence(other_normal_np)
kl_variable = normal_variable.kl_divergence(other_normal_variable)
fetch_list = [
sample_int, sample_float, sample_float_np_broadcast, sample_np,
sample_variable, entropy_int, entropy_float,
entropy_float_np_broadcast, entropy_np, entropy_variable,
lp_float_np_broadcast, lp_np, lp_variable, p_float_np_broadcast,
p_np, p_variable, kl_float, kl_float_np_broadcast, kl_np,
kl_variable
]
return fetch_list
def build_normal_static(self, test_program, batch_size, dims, loc_float,
scale_float, other_loc_float, other_scale_float,
scale_np, other_scale_np, loc_np, other_loc_np,
values_np):
with fluid.program_guard(test_program):
loc = layers.data(name='loc', shape=[dims], dtype='float32')
scale = layers.data(name='scale', shape=[dims], dtype='float32')
other_loc = layers.data(
name='other_loc', shape=[dims], dtype='float32')
other_scale = layers.data(
name='other_scale', shape=[dims], dtype='float32')
values = layers.data(name='values', shape=[dims], dtype='float32')
fetch_list = self.build_normal_common_net(
batch_size, dims, loc_float, scale_float, other_loc_float,
other_scale_float, scale_np, other_scale_np, loc_np,
other_loc_np, loc, scale, other_loc, other_scale, values)
feed_vars = {
'loc': loc_np,
'scale': scale_np,
'other_loc': other_loc_np,
'other_scale': other_scale_np,
'values': values_np
}
return feed_vars, fetch_list
def build_normal_dygraph(self, batch_size, dims, loc_float, scale_float,
other_loc_float, other_scale_float, scale_np,
other_scale_np, loc_np, other_loc_np, values_np):
loc = paddle.to_tensor(loc_np)
scale = paddle.to_tensor(scale_np)
other_loc = paddle.to_tensor(other_loc_np)
other_scale = paddle.to_tensor(other_scale_np)
values = paddle.to_tensor(values_np)
fetch_list = self.build_normal_common_net(
batch_size, dims, loc_float, scale_float, other_loc_float,
other_scale_float, scale_np, other_scale_np, loc_np, other_loc_np,
loc, scale, other_loc, other_scale, values)
fetch_list_numpy = [t.numpy() for t in fetch_list]
return fetch_list_numpy
def get_normal_random_input(self, batch_size, dims):
loc_np = np.random.randn(batch_size, dims).astype('float32')
other_loc_np = np.random.randn(batch_size, dims).astype('float32')
loc_float = (np.random.ranf() - 0.5) * 4
scale_float = (np.random.ranf() - 0.5) * 4
while scale_float < 0:
scale_float = (np.random.ranf() - 0.5) * 4
other_loc_float = (np.random.ranf() - 0.5) * 4
other_scale_float = (np.random.ranf() - 0.5) * 4
while other_scale_float < 0:
other_scale_float = (np.random.ranf() - 0.5) * 4
scale_np = np.random.randn(batch_size, dims).astype('float32')
other_scale_np = np.random.randn(batch_size, dims).astype('float32')
values_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(scale_np > 0):
scale_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(other_scale_np > 0):
other_scale_np = np.random.randn(batch_size, dims).astype('float32')
return [
loc_np, other_loc_np, loc_float, scale_float, other_loc_float,
other_scale_float, scale_np, other_scale_np, values_np
]
def compare_normal_with_numpy(self,
data_list,
output_list,
batch_size=2,
dims=3,
tolerance=1e-6):
loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list
np_normal_int = NormalNumpy(int(loc_float), int(scale_float))
np_normal_float = NormalNumpy(loc_float, scale_float)
np_other_normal_float = NormalNumpy(other_loc_float, other_scale_float)
np_normal_float_np_broadcast = NormalNumpy(loc_float, scale_np)
np_other_normal_float_np_broadcast = NormalNumpy(other_loc_float,
other_scale_np)
np_normal = NormalNumpy(loc_np, scale_np)
np_other_normal = NormalNumpy(other_loc_np, other_scale_np)
gt_sample_int = np_normal_int.sample([batch_size, dims])
gt_sample_float = np_normal_float.sample([batch_size, dims])
gt_sample_float_np_broadcast = np_normal_float_np_broadcast.sample(
[batch_size, dims])
gt_sample_np = np_normal.sample([batch_size, dims])
gt_entropy_int = np_normal_int.entropy()
gt_entropy_float = np_normal_float.entropy()
gt_entropy_float_np_broadcast = np_normal_float_np_broadcast.entropy()
gt_entropy = np_normal.entropy()
gt_lp_float_np_broadcast = np_normal_float_np_broadcast.log_prob(
values_np)
gt_lp = np_normal.log_prob(values_np)
gt_p_float_np_broadcast = np_normal_float_np_broadcast.probs(values_np)
gt_p = np_normal.probs(values_np)
gt_kl_float = np_normal_float.kl_divergence(np_other_normal_float)
gt_kl_float_np_broadcast = np_normal_float_np_broadcast.kl_divergence(
np_other_normal_float_np_broadcast)
gt_kl = np_normal.kl_divergence(np_other_normal)
[
output_sample_int, output_sample_float,
output_sample_float_np_broadcast, output_sample_np,
output_sample_variable, output_entropy_int, output_entropy_float,
output_entropy_float_np_broadcast, output_entropy_np,
output_entropy_variable, output_lp_float_np_broadcast, output_lp_np,
output_lp_variable, output_p_float_np_broadcast, output_p_np,
output_p_variable, output_kl_float, output_kl_float_np_broadcast,
output_kl_np, output_kl_variable
] = output_list
np.testing.assert_allclose(
output_sample_int.shape,
gt_sample_int.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_float.shape,
gt_sample_float.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_float_np_broadcast.shape,
gt_sample_float_np_broadcast.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_np.shape,
gt_sample_np.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_variable.shape,
gt_sample_np.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_int, gt_entropy_int, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_float,
gt_entropy_float,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_float_np_broadcast,
gt_entropy_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_np, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_variable, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_float_np_broadcast,
gt_lp_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_lp_np, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_variable, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_p_float_np_broadcast,
gt_p_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_p_np, gt_p, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_p_variable, gt_p, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_kl_float, gt_kl_float, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_kl_float_np_broadcast,
gt_kl_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_kl_np, gt_kl, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_kl_variable, gt_kl, rtol=tolerance, atol=tolerance)
def test_normal_distribution_static(self,
batch_size=2,
dims=3,
tolerance=1e-6):
test_program = fluid.Program()
data_list = self.get_normal_random_input(batch_size, dims)
loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list
feed_vars, fetch_list = self.build_normal_static(
test_program, batch_size, dims, loc_float, scale_float,
other_loc_float, other_scale_float, scale_np, other_scale_np,
loc_np, other_loc_np, values_np)
self.executor.run(fluid.default_startup_program())
output_list = self.executor.run(program=test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_normal_with_numpy(data_list, output_list, batch_size, dims,
tolerance)
def test_normal_distribution_dygraph(self,
batch_size=2,
dims=3,
tolerance=1e-6):
paddle.disable_static()
data_list = self.get_normal_random_input(batch_size, dims)
loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list
output_list = self.build_normal_dygraph(
batch_size, dims, loc_float, scale_float, other_loc_float,
other_scale_float, scale_np, other_scale_np, loc_np, other_loc_np,
values_np)
self.compare_normal_with_numpy(data_list, output_list, batch_size, dims,
tolerance)
paddle.enable_static()
def build_uniform_common_net(self, batch_size, dims, low_float, high_float,
high_np, low_np, values_np, low, high, values):
uniform_int = Uniform(int(low_float), int(high_float))
uniform_float = Uniform(low_float, high_float)
uniform_float_np_broadcast = Uniform(low_float, high_np)
uniform_np = Uniform(low_np, high_np)
uniform_variable = Uniform(low, high)
sample_int = uniform_int.sample([batch_size, dims])
sample_float = uniform_float.sample([batch_size, dims])
sample_float_np_broadcast = uniform_float_np_broadcast.sample(
[batch_size, dims])
sample_np = uniform_np.sample([batch_size, dims])
sample_variable = uniform_variable.sample([batch_size, dims])
entropy_int = uniform_int.entropy()
entropy_float = uniform_float.entropy()
entropy_float_np_broadcast = uniform_float_np_broadcast.entropy()
entropy_np = uniform_np.entropy()
entropy_variable = uniform_variable.entropy()
lp_float_np_broadcast = uniform_float_np_broadcast.log_prob(values)
lp_np = uniform_np.log_prob(values)
lp_variable = uniform_variable.log_prob(values)
p_float_np_broadcast = uniform_float_np_broadcast.probs(values)
p_np = uniform_np.probs(values)
p_variable = uniform_variable.probs(values)
fetch_list = [
sample_int, sample_float, sample_float_np_broadcast, sample_np,
sample_variable, entropy_int, entropy_float,
entropy_float_np_broadcast, entropy_np, entropy_variable,
lp_float_np_broadcast, lp_np, lp_variable, p_float_np_broadcast,
p_np, p_variable
]
return fetch_list
def build_uniform_static(self, test_program, batch_size, dims, low_float,
high_float, high_np, low_np, values_np):
with fluid.program_guard(test_program):
low = layers.data(name='low', shape=[dims], dtype='float32')
high = layers.data(name='high', shape=[dims], dtype='float32')
values = layers.data(name='values', shape=[dims], dtype='float32')
fetch_list = self.build_uniform_common_net(
batch_size, dims, low_float, high_float, high_np, low_np,
values_np, low, high, values)
feed_vars = {'low': low_np, 'high': high_np, 'values': values_np}
return feed_vars, fetch_list
def build_uniform_dygraph(self, batch_size, dims, low_float, high_float,
high_np, low_np, values_np):
low = paddle.to_tensor(low_np)
high = paddle.to_tensor(high_np)
values = paddle.to_tensor(values_np)
fetch_list = self.build_uniform_common_net(batch_size, dims, low_float,
high_float, high_np, low_np,
values_np, low, high, values)
fetch_list_numpy = [t.numpy() for t in fetch_list]
return fetch_list_numpy
def compare_uniform_with_numpy(self,
data_list,
output_list,
batch_size=2,
dims=3,
tolerance=1e-6):
[low_np, low_float, high_float, high_np, values_np] = data_list
np_uniform_int = UniformNumpy(int(low_float), int(high_float))
np_uniform_float = UniformNumpy(low_float, high_float)
np_uniform_float_np_broadcast = UniformNumpy(low_float, high_np)
np_uniform = UniformNumpy(low_np, high_np)
gt_sample_int = np_uniform_int.sample([batch_size, dims])
gt_sample_float = np_uniform_float.sample([batch_size, dims])
gt_sample_float_np_broadcast = np_uniform_float_np_broadcast.sample(
[batch_size, dims])
gt_sample_np = np_uniform.sample([batch_size, dims])
gt_entropy_int = np_uniform_int.entropy()
gt_entropy_float = np_uniform_float.entropy()
gt_entropy_float_np_broadcast = np_uniform_float_np_broadcast.entropy()
gt_entropy = np_uniform.entropy()
gt_lp_float_np_broadcast = np_uniform_float_np_broadcast.log_prob(
values_np)
gt_lp = np_uniform.log_prob(values_np)
gt_p_float_np_broadcast = np_uniform_float_np_broadcast.probs(values_np)
gt_p = np_uniform.probs(values_np)
[
output_sample_int, output_sample_float,
output_sample_float_np_broadcast, output_sample_np,
output_sample_variable, output_entropy_int, output_entropy_float,
output_entropy_float_np_broadcast, output_entropy_np,
output_entropy_variable, output_lp_float_np_broadcast, output_lp_np,
output_lp_variable, output_p_float_np_broadcast, output_p_np,
output_p_variable
] = output_list
np.testing.assert_allclose(
output_sample_int.shape,
gt_sample_int.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_float.shape,
gt_sample_float.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_float_np_broadcast.shape,
gt_sample_float_np_broadcast.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_np.shape,
gt_sample_np.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_sample_variable.shape,
gt_sample_np.shape,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_int, gt_entropy_int, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_float,
gt_entropy_float,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_float_np_broadcast,
gt_entropy_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_entropy_np, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_entropy_variable, gt_entropy, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_float_np_broadcast,
gt_lp_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_lp_np, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_lp_variable, gt_lp, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_p_float_np_broadcast,
gt_p_float_np_broadcast,
rtol=tolerance,
atol=tolerance)
np.testing.assert_allclose(
output_p_np, gt_p, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
output_p_variable, gt_p, rtol=tolerance, atol=tolerance)
def test_uniform_distribution_static(self,
batch_size=2,
dims=3,
tolerance=1e-6):
test_program = fluid.Program()
low_np = np.random.randn(batch_size, dims).astype('float32')
low_float = np.random.uniform(-2, 1)
high_float = np.random.uniform(1, 3)
high_np = np.random.uniform(-5.0, 5.0,
(batch_size, dims)).astype('float32')
values_np = np.random.randn(batch_size, dims).astype('float32')
data_list = [low_np, low_float, high_float, high_np, values_np]
feed_vars, fetch_list = self.build_uniform_static(
test_program, batch_size, dims, low_float, high_float, high_np,
low_np, values_np)
self.executor.run(fluid.default_startup_program())
# result calculated by paddle
output_list = self.executor.run(program=test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_uniform_with_numpy(data_list, output_list, batch_size,
dims, tolerance)
def test_uniform_distribution_dygraph(self,
batch_size=2,
dims=3,
tolerance=1e-6):
paddle.disable_static()
low_np = np.random.randn(batch_size, dims).astype('float32')
low_float = np.random.uniform(-2, 1)
high_float = np.random.uniform(1, 3)
high_np = np.random.uniform(-5.0, 5.0,
(batch_size, dims)).astype('float32')
values_np = np.random.randn(batch_size, dims).astype('float32')
data_list = [low_np, low_float, high_float, high_np, values_np]
output_list = self.build_uniform_dygraph(
batch_size, dims, low_float, high_float, high_np, low_np, values_np)
self.compare_uniform_with_numpy(data_list, output_list, batch_size,
dims, tolerance)
paddle.enable_static()
class DistributionTestError(unittest.TestCase):
def test_distribution_error(self):
distribution = Distribution()
self.assertRaises(NotImplementedError, distribution.sample)
self.assertRaises(NotImplementedError, distribution.entropy)
normal = Normal(0.0, 1.0)
self.assertRaises(NotImplementedError, distribution.kl_divergence,
normal)
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
self.assertRaises(NotImplementedError, distribution.log_prob,
value_tensor)
self.assertRaises(NotImplementedError, distribution.probs, value_tensor)
def test_normal_error(self):
normal = Normal(0.0, 1.0)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, normal.log_prob, value)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, normal.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, normal.sample, shape)
seed = 1.0
# type of seed must be int
self.assertRaises(TypeError, normal.sample, [2, 3], seed)
normal_other = Uniform(1.0, 2.0)
# type of other must be an instance of Normal
self.assertRaises(TypeError, normal.kl_divergence, normal_other)
def test_uniform_error(self):
uniform = Uniform(0.0, 1.0)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, uniform.log_prob, value)
value = [1.0, 2.0]
# type of value must be variable
self.assertRaises(TypeError, uniform.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, uniform.sample, shape)
seed = 1.0
# type of seed must be int
self.assertRaises(TypeError, uniform.sample, [2, 3], seed)
class DistributionTestName(unittest.TestCase):
def get_prefix(self, string):
return (string.split('.')[0])
def test_normal_name(self):
name = 'test_normal'
normal1 = Normal(0.0, 1.0, name=name)
self.assertEqual(normal1.name, name)
normal2 = Normal(0.0, 1.0)
self.assertEqual(normal2.name, 'Normal')
paddle.enable_static()
sample = normal1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = normal1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
layers.assign(value_npdata, value_tensor)
lp = normal1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
p = normal1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
kl = normal1.kl_divergence(normal2)
self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence')
def test_uniform_name(self):
name = 'test_uniform'
uniform1 = Uniform(0.0, 1.0, name=name)
self.assertEqual(uniform1.name, name)
uniform2 = Uniform(0.0, 1.0)
self.assertEqual(uniform2.name, 'Uniform')
paddle.enable_static()
sample = uniform1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = uniform1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
value_npdata = np.array([0.8], dtype="float32")
value_tensor = layers.create_tensor(dtype="float32")
layers.assign(value_npdata, value_tensor)
lp = uniform1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
p = uniform1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册