From ad91bfe5d4531c1a95dc968bed25ad99a0cf779f Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Fri, 20 Apr 2018 20:51:27 +0800 Subject: [PATCH] fix a bug in test_batch_norm_op.py (#10094) --- python/paddle/fluid/tests/unittests/test_batch_norm_op.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 7ecf9a1459f..6afb6fa6e75 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -100,6 +100,9 @@ def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format): # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) # transfer from (N, C, H, W) to (N, H, W, C) to simplify computation + if data_format != "NCHW" and data_format != "NHWC": + raise ValueError("Unknown data order.") + if data_format == "NCHW": x = np.transpose(x, (0, 2, 3, 1)) y_grad = np.transpose(y_grad, (0, 2, 3, 1)) @@ -304,7 +307,7 @@ class TestBatchNormOpTraining(unittest.TestCase): # run backward y_grad = np.random.random_sample(shape).astype(np.float32) x_grad, scale_grad, bias_grad = _reference_grad( - x, y_grad, scale, saved_mean, var_ref, epsilon, data_format) + x, y_grad, scale, saved_mean, var_ref, epsilon, data_layout) var_dict = locals() var_dict['y@GRAD'] = y_grad -- GitLab