From 151cfff90b4e860baba5a374cf044d6ce8b8c7ff Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sat, 17 Mar 2018 20:51:20 -0700 Subject: [PATCH] add more tests --- .../tests/unittests/test_batch_norm_op.py | 94 ++++++++++++++++--- 1 file changed, 80 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index d5a57bdd73a..f631050e2a7 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -188,14 +188,27 @@ def set_output_grad(scope, outputs, place, feed_dict=None): class TestBatchNormOpInference(OpTest): def setUp(self): + self.op_type = "conv2d" + self.is_test = True self.dtype = np.float32 + self.data_layout = "NCHW" + init_dtype() + init_data_layout() + init_test_case() - def test_python(self): - data_format = "NHWC" epsilon = 0.00001 - - n, h, w, c = 2, 3, 4, 5 - x_shape = [n, h, w, c] + shape = self.shape + if len(shape) == 2: + x_shape = shape + c = x_shape[1] + else: + n, h, w, c = shape[0], shape[1], shape[2], shape[3] + if self.data_layout == "NHWC": + x_shape = [n, h, w, c] + elif self.data_layout == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data layout.") scale_shape = [c] x_val = np.random.random_sample(x_shape).astype(self.dtype) @@ -205,7 +218,64 @@ class TestBatchNormOpInference(OpTest): mean = np.zeros(scale_shape).astype(self.dtype) variance = np.ones(scale_shape).astype(self.dtype) - # run forward + saved_mean = np.zeros(scale_shape).astype(self.dtype) + saved_variance = np.ones(scale_shape).astype(self.dtype) + + y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, + epsilon, self.data_layout).astype(self.dtype) + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(x_val), + 'Scale': OpTest.np_dtype_to_fluid_dtype(scale_val), + 'Bias': OpTest.np_dtype_to_fluid_dtype(bias_val), + 'Mean': OpTest.np_dtype_to_fluid_dtype(mean), + 'Variance': OpTest.np_dtype_to_fluid_dtype(variance) + } + self.attrs = { + 'is_test': self.is_test, + 'epsilon': epsilon, + 'data_layout': self.data_layout + } + self.outputs = { + 'Y': y_out, + 'MeanOut': mean, + 'VarianceOut': variance, + 'SavedMean': saved_mean, + 'SavedVariance': saved_variance + } + + def test_check_output(self): + self.check_output() + + def init_dtype(self): + pass + + def init_data_layout(self): + pass + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + + +class TestBatchNormOpTraining(OpTest): + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + + def test_python_testing(self): + data_format = "NHWC" + epsilon = 0.00001 + + n, h, w, c = 2, 3, 4, 5 + x_shape = [n, h, w, c] + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, epsilon, "NHWC") @@ -218,15 +288,11 @@ class TestBatchNormOpInference(OpTest): # transfer (N, C, H, W) back to (N, H, W, C) y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) - self.__assert_close(y_out, y_out2_trans, "inference output") + self.__assert_close(y_out, y_out2_trans, + "inference outputs of two formats have differences") print 'python: NHWC, NCHW, inference checking passed' - -class TestBatchNormOpTraining(OpTest): - def __assert_close(self, tensor, np_array, msg, atol=1e-4): - self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - - def test_python(self): + def test_python_training(self): data_format = "NHWC" epsilon = 0.00001 momentum = 0.9 @@ -264,7 +330,7 @@ class TestBatchNormOpTraining(OpTest): # transfer (N, C, H, W) back to (N, H, W, C) y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) - self.__assert_close(y_out, y_out2_trans, "batch variance") + self.__assert_close(y_out, y_out2_trans, "batch output") print 'python: NHWC, NCHW, forward checking passed' # test backward now -- GitLab