From 2e417b6011b05662602e70f9564681c7e4a7cfd1 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 25 Oct 2017 16:23:46 -0700 Subject: [PATCH] batch norm --- .../v2/framework/tests/test_batch_norm_op.py | 143 +++++++++++++++--- 1 file changed, 121 insertions(+), 22 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_batch_norm_op.py b/python/paddle/v2/framework/tests/test_batch_norm_op.py index b7b071c24da..76c1ff018ac 100644 --- a/python/paddle/v2/framework/tests/test_batch_norm_op.py +++ b/python/paddle/v2/framework/tests/test_batch_norm_op.py @@ -6,16 +6,36 @@ from paddle.v2.framework.op import Operator def _reference_training(x, scale, offset, epsilon, data_format): - if data_format != "NHWC": - raise ValueError("data_format must be NHWC, got %s." % data_format) - x_square = x * x - x_square_sum = np.sum(x_square, (0, 1, 2)) - x_sum = np.sum(x, axis=(0, 1, 2)) - element_count = np.size(x) / int(np.shape(x)[-1]) - mean = x_sum / element_count - var = x_square_sum / element_count - mean * mean - normalized = (x - mean) / np.sqrt(var + epsilon) - return (normalized * scale + offset), mean, var + if data_format == "NCHW": + n, c, h, w = x.shape + x_square = x * x + x_square_sum = np.sum(x_square, (0, 2, 3)) + x_sum = np.sum(x, axis=(0, 2, 3)) + element_count = np.size(x) / int(np.shape(x)[1]) + mean = x_sum / element_count + var = x_square_sum / element_count - mean * mean + mean_tile = np.reshape(mean, (1, c, 1, 1)) + mean_tile = np.tile(mean_tile, (n, 1, h, w)) + var_tile = np.reshape(var, (1, c, 1, 1)) + var_tile = np.tile(var_tile, (n, 1, h, w)) + normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon) + scale_tile = np.reshape(scale, (1, c, 1, 1)) + scale_tile = np.tile(scale_tile, (n, 1, h, w)) + offset_tile = np.reshape(offset, (1, c, 1, 1)) + offset_tile = np.reshape(offset_tile, (1, c, 1, 1)) + y = normalized * scale_tile + offset_tile + return y, mean, var + elif data_format == "NHWC": + x_square = x * x + x_square_sum = np.sum(x_square, (0, 1, 2)) + x_sum = np.sum(x, axis=(0, 1, 2)) + element_count = np.size(x) / int(np.shape(x)[-1]) + mean = x_sum / element_count + var = x_square_sum / element_count - mean * mean + normalized = (x - mean) / np.sqrt(var + epsilon) + return (normalized * scale + offset), mean, var + else: + raise ValueError("Unknown data order.") def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): @@ -28,8 +48,13 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): # grad_x = # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) - if data_format != "NHWC": - raise ValueError("data_format must be NHWC, got %s." % data_format) + + # transfer from (N, C, H, W) to (N, H, W, C) to simplify computation + if data_format == "NCHW": + x = np.transpose(x, (0, 2, 3, 1)) + grad_y = np.transpose(grad_y, (0, 2, 3, 1)) + + # raise ValueError("data_format must be NHWC, got %s." % data_format) grad_x = scale * (grad_y - np.mean( grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean( grad_y * (x - mean), axis=(0, 1, 2)) / @@ -37,6 +62,12 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2)) grad_offset = np.sum(grad_y, axis=(0, 1, 2)) + + # transfer back to N, C, H, W + if data_format == "NCHW": + grad_x = np.transpose(grad_x, (0, 3, 1, 2)) + x = np.transpose(x, (0, 3, 1, 2)) + grad_y = np.transpose(grad_y, (0, 3, 1, 2)) return grad_x, grad_scale, grad_offset @@ -72,39 +103,104 @@ class TestBatchNormOp(OpTest): def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def test_forward_backward(self): - # attr + def test_python(self): data_format = "NHWC" epsilon = 0.00001 momentum = 0.9 + # N, H, W, C: 2, 3, 4, 2 channel_num = 2 x_shape = [2, 3, 4, channel_num] scale_shape = [channel_num] - # input x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32) mean = np.zeros(scale_shape).astype(np.float32) - variance = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + + # run forward + y_out, saved_mean, var_ref = _reference_training( + x_val, scale_val, bias_val, epsilon, "NHWC") + + # + mean_out = saved_mean * (1. - momentum) + momentum * mean + variance_out = var_ref * (1. - momentum) + momentum * variance + saved_variance = 1. / np.sqrt(var_ref + epsilon) + + # running N, C, H, W case + # should produce the same results + x_shape2 = [2, channel_num, 3, 4] + x_val2 = np.transpose(x_val, (0, 3, 1, 2)) + y_out2, saved_mean2, var_ref2 = _reference_training( + x_val2, scale_val, bias_val, epsilon, "NCHW") + + self.__assert_close(saved_mean, saved_mean2, "batch mean") + self.__assert_close(var_ref, var_ref2, "batch variance") + + # transfer (N, C, H, W) back to (N, H, W, C) + y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) + self.__assert_close(y_out, y_out2_trans, "batch variance") + print 'python: NHWC, NCHW, forward checking passed' + + # test backward now + # NHWC + y_grad = np.ones(x_shape).astype(np.float32) + x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( + x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, "NHWC") + + # NCHW + y_grad2 = np.ones(x_shape2).astype(np.float32) + x_grad_ref2, scale_grad_ref2, bias_grad_ref2 = _reference_grad( + x_val2, y_grad2, scale_val, saved_mean2, var_ref2, epsilon, "NCHW") + + self.__assert_close(scale_grad_ref, scale_grad_ref2, "scale gradient") + self.__assert_close(bias_grad_ref, bias_grad_ref2, "bias gradient") + + x_grad_transpose = np.transpose(x_grad_ref2, (0, 2, 3, 1)) + self.__assert_close(x_grad_ref, x_grad_transpose, "x gradient") + print 'python: NHWC, NCHW, backward checking passed' + + def test_forward_backward(self): + # attr + data_format = "NCHW" + epsilon = 0.00001 + momentum = 0.9 + + # N, H, W, C: 2, 3, 4, 2 + n, h, w, c = 2, 3, 4, 2 + + if data_format == "NHWC": + x_shape = [n, h, w, c] + elif data_format == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data type.") + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) # run forward y_out, saved_mean, var_ref = _reference_training( x_val, scale_val, bias_val, epsilon, data_format) - # run backward - mean_out = saved_mean * (1 - momentum) - variance_out = var_ref * (1 - momentum) - saved_variance = 1 / np.sqrt(var_ref + epsilon) + # update moving mean and variance + mean_out = saved_mean * (1. - momentum) + momentum * mean + variance_out = var_ref * (1. - momentum) + momentum * variance + saved_variance = 1. / np.sqrt(var_ref + epsilon) # for gradient test y_grad = np.ones(x_shape).astype(np.float32) x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format) - def test_with_place(place): + def test_with_place(place, tensor_format=data_format): scope = core.Scope() # create input @@ -142,7 +238,7 @@ class TestBatchNormOp(OpTest): SavedVariance="saved_variance", # attrs is_test=False, - tensor_format=data_format, + tensor_format=tensor_format, momentum=momentum, epsilon=epsilon) @@ -162,6 +258,7 @@ class TestBatchNormOp(OpTest): atol = 1e-4 self.__assert_close(variance_out_tensor, variance_out, "variance_out", atol) + print "op test forward passed: ", tensor_format # run backward batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set()) @@ -185,12 +282,14 @@ class TestBatchNormOp(OpTest): self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") + print "op test backward passed: ", tensor_format places = [core.CPUPlace()] if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): places.append(core.GPUPlace(0)) for place in places: test_with_place(place) + print "test forward passed" if __name__ == '__main__': -- GitLab