diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 8404e82c544955732fd44040a9f57fd6eeb398bd..b61ea8ed679eef11202d5ce0c1179d627564d178 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -38,6 +38,7 @@ import paddle.compat import paddle.distributed import paddle.sysconfig import paddle.tensor +import paddle.distribution import paddle.nn import paddle.distributed.fleet import paddle.optimizer diff --git a/python/paddle/distribution.py b/python/paddle/distribution.py index fff10c5b2a9ee497cccff94346314db2c8011eb5..ba4bfa8708b2ae3d8ae6643393e0f87cc9c6b360 100644 --- a/python/paddle/distribution.py +++ b/python/paddle/distribution.py @@ -18,3 +18,517 @@ # 'Normal', # 'sampling_id', # 'Uniform'] + +from __future__ import print_function + +from .fluid.layers import control_flow +from .fluid.layers import tensor +from .fluid.layers import ops +from .fluid.layers import nn +from .fluid.framework import in_dygraph_mode +from .tensor.math import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub +import math +import numpy as np +import warnings + +from .fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype + +__all__ = ['Distribution', 'Uniform', 'Normal'] + + +class Distribution(object): + """ + The abstract base class for probability distributions. Functions are + implemented in specific distributions. + """ + + def __init__(self): + super(Distribution, self).__init__() + + def sample(self): + """Sampling from the distribution.""" + raise NotImplementedError + + def entropy(self): + """The entropy of the distribution.""" + raise NotImplementedError + + def kl_divergence(self, other): + """The KL-divergence between self distributions and other.""" + raise NotImplementedError + + def log_prob(self, value): + """Log probability density/mass function.""" + raise NotImplementedError + + def probs(self, value): + """Probability density/mass function.""" + raise NotImplementedError + + def _validate_args(self, *args): + """ + Argument validation for distribution args + Args: + value (float, list, numpy.ndarray, Tensor) + Raises + ValueError: if one argument is Tensor, all arguments should be Tensor + """ + is_variable = False + is_number = False + for arg in args: + if isinstance(arg, tensor.Variable): + is_variable = True + else: + is_number = True + + if is_variable and is_number: + raise ValueError( + 'if one argument is Tensor, all arguments should be Tensor') + + return is_variable + + def _to_variable(self, *args): + """ + Argument convert args to Tensor + + Args: + value (float, list, numpy.ndarray, Tensor) + Returns: + Tensor of args. + """ + numpy_args = [] + variable_args = [] + tmp = 0. + + for arg in args: + valid_arg = False + for cls in [float, list, np.ndarray, tensor.Variable]: + if isinstance(arg, cls): + valid_arg = True + break + assert valid_arg, "type of input args must be float, list, numpy.ndarray or Tensor." + if isinstance(arg, float): + arg = np.zeros(1) + arg + arg_np = np.array(arg) + arg_dtype = arg_np.dtype + if str(arg_dtype) not in ['float32']: + warnings.warn( + "data type of argument only support float32, your argument will be convert to float32." + ) + arg_np = arg_np.astype('float32') + tmp = tmp + arg_np + numpy_args.append(arg_np) + + dtype = tmp.dtype + for arg in numpy_args: + arg_broadcasted, _ = np.broadcast_arrays(arg, tmp) + arg_variable = tensor.create_tensor(dtype=dtype) + tensor.assign(arg_broadcasted, arg_variable) + variable_args.append(arg_variable) + + return tuple(variable_args) + + +class Uniform(Distribution): + """Uniform distribution with `low` and `high` parameters. + + Mathematical Details + + The probability density function (pdf) is, + + .. math:: + + pdf(x; a, b) = \\frac{1}{Z}, \ a <=x 0): + scale_np = np.random.randn(batch_size, dims).astype('float32') + while not np.all(other_scale_np > 0): + other_scale_np = np.random.randn(batch_size, dims).astype('float32') + return [ + loc_np, other_loc_np, loc_float, scale_float, other_loc_float, + other_scale_float, scale_np, other_scale_np, values_np + ] + + def compare_normal_with_numpy(self, + data_list, + output_list, + batch_size=2, + dims=3, + tolerance=1e-6): + loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list + + np_normal_int = NormalNumpy(int(loc_float), int(scale_float)) + np_normal_float = NormalNumpy(loc_float, scale_float) + np_other_normal_float = NormalNumpy(other_loc_float, other_scale_float) + np_normal_float_np_broadcast = NormalNumpy(loc_float, scale_np) + np_other_normal_float_np_broadcast = NormalNumpy(other_loc_float, + other_scale_np) + np_normal = NormalNumpy(loc_np, scale_np) + np_other_normal = NormalNumpy(other_loc_np, other_scale_np) + + gt_sample_int = np_normal_int.sample([batch_size, dims]) + gt_sample_float = np_normal_float.sample([batch_size, dims]) + gt_sample_float_np_broadcast = np_normal_float_np_broadcast.sample( + [batch_size, dims]) + gt_sample_np = np_normal.sample([batch_size, dims]) + gt_entropy_int = np_normal_int.entropy() + gt_entropy_float = np_normal_float.entropy() + gt_entropy_float_np_broadcast = np_normal_float_np_broadcast.entropy() + gt_entropy = np_normal.entropy() + gt_lp_float_np_broadcast = np_normal_float_np_broadcast.log_prob( + values_np) + gt_lp = np_normal.log_prob(values_np) + gt_p_float_np_broadcast = np_normal_float_np_broadcast.probs(values_np) + gt_p = np_normal.probs(values_np) + gt_kl_float = np_normal_float.kl_divergence(np_other_normal_float) + gt_kl_float_np_broadcast = np_normal_float_np_broadcast.kl_divergence( + np_other_normal_float_np_broadcast) + gt_kl = np_normal.kl_divergence(np_other_normal) + + [ + output_sample_int, output_sample_float, + output_sample_float_np_broadcast, output_sample_np, + output_sample_variable, output_entropy_int, output_entropy_float, + output_entropy_float_np_broadcast, output_entropy_np, + output_entropy_variable, output_lp_float_np_broadcast, output_lp_np, + output_lp_variable, output_p_float_np_broadcast, output_p_np, + output_p_variable, output_kl_float, output_kl_float_np_broadcast, + output_kl_np, output_kl_variable + ] = output_list + + np.testing.assert_allclose( + output_sample_int.shape, + gt_sample_int.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_sample_float.shape, + gt_sample_float.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_sample_float_np_broadcast.shape, + gt_sample_float_np_broadcast.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_sample_np.shape, + gt_sample_np.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_sample_variable.shape, + gt_sample_np.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_entropy_int, gt_entropy_int, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_entropy_float, + gt_entropy_float, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_entropy_float_np_broadcast, + gt_entropy_float_np_broadcast, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_entropy_np, gt_entropy, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_entropy_variable, gt_entropy, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_lp_float_np_broadcast, + gt_lp_float_np_broadcast, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_lp_np, gt_lp, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_lp_variable, gt_lp, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_p_float_np_broadcast, + gt_p_float_np_broadcast, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_p_np, gt_p, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_p_variable, gt_p, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_kl_float, gt_kl_float, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_kl_float_np_broadcast, + gt_kl_float_np_broadcast, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_kl_np, gt_kl, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_kl_variable, gt_kl, rtol=tolerance, atol=tolerance) + + def test_normal_distribution_static(self, + batch_size=2, + dims=3, + tolerance=1e-6): + test_program = fluid.Program() + data_list = self.get_normal_random_input(batch_size, dims) + loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list + + feed_vars, fetch_list = self.build_normal_static( + test_program, batch_size, dims, loc_float, scale_float, + other_loc_float, other_scale_float, scale_np, other_scale_np, + loc_np, other_loc_np, values_np) + self.executor.run(fluid.default_startup_program()) + + output_list = self.executor.run(program=test_program, + feed=feed_vars, + fetch_list=fetch_list) + + self.compare_normal_with_numpy(data_list, output_list, batch_size, dims, + tolerance) + + def test_normal_distribution_dygraph(self, + batch_size=2, + dims=3, + tolerance=1e-6): + paddle.disable_static() + data_list = self.get_normal_random_input(batch_size, dims) + loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list + + output_list = self.build_normal_dygraph( + batch_size, dims, loc_float, scale_float, other_loc_float, + other_scale_float, scale_np, other_scale_np, loc_np, other_loc_np, + values_np) + + self.compare_normal_with_numpy(data_list, output_list, batch_size, dims, + tolerance) + paddle.enable_static() + + def build_uniform_common_net(self, batch_size, dims, low_float, high_float, + high_np, low_np, values_np, low, high, values): + uniform_int = Uniform(int(low_float), int(high_float)) + uniform_float = Uniform(low_float, high_float) + uniform_float_np_broadcast = Uniform(low_float, high_np) + uniform_np = Uniform(low_np, high_np) + uniform_variable = Uniform(low, high) + + sample_int = uniform_int.sample([batch_size, dims]) + sample_float = uniform_float.sample([batch_size, dims]) + sample_float_np_broadcast = uniform_float_np_broadcast.sample( + [batch_size, dims]) + sample_np = uniform_np.sample([batch_size, dims]) + sample_variable = uniform_variable.sample([batch_size, dims]) + + entropy_int = uniform_int.entropy() + entropy_float = uniform_float.entropy() + entropy_float_np_broadcast = uniform_float_np_broadcast.entropy() + entropy_np = uniform_np.entropy() + entropy_variable = uniform_variable.entropy() + + lp_float_np_broadcast = uniform_float_np_broadcast.log_prob(values) + lp_np = uniform_np.log_prob(values) + lp_variable = uniform_variable.log_prob(values) + + p_float_np_broadcast = uniform_float_np_broadcast.probs(values) + p_np = uniform_np.probs(values) + p_variable = uniform_variable.probs(values) + + fetch_list = [ + sample_int, sample_float, sample_float_np_broadcast, sample_np, + sample_variable, entropy_int, entropy_float, + entropy_float_np_broadcast, entropy_np, entropy_variable, + lp_float_np_broadcast, lp_np, lp_variable, p_float_np_broadcast, + p_np, p_variable + ] + return fetch_list + + def build_uniform_static(self, test_program, batch_size, dims, low_float, + high_float, high_np, low_np, values_np): + with fluid.program_guard(test_program): + low = layers.data(name='low', shape=[dims], dtype='float32') + high = layers.data(name='high', shape=[dims], dtype='float32') + + values = layers.data(name='values', shape=[dims], dtype='float32') + + fetch_list = self.build_uniform_common_net( + batch_size, dims, low_float, high_float, high_np, low_np, + values_np, low, high, values) + + feed_vars = {'low': low_np, 'high': high_np, 'values': values_np} + return feed_vars, fetch_list + + def build_uniform_dygraph(self, batch_size, dims, low_float, high_float, + high_np, low_np, values_np): + low = paddle.to_tensor(low_np) + high = paddle.to_tensor(high_np) + values = paddle.to_tensor(values_np) + + fetch_list = self.build_uniform_common_net(batch_size, dims, low_float, + high_float, high_np, low_np, + values_np, low, high, values) + fetch_list_numpy = [t.numpy() for t in fetch_list] + return fetch_list_numpy + + def compare_uniform_with_numpy(self, + data_list, + output_list, + batch_size=2, + dims=3, + tolerance=1e-6): + [low_np, low_float, high_float, high_np, values_np] = data_list + + np_uniform_int = UniformNumpy(int(low_float), int(high_float)) + np_uniform_float = UniformNumpy(low_float, high_float) + np_uniform_float_np_broadcast = UniformNumpy(low_float, high_np) + np_uniform = UniformNumpy(low_np, high_np) + + gt_sample_int = np_uniform_int.sample([batch_size, dims]) + gt_sample_float = np_uniform_float.sample([batch_size, dims]) + gt_sample_float_np_broadcast = np_uniform_float_np_broadcast.sample( + [batch_size, dims]) + gt_sample_np = np_uniform.sample([batch_size, dims]) + gt_entropy_int = np_uniform_int.entropy() + gt_entropy_float = np_uniform_float.entropy() + gt_entropy_float_np_broadcast = np_uniform_float_np_broadcast.entropy() + gt_entropy = np_uniform.entropy() + gt_lp_float_np_broadcast = np_uniform_float_np_broadcast.log_prob( + values_np) + gt_lp = np_uniform.log_prob(values_np) + gt_p_float_np_broadcast = np_uniform_float_np_broadcast.probs(values_np) + gt_p = np_uniform.probs(values_np) + + [ + output_sample_int, output_sample_float, + output_sample_float_np_broadcast, output_sample_np, + output_sample_variable, output_entropy_int, output_entropy_float, + output_entropy_float_np_broadcast, output_entropy_np, + output_entropy_variable, output_lp_float_np_broadcast, output_lp_np, + output_lp_variable, output_p_float_np_broadcast, output_p_np, + output_p_variable + ] = output_list + + np.testing.assert_allclose( + output_sample_int.shape, + gt_sample_int.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_sample_float.shape, + gt_sample_float.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_sample_float_np_broadcast.shape, + gt_sample_float_np_broadcast.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_sample_np.shape, + gt_sample_np.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_sample_variable.shape, + gt_sample_np.shape, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_entropy_int, gt_entropy_int, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_entropy_float, + gt_entropy_float, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_entropy_float_np_broadcast, + gt_entropy_float_np_broadcast, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_entropy_np, gt_entropy, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_entropy_variable, gt_entropy, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_lp_float_np_broadcast, + gt_lp_float_np_broadcast, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_lp_np, gt_lp, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_lp_variable, gt_lp, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_p_float_np_broadcast, + gt_p_float_np_broadcast, + rtol=tolerance, + atol=tolerance) + np.testing.assert_allclose( + output_p_np, gt_p, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + output_p_variable, gt_p, rtol=tolerance, atol=tolerance) + + def test_uniform_distribution_static(self, + batch_size=2, + dims=3, + tolerance=1e-6): + test_program = fluid.Program() + + low_np = np.random.randn(batch_size, dims).astype('float32') + low_float = np.random.uniform(-2, 1) + high_float = np.random.uniform(1, 3) + high_np = np.random.uniform(-5.0, 5.0, + (batch_size, dims)).astype('float32') + values_np = np.random.randn(batch_size, dims).astype('float32') + + data_list = [low_np, low_float, high_float, high_np, values_np] + + feed_vars, fetch_list = self.build_uniform_static( + test_program, batch_size, dims, low_float, high_float, high_np, + low_np, values_np) + + self.executor.run(fluid.default_startup_program()) + + # result calculated by paddle + output_list = self.executor.run(program=test_program, + feed=feed_vars, + fetch_list=fetch_list) + self.compare_uniform_with_numpy(data_list, output_list, batch_size, + dims, tolerance) + + def test_uniform_distribution_dygraph(self, + batch_size=2, + dims=3, + tolerance=1e-6): + paddle.disable_static() + + low_np = np.random.randn(batch_size, dims).astype('float32') + low_float = np.random.uniform(-2, 1) + high_float = np.random.uniform(1, 3) + high_np = np.random.uniform(-5.0, 5.0, + (batch_size, dims)).astype('float32') + values_np = np.random.randn(batch_size, dims).astype('float32') + + data_list = [low_np, low_float, high_float, high_np, values_np] + output_list = self.build_uniform_dygraph( + batch_size, dims, low_float, high_float, high_np, low_np, values_np) + + self.compare_uniform_with_numpy(data_list, output_list, batch_size, + dims, tolerance) + paddle.enable_static() + + +class DistributionTestError(unittest.TestCase): + def test_distribution_error(self): + distribution = Distribution() + + self.assertRaises(NotImplementedError, distribution.sample) + self.assertRaises(NotImplementedError, distribution.entropy) + + normal = Normal(0.0, 1.0) + self.assertRaises(NotImplementedError, distribution.kl_divergence, + normal) + + value_npdata = np.array([0.8], dtype="float32") + value_tensor = layers.create_tensor(dtype="float32") + self.assertRaises(NotImplementedError, distribution.log_prob, + value_tensor) + self.assertRaises(NotImplementedError, distribution.probs, value_tensor) + + def test_normal_error(self): + normal = Normal(0.0, 1.0) + + value = [1.0, 2.0] + # type of value must be variable + self.assertRaises(TypeError, normal.log_prob, value) + + value = [1.0, 2.0] + # type of value must be variable + self.assertRaises(TypeError, normal.probs, value) + + shape = 1.0 + # type of shape must be list + self.assertRaises(TypeError, normal.sample, shape) + + seed = 1.0 + # type of seed must be int + self.assertRaises(TypeError, normal.sample, [2, 3], seed) + + normal_other = Uniform(1.0, 2.0) + # type of other must be an instance of Normal + self.assertRaises(TypeError, normal.kl_divergence, normal_other) + + def test_uniform_error(self): + uniform = Uniform(0.0, 1.0) + + value = [1.0, 2.0] + # type of value must be variable + self.assertRaises(TypeError, uniform.log_prob, value) + + value = [1.0, 2.0] + # type of value must be variable + self.assertRaises(TypeError, uniform.probs, value) + + shape = 1.0 + # type of shape must be list + self.assertRaises(TypeError, uniform.sample, shape) + + seed = 1.0 + # type of seed must be int + self.assertRaises(TypeError, uniform.sample, [2, 3], seed) + + +class DistributionTestName(unittest.TestCase): + def get_prefix(self, string): + return (string.split('.')[0]) + + def test_normal_name(self): + name = 'test_normal' + normal1 = Normal(0.0, 1.0, name=name) + self.assertEqual(normal1.name, name) + + normal2 = Normal(0.0, 1.0) + self.assertEqual(normal2.name, 'Normal') + + paddle.enable_static() + + sample = normal1.sample([2]) + self.assertEqual(self.get_prefix(sample.name), name + '_sample') + + entropy = normal1.entropy() + self.assertEqual(self.get_prefix(entropy.name), name + '_entropy') + + value_npdata = np.array([0.8], dtype="float32") + value_tensor = layers.create_tensor(dtype="float32") + layers.assign(value_npdata, value_tensor) + + lp = normal1.log_prob(value_tensor) + self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') + + p = normal1.probs(value_tensor) + self.assertEqual(self.get_prefix(p.name), name + '_probs') + + kl = normal1.kl_divergence(normal2) + self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence') + + def test_uniform_name(self): + name = 'test_uniform' + uniform1 = Uniform(0.0, 1.0, name=name) + self.assertEqual(uniform1.name, name) + + uniform2 = Uniform(0.0, 1.0) + self.assertEqual(uniform2.name, 'Uniform') + + paddle.enable_static() + + sample = uniform1.sample([2]) + self.assertEqual(self.get_prefix(sample.name), name + '_sample') + + entropy = uniform1.entropy() + self.assertEqual(self.get_prefix(entropy.name), name + '_entropy') + + value_npdata = np.array([0.8], dtype="float32") + value_tensor = layers.create_tensor(dtype="float32") + layers.assign(value_npdata, value_tensor) + + lp = uniform1.log_prob(value_tensor) + self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') + + p = uniform1.probs(value_tensor) + self.assertEqual(self.get_prefix(p.name), name + '_probs') + + +if __name__ == '__main__': + unittest.main()