From 5f275aadaf7bd38001eb9fbf14909547d9ada559 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Thu, 27 Aug 2020 21:37:51 -0500 Subject: [PATCH] fix sample method of Uniform and Normal class (#26713) * fix sample shape error * Add unittest * change assert_allclose to assert_equal * Add unittest doc * fix encoding problem --- python/paddle/distribution.py | 31 +- .../tests/unittests/test_distribution.py | 405 +++++++++++++----- 2 files changed, 332 insertions(+), 104 deletions(-) diff --git a/python/paddle/distribution.py b/python/paddle/distribution.py index acb7251a15..49e98805d2 100644 --- a/python/paddle/distribution.py +++ b/python/paddle/distribution.py @@ -243,10 +243,19 @@ class Uniform(Distribution): zero_tmp = tensor.fill_constant_batch_size_like( self.low + self.high, batch_shape + shape, self.low.dtype, 0.) uniform_random_tmp = nn.uniform_random_batch_size_like( - zero_tmp, zero_tmp.shape, min=0., max=1., seed=seed) - output = uniform_random_tmp * (zero_tmp + self.high - self.low - ) + self.low - return nn.reshape(output, output_shape, name=name) + zero_tmp, + zero_tmp.shape, + dtype=convert_dtype(zero_tmp.dtype), + min=0., + max=1., + seed=seed) + zero_tmp_reshape = nn.reshape(zero_tmp, output_shape) + uniform_random_tmp_reshape = nn.reshape(uniform_random_tmp, + output_shape) + output = uniform_random_tmp_reshape * ( + zero_tmp_reshape + self.high - self.low) + output = elementwise_add(output, self.low, name=name) + return output else: output_shape = shape + batch_shape output = nn.uniform_random( @@ -446,11 +455,17 @@ class Normal(Distribution): output_shape = shape + batch_shape zero_tmp = tensor.fill_constant_batch_size_like( self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.) - zero_tmp_shape = nn.shape(zero_tmp) + zero_tmp_reshape = nn.reshape(zero_tmp, output_shape) + zero_tmp_shape = nn.shape(zero_tmp_reshape) normal_random_tmp = nn.gaussian_random( - zero_tmp_shape, mean=0., std=1., seed=seed) - output = normal_random_tmp * (zero_tmp + self.scale) + self.loc - return nn.reshape(output, output_shape, name=name) + zero_tmp_shape, + mean=0., + std=1., + seed=seed, + dtype=convert_dtype(self.loc.dtype)) + output = normal_random_tmp * (zero_tmp_reshape + self.scale) + output = elementwise_add(output, self.loc, name=name) + return output else: output_shape = shape + batch_shape output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed) * \ diff --git a/python/paddle/fluid/tests/unittests/test_distribution.py b/python/paddle/fluid/tests/unittests/test_distribution.py index 4ccaa3266e..533ad9604c 100644 --- a/python/paddle/fluid/tests/unittests/test_distribution.py +++ b/python/paddle/fluid/tests/unittests/test_distribution.py @@ -105,10 +105,59 @@ class DistributionTest(unittest.TestCase): self.gpu_id = 0 self.executor = fluid.Executor(place) - def build_normal_common_net(self, batch_size, dims, loc_float, scale_float, - other_loc_float, other_scale_float, scale_np, - other_scale_np, loc_np, other_loc_np, loc, - scale, other_loc, other_scale, values): + def build_normal_common_net(self, batch_size, dims, sample_shape, loc_float, + scale_float, other_loc_float, other_scale_float, + scale_np, other_scale_np, loc_np, other_loc_np, + loc, scale, other_loc, other_scale, values): + """Generate Normal object and get the output of its methods including + ``sample``, ``entropy``, ``log_prob``, ``probs`` and ``kl_divergence``. + Parameters ``loc`` and ``scale`` have different data types to test different situations. + + Args: + batch_size(int): The first dimension of the shape of parameters(loc and scale). + dims(int): The second dimension of the shape of parameters. + sample_shape(int): The sample value used in ``sample`` method. + loc_float(float): Generated in function ``get_normal_random_input``, loc is a float number. + scale_float(float): Generated in function ``get_normal_random_input``, scale is a float number. + other_loc_float(float): Generated in function ``get_normal_random_input``, other_loc is a + float number. It is the first parameter in another Normal object used in ``kl_divergence`` + method. + other_scale_float(float): Generated in function ``get_normal_random_input``, other_scale is a + float number. It is the second parameter in another Normal object used in ``kl_divergence`` + method. + scale_np(numpy.ndarray): Generated in function ``get_normal_random_input``, An numpy array + whose shape is [batch_size, dims]. + other_scale_np(numpy.ndarray): Generated in function ``get_normal_random_input``, other_scale_np + is an numpy array. It is the second parameter in another Normal object used in ``kl_divergence`` + method. + loc_np(numpy.ndarray): Generated in function ``get_normal_random_input``, An numpy array + whose shape is [batch_size, dims]. + other_loc_np(numpy.ndarray): Generated in function ``get_normal_random_input``, other_loc_np + is an numpy array. It is the first parameter in another Normal object used in ``kl_divergence`` + method. + loc(Tensor): In dynamic mode, loc is generated in ``build_normal_dygraph``, it's a Tensor filled + with ``loc_np`` data. In static mode, loc is generated in ``build_normal_static``, ``layers.data`` + method is used to get a Placeholder whose shape is [dims]. + scale(Tensor): In dynamic mode, scale is generated in ``build_normal_dygraph``, it's a Tensor filled + with ``scale_np`` data. In static mode, scale is generated in ``build_normal_static``, ``layers.data`` + method is used to get a Placeholder whose shape is [dims]. + other_loc(Tensor): In dynamic mode, other_loc is generated in ``build_normal_dygraph``, it's a Tensor + filled with ``other_loc_np`` data. In static mode, other_loc is generated in ``build_normal_static``, + ``layers.data`` method is used to get a Placeholder whose shape is [dims]. It is the first parameter + in another Normal object used in ``kl_divergence`` method. + other_scale(Tensor): In dynamic mode, other_scale is generated in ``build_normal_dygraph``, it's a Tensor + filled with ``other_scale_np`` data. In static mode, other_scale is generated in ``build_normal_static``, + ``layers.data`` method is used to get a Placeholder whose shape is [dims]. It is the second parameter + in another Normal object used in ``kl_divergence`` method. + values(Tensor): In dynamic mode, values is generated in ``build_normal_dygraph``, it's a Tensor filled with + ``values_np`` data. In static mode, values is generated in ``build_normal_static``, ``layers.data`` + method is used to get a Placeholder whose shape is [dims]. + + Returns: + List: The elements of the list are the output of sample, entropy, log_prob, probs, kl_divergence methods. + The inputs' type of these methods can be float, np.ndarray and Tensor. And broadcast will be considered. + + """ normal_int = Normal(int(loc_float), int(scale_float)) normal_float = Normal(loc_float, scale_float) other_normal_float = Normal(other_loc_float, other_scale_float) @@ -130,6 +179,13 @@ class DistributionTest(unittest.TestCase): sample_np = normal_np.sample([batch_size, dims]) sample_variable = normal_variable.sample([batch_size, dims]) + sample_int_diff = normal_int.sample([sample_shape]) + sample_float_diff = normal_float.sample([sample_shape]) + sample_float_np_broadcast_diff = normal_float_np_broadcast.sample( + [sample_shape]) + sample_np_diff = normal_np.sample([sample_shape]) + sample_variable_diff = normal_variable.sample([sample_shape]) + entropy_int = normal_int.entropy() entropy_float = normal_float.entropy() entropy_float_np_broadcast = normal_float_np_broadcast.entropy() @@ -152,7 +208,9 @@ class DistributionTest(unittest.TestCase): fetch_list = [ sample_int, sample_float, sample_float_np_broadcast, sample_np, - sample_variable, entropy_int, entropy_float, + sample_variable, sample_int_diff, sample_float_diff, + sample_float_np_broadcast_diff, sample_np_diff, + sample_variable_diff, entropy_int, entropy_float, entropy_float_np_broadcast, entropy_np, entropy_variable, lp_float_np_broadcast, lp_np, lp_variable, p_float_np_broadcast, p_np, p_variable, kl_float, kl_float_np_broadcast, kl_np, @@ -160,10 +218,22 @@ class DistributionTest(unittest.TestCase): ] return fetch_list - def build_normal_static(self, test_program, batch_size, dims, loc_float, - scale_float, other_loc_float, other_scale_float, - scale_np, other_scale_np, loc_np, other_loc_np, - values_np): + def build_normal_static(self, test_program, batch_size, dims, sample_shape, + loc_float, scale_float, other_loc_float, + other_scale_float, scale_np, other_scale_np, loc_np, + other_loc_np, values_np): + """ + In static mode, generate feed data of Normal network, and get output fetch_list using + ``build_normal_common_net``. + + Args: + test_program: In static mode, the Program object. + other args can refer to function ``build_normal_common_net``. + + Returns: + feed_vars: The feed data of Normal network in static mode. + fetch_list: The output is generated by function ``build_normal_common_net``. + """ with fluid.program_guard(test_program): loc = layers.data(name='loc', shape=[dims], dtype='float32') scale = layers.data(name='scale', shape=[dims], dtype='float32') @@ -176,9 +246,10 @@ class DistributionTest(unittest.TestCase): values = layers.data(name='values', shape=[dims], dtype='float32') fetch_list = self.build_normal_common_net( - batch_size, dims, loc_float, scale_float, other_loc_float, - other_scale_float, scale_np, other_scale_np, loc_np, - other_loc_np, loc, scale, other_loc, other_scale, values) + batch_size, dims, sample_shape, loc_float, scale_float, + other_loc_float, other_scale_float, scale_np, other_scale_np, + loc_np, other_loc_np, loc, scale, other_loc, other_scale, + values) feed_vars = { 'loc': loc_np, @@ -189,9 +260,21 @@ class DistributionTest(unittest.TestCase): } return feed_vars, fetch_list - def build_normal_dygraph(self, batch_size, dims, loc_float, scale_float, - other_loc_float, other_scale_float, scale_np, - other_scale_np, loc_np, other_loc_np, values_np): + def build_normal_dygraph(self, batch_size, dims, sample_shape, loc_float, + scale_float, other_loc_float, other_scale_float, + scale_np, other_scale_np, loc_np, other_loc_np, + values_np): + """ + In dynamic mode, generate input data of Normal network, and get output fetch_list using + ``build_normal_common_net``. + + Args: + refer to function ``build_normal_common_net``. + + Returns: + fetch_list_numpy: The output is generated by function ``build_normal_common_net``. Transform + these tensor to numpy.ndarray. + """ loc = paddle.to_tensor(loc_np) scale = paddle.to_tensor(scale_np) other_loc = paddle.to_tensor(other_loc_np) @@ -199,13 +282,24 @@ class DistributionTest(unittest.TestCase): values = paddle.to_tensor(values_np) fetch_list = self.build_normal_common_net( - batch_size, dims, loc_float, scale_float, other_loc_float, - other_scale_float, scale_np, other_scale_np, loc_np, other_loc_np, - loc, scale, other_loc, other_scale, values) + batch_size, dims, sample_shape, loc_float, scale_float, + other_loc_float, other_scale_float, scale_np, other_scale_np, + loc_np, other_loc_np, loc, scale, other_loc, other_scale, values) fetch_list_numpy = [t.numpy() for t in fetch_list] return fetch_list_numpy def get_normal_random_input(self, batch_size, dims): + """ + Generate input data ``loc`` and ``scale`` used in Normal network. + + Args: + refer to function ``build_normal_common_net``. + + Returns: + List: Different data type of ``loc`` and ``scale``, including float, numpy.ndarray. + By the way, ``other_loc`` and ``other_scale`` are used in ``kl_divergence`` method. + refer to ``args`` in function ``build_normal_common_net``. + """ loc_np = np.random.randn(batch_size, dims).astype('float32') other_loc_np = np.random.randn(batch_size, dims).astype('float32') @@ -237,7 +331,20 @@ class DistributionTest(unittest.TestCase): output_list, batch_size=2, dims=3, + sample_shape=7, tolerance=1e-6): + """ + Compare the outputs of Normal's methods in paddle and numpy. If the outputs are not consistent, + raise errors. + + Args: + data_list: Input data generated by function ``get_normal_random_input``. + output_list: The outputs of Normal's methods in static or dynamic mode. + batch_size(int): The first dimension of the shape of parameters(loc and scale). + dims(int): The second dimension of the shape of parameters. + sample_shape(int): The sample value used in ``sample`` method. + tolerance(float): The tolerance of the error. + """ loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list np_normal_int = NormalNumpy(int(loc_float), int(scale_float)) @@ -254,6 +361,13 @@ class DistributionTest(unittest.TestCase): gt_sample_float_np_broadcast = np_normal_float_np_broadcast.sample( [batch_size, dims]) gt_sample_np = np_normal.sample([batch_size, dims]) + + gt_sample_int_diff = np_normal_int.sample([sample_shape]) + gt_sample_float_diff = np_normal_float.sample([sample_shape]) + gt_sample_float_np_broadcast_diff = np_normal_float_np_broadcast.sample( + [sample_shape]) + gt_sample_np_diff = np_normal.sample([sample_shape]) + gt_entropy_int = np_normal_int.entropy() gt_entropy_float = np_normal_float.entropy() gt_entropy_float_np_broadcast = np_normal_float_np_broadcast.entropy() @@ -271,7 +385,10 @@ class DistributionTest(unittest.TestCase): [ output_sample_int, output_sample_float, output_sample_float_np_broadcast, output_sample_np, - output_sample_variable, output_entropy_int, output_entropy_float, + output_sample_variable, output_sample_int_diff, + output_sample_float_diff, output_sample_float_np_broadcast_diff, + output_sample_np_diff, output_sample_variable_diff, + output_entropy_int, output_entropy_float, output_entropy_float_np_broadcast, output_entropy_np, output_entropy_variable, output_lp_float_np_broadcast, output_lp_np, output_lp_variable, output_p_float_np_broadcast, output_p_np, @@ -279,31 +396,24 @@ class DistributionTest(unittest.TestCase): output_kl_np, output_kl_variable ] = output_list - np.testing.assert_allclose( - output_sample_int.shape, - gt_sample_int.shape, - rtol=tolerance, - atol=tolerance) - np.testing.assert_allclose( - output_sample_float.shape, - gt_sample_float.shape, - rtol=tolerance, - atol=tolerance) - np.testing.assert_allclose( - output_sample_float_np_broadcast.shape, - gt_sample_float_np_broadcast.shape, - rtol=tolerance, - atol=tolerance) - np.testing.assert_allclose( - output_sample_np.shape, - gt_sample_np.shape, - rtol=tolerance, - atol=tolerance) - np.testing.assert_allclose( - output_sample_variable.shape, - gt_sample_np.shape, - rtol=tolerance, - atol=tolerance) + np.testing.assert_equal(output_sample_int.shape, gt_sample_int.shape) + np.testing.assert_equal(output_sample_float.shape, + gt_sample_float.shape) + np.testing.assert_equal(output_sample_float_np_broadcast.shape, + gt_sample_float_np_broadcast.shape) + np.testing.assert_equal(output_sample_np.shape, gt_sample_np.shape) + np.testing.assert_equal(output_sample_variable.shape, + gt_sample_np.shape) + np.testing.assert_equal(output_sample_int_diff.shape, + gt_sample_int_diff.shape) + np.testing.assert_equal(output_sample_float_diff.shape, + gt_sample_float_diff.shape) + np.testing.assert_equal(output_sample_float_np_broadcast_diff.shape, + gt_sample_float_np_broadcast_diff.shape) + np.testing.assert_equal(output_sample_np_diff.shape, + gt_sample_np_diff.shape) + np.testing.assert_equal(output_sample_variable_diff.shape, + gt_sample_np_diff.shape) np.testing.assert_allclose( output_entropy_int, gt_entropy_int, rtol=tolerance, atol=tolerance) np.testing.assert_allclose( @@ -353,15 +463,22 @@ class DistributionTest(unittest.TestCase): def test_normal_distribution_static(self, batch_size=2, dims=3, + sample_shape=7, tolerance=1e-6): + """ + Test Normal's methods in static mode. + + Args: + refer to ``compare_normal_with_numpy`` function. + """ test_program = fluid.Program() data_list = self.get_normal_random_input(batch_size, dims) loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list feed_vars, fetch_list = self.build_normal_static( - test_program, batch_size, dims, loc_float, scale_float, - other_loc_float, other_scale_float, scale_np, other_scale_np, - loc_np, other_loc_np, values_np) + test_program, batch_size, dims, sample_shape, loc_float, + scale_float, other_loc_float, other_scale_float, scale_np, + other_scale_np, loc_np, other_loc_np, values_np) self.executor.run(fluid.default_startup_program()) output_list = self.executor.run(program=test_program, @@ -369,27 +486,62 @@ class DistributionTest(unittest.TestCase): fetch_list=fetch_list) self.compare_normal_with_numpy(data_list, output_list, batch_size, dims, - tolerance) + sample_shape, tolerance) def test_normal_distribution_dygraph(self, batch_size=2, dims=3, + sample_shape=7, tolerance=1e-6): + """ + Test Normal's methods in dynamic mode. + + Args: + refer to ``compare_normal_with_numpy`` function. + """ paddle.disable_static() data_list = self.get_normal_random_input(batch_size, dims) loc_np, other_loc_np, loc_float, scale_float, other_loc_float, other_scale_float, scale_np, other_scale_np, values_np = data_list output_list = self.build_normal_dygraph( - batch_size, dims, loc_float, scale_float, other_loc_float, - other_scale_float, scale_np, other_scale_np, loc_np, other_loc_np, - values_np) + batch_size, dims, sample_shape, loc_float, scale_float, + other_loc_float, other_scale_float, scale_np, other_scale_np, + loc_np, other_loc_np, values_np) self.compare_normal_with_numpy(data_list, output_list, batch_size, dims, - tolerance) + sample_shape, tolerance) paddle.enable_static() - def build_uniform_common_net(self, batch_size, dims, low_float, high_float, - high_np, low_np, values_np, low, high, values): + def build_uniform_common_net(self, batch_size, dims, sample_shape, + low_float, high_float, high_np, low_np, + values_np, low, high, values): + """Generate Uniform object and get the output of its methods including ``sample``, ``entropy``, + ``log_prob`` and ``probs``. + Parameters ``low`` and ``high`` have different data types to test different situations. + + Args: + batch_size(int): The first dimension of the shape of parameters(low and high). + dims(int): The second dimension of the shape of parameters. + sample_shape(int): The sample value used in ``sample`` method. + low_float(float): Parameter ``low`` is a float number. + high_float(float): Parameter ``high`` is a float number. + high_np(numpy.ndarray): An numpy array whose shape is [batch_size, dims]. + low_np(numpy.ndarray): An numpy array whose shape is [batch_size, dims]. + values_np(numpy.ndarray): The input of ``log_prob`` and ``probs`` methods. An numpy array whose + shape is [batch_size, dims]. + low(Tensor): In dynamic mode, low is generated in ``build_uniform_dygraph``, it's a Tensor filled + with ``low_np`` data. In static mode, low is generated in ``build_uniform_static``. + high(Tensor): In dynamic mode, high is generated in ``build_uniform_dygraph``, it's a Tensor filled + with ``high_np`` data. In static mode, high is generated in ``build_uniform_static``. + values(Tensor): In dynamic mode, values is generated in ``build_uniform_dygraph``, it's a Tensor + filled with ``values_np`` data. In static mode, values is generated in ``build_uniform_static``. + + Returns: + List: The elements of the list are the output of sample, entropy, log_prob, probs methods. + The inputs' type of these methods can be float, np.ndarray and Tensor. And broadcast will be + considered. + + """ uniform_int = Uniform(int(low_float), int(high_float)) uniform_float = Uniform(low_float, high_float) uniform_float_np_broadcast = Uniform(low_float, high_np) @@ -403,6 +555,13 @@ class DistributionTest(unittest.TestCase): sample_np = uniform_np.sample([batch_size, dims]) sample_variable = uniform_variable.sample([batch_size, dims]) + sample_int_diff = uniform_int.sample([sample_shape]) + sample_float_diff = uniform_float.sample([sample_shape]) + sample_float_np_broadcast_diff = uniform_float_np_broadcast.sample( + [sample_shape]) + sample_np_diff = uniform_np.sample([sample_shape]) + sample_variable_diff = uniform_variable.sample([sample_shape]) + entropy_int = uniform_int.entropy() entropy_float = uniform_float.entropy() entropy_float_np_broadcast = uniform_float_np_broadcast.entropy() @@ -419,15 +578,29 @@ class DistributionTest(unittest.TestCase): fetch_list = [ sample_int, sample_float, sample_float_np_broadcast, sample_np, - sample_variable, entropy_int, entropy_float, + sample_variable, sample_int_diff, sample_float_diff, + sample_float_np_broadcast_diff, sample_np_diff, + sample_variable_diff, entropy_int, entropy_float, entropy_float_np_broadcast, entropy_np, entropy_variable, lp_float_np_broadcast, lp_np, lp_variable, p_float_np_broadcast, p_np, p_variable ] return fetch_list - def build_uniform_static(self, test_program, batch_size, dims, low_float, - high_float, high_np, low_np, values_np): + def build_uniform_static(self, test_program, batch_size, dims, sample_shape, + low_float, high_float, high_np, low_np, values_np): + """ + In static mode, generate feed data of Uniform network, and get output fetch_list using + ``build_uniform_common_net``. + + Args: + test_program: In static mode, the Program object. + other args can refer to function ``build_uniform_common_net``. + + Returns: + feed_vars: The feed data of Uniform network in static mode. + fetch_list: The output is generated by function ``build_uniform_common_net``. + """ with fluid.program_guard(test_program): low = layers.data(name='low', shape=[dims], dtype='float32') high = layers.data(name='high', shape=[dims], dtype='float32') @@ -435,21 +608,32 @@ class DistributionTest(unittest.TestCase): values = layers.data(name='values', shape=[dims], dtype='float32') fetch_list = self.build_uniform_common_net( - batch_size, dims, low_float, high_float, high_np, low_np, - values_np, low, high, values) + batch_size, dims, sample_shape, low_float, high_float, high_np, + low_np, values_np, low, high, values) feed_vars = {'low': low_np, 'high': high_np, 'values': values_np} return feed_vars, fetch_list - def build_uniform_dygraph(self, batch_size, dims, low_float, high_float, - high_np, low_np, values_np): + def build_uniform_dygraph(self, batch_size, dims, sample_shape, low_float, + high_float, high_np, low_np, values_np): + """ + In dynamic mode, generate input data of Uniform network, and get output fetch_list using + ``build_uniform_common_net``. + + Args: + refer to function ``build_uniform_common_net``. + + Returns: + fetch_list_numpy: The output is generated by function ``build_uniform_common_net``. Transform + these tensor to numpy.ndarray. + """ low = paddle.to_tensor(low_np) high = paddle.to_tensor(high_np) values = paddle.to_tensor(values_np) - fetch_list = self.build_uniform_common_net(batch_size, dims, low_float, - high_float, high_np, low_np, - values_np, low, high, values) + fetch_list = self.build_uniform_common_net( + batch_size, dims, sample_shape, low_float, high_float, high_np, + low_np, values_np, low, high, values) fetch_list_numpy = [t.numpy() for t in fetch_list] return fetch_list_numpy @@ -458,7 +642,20 @@ class DistributionTest(unittest.TestCase): output_list, batch_size=2, dims=3, + sample_shape=7, tolerance=1e-6): + """ + Compare the outputs of Uniform's methods in paddle and numpy. If the outputs are not consistent, + raise errors. + + Args: + data_list: Input data including float and numpy.ndarray type of ``low`` and ``high`` parameters. + output_list: The outputs of Uniform's methods in static or dynamic mode. + batch_size(int): The first dimension of the shape of parameters(low and high). + dims(int): The second dimension of the shape of parameters. + sample_shape(int): The sample value used in ``sample`` method. + tolerance(float): The tolerance of the error. + """ [low_np, low_float, high_float, high_np, values_np] = data_list np_uniform_int = UniformNumpy(int(low_float), int(high_float)) @@ -471,6 +668,11 @@ class DistributionTest(unittest.TestCase): gt_sample_float_np_broadcast = np_uniform_float_np_broadcast.sample( [batch_size, dims]) gt_sample_np = np_uniform.sample([batch_size, dims]) + gt_sample_int_diff = np_uniform_int.sample([sample_shape]) + gt_sample_float_diff = np_uniform_float.sample([sample_shape]) + gt_sample_float_np_broadcast_diff = np_uniform_float_np_broadcast.sample( + [sample_shape]) + gt_sample_np_diff = np_uniform.sample([sample_shape]) gt_entropy_int = np_uniform_int.entropy() gt_entropy_float = np_uniform_float.entropy() gt_entropy_float_np_broadcast = np_uniform_float_np_broadcast.entropy() @@ -484,38 +686,34 @@ class DistributionTest(unittest.TestCase): [ output_sample_int, output_sample_float, output_sample_float_np_broadcast, output_sample_np, - output_sample_variable, output_entropy_int, output_entropy_float, + output_sample_variable, output_sample_int_diff, + output_sample_float_diff, output_sample_float_np_broadcast_diff, + output_sample_np_diff, output_sample_variable_diff, + output_entropy_int, output_entropy_float, output_entropy_float_np_broadcast, output_entropy_np, output_entropy_variable, output_lp_float_np_broadcast, output_lp_np, output_lp_variable, output_p_float_np_broadcast, output_p_np, output_p_variable ] = output_list - np.testing.assert_allclose( - output_sample_int.shape, - gt_sample_int.shape, - rtol=tolerance, - atol=tolerance) - np.testing.assert_allclose( - output_sample_float.shape, - gt_sample_float.shape, - rtol=tolerance, - atol=tolerance) - np.testing.assert_allclose( - output_sample_float_np_broadcast.shape, - gt_sample_float_np_broadcast.shape, - rtol=tolerance, - atol=tolerance) - np.testing.assert_allclose( - output_sample_np.shape, - gt_sample_np.shape, - rtol=tolerance, - atol=tolerance) - np.testing.assert_allclose( - output_sample_variable.shape, - gt_sample_np.shape, - rtol=tolerance, - atol=tolerance) + np.testing.assert_equal(output_sample_int.shape, gt_sample_int.shape) + np.testing.assert_equal(output_sample_float.shape, + gt_sample_float.shape) + np.testing.assert_equal(output_sample_float_np_broadcast.shape, + gt_sample_float_np_broadcast.shape) + np.testing.assert_equal(output_sample_np.shape, gt_sample_np.shape) + np.testing.assert_equal(output_sample_variable.shape, + gt_sample_np.shape) + np.testing.assert_equal(output_sample_int_diff.shape, + gt_sample_int_diff.shape) + np.testing.assert_equal(output_sample_float_diff.shape, + gt_sample_float_diff.shape) + np.testing.assert_equal(output_sample_float_np_broadcast_diff.shape, + gt_sample_float_np_broadcast_diff.shape) + np.testing.assert_equal(output_sample_np_diff.shape, + gt_sample_np_diff.shape) + np.testing.assert_equal(output_sample_variable_diff.shape, + gt_sample_np_diff.shape) np.testing.assert_allclose( output_entropy_int, gt_entropy_int, rtol=tolerance, atol=tolerance) np.testing.assert_allclose( @@ -554,7 +752,14 @@ class DistributionTest(unittest.TestCase): def test_uniform_distribution_static(self, batch_size=2, dims=3, + sample_shape=7, tolerance=1e-6): + """ + Test Uniform's methods in static mode. + + Args: + refer to ``compare_uniform_with_numpy`` function. + """ test_program = fluid.Program() low_np = np.random.randn(batch_size, dims).astype('float32') @@ -567,8 +772,8 @@ class DistributionTest(unittest.TestCase): data_list = [low_np, low_float, high_float, high_np, values_np] feed_vars, fetch_list = self.build_uniform_static( - test_program, batch_size, dims, low_float, high_float, high_np, - low_np, values_np) + test_program, batch_size, dims, sample_shape, low_float, high_float, + high_np, low_np, values_np) self.executor.run(fluid.default_startup_program()) @@ -577,12 +782,19 @@ class DistributionTest(unittest.TestCase): feed=feed_vars, fetch_list=fetch_list) self.compare_uniform_with_numpy(data_list, output_list, batch_size, - dims, tolerance) + dims, sample_shape, tolerance) def test_uniform_distribution_dygraph(self, batch_size=2, dims=3, + sample_shape=7, tolerance=1e-6): + """ + Test Uniform's methods in dynamic mode. + + Args: + refer to ``compare_uniform_with_numpy`` function. + """ paddle.disable_static() low_np = np.random.randn(batch_size, dims).astype('float32') @@ -593,11 +805,12 @@ class DistributionTest(unittest.TestCase): values_np = np.random.randn(batch_size, dims).astype('float32') data_list = [low_np, low_float, high_float, high_np, values_np] - output_list = self.build_uniform_dygraph( - batch_size, dims, low_float, high_float, high_np, low_np, values_np) + output_list = self.build_uniform_dygraph(batch_size, dims, sample_shape, + low_float, high_float, high_np, + low_np, values_np) self.compare_uniform_with_numpy(data_list, output_list, batch_size, - dims, tolerance) + dims, sample_shape, tolerance) paddle.enable_static() -- GitLab