From 7cb19f575f8ff7e8f4d03fd70a5fc33c76360a36 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Fri, 8 Oct 2021 16:44:01 +0800 Subject: [PATCH] [NPU] BatchNorm support layout of NCL and NLC, test=develop (#35668) * [NPU] support NCL and NCL for BatchNorm, test=develop * [NPU] remove debug files, test=develop * update, test=develop --- paddle/fluid/operators/batch_norm_op_npu.cc | 62 ++++++++++++++----- paddle/fluid/operators/conv_op_npu.cc | 5 -- .../unittests/npu/test_batch_norm_op_npu.py | 54 +++++++++++++++- .../tests/unittests/test_batch_norm_op.py | 37 ++++++++++- 4 files changed, 133 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op_npu.cc b/paddle/fluid/operators/batch_norm_op_npu.cc index dfb620a4e96..791c3656791 100644 --- a/paddle/fluid/operators/batch_norm_op_npu.cc +++ b/paddle/fluid/operators/batch_norm_op_npu.cc @@ -38,11 +38,13 @@ class NPUBatchNormOpKernel : public framework::OpKernel { const auto *x = ctx.Input("X"); const auto &x_dims = x->dims(); - PADDLE_ENFORCE_EQ(x_dims.size(), 4, - platform::errors::InvalidArgument( - "The input tensor X's dimension must equal to 4. But " - "received X's shape = [%s], X's dimension = [%d].", - x_dims, x_dims.size())); + PADDLE_ENFORCE_EQ( + (x_dims.size() == 4UL || x_dims.size() == 3UL), true, + platform::errors::InvalidArgument( + "The input tensor X's dimension must equal to 3 or 4. " + " But got X's shape = [%s], X's dimension = [%d].", + x_dims.to_str(), x_dims.size())); + const auto *running_mean = ctx.Input("Mean"); const auto *running_var = ctx.Input("Variance"); const auto *scale = ctx.Input("Scale"); @@ -51,8 +53,11 @@ class NPUBatchNormOpKernel : public framework::OpKernel { auto *y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); - Tensor x_tensor(x->type()); - Tensor y_tesnor(y->type()); + auto &dev_ctx = ctx.template device_context(); + auto x_tensor = + ctx.AllocateTmpTensor(x->dims(), dev_ctx); + auto y_tesnor = + ctx.AllocateTmpTensor(y->dims(), dev_ctx); x_tensor.ShareDataWith(*x); y_tesnor.ShareDataWith(*y); if (data_layout == DataLayout::kNHWC) { @@ -89,6 +94,18 @@ class NPUBatchNormOpKernel : public framework::OpKernel { sum.mutable_data(running_mean->dims(), ctx.GetPlace()); square_sum.mutable_data(running_mean->dims(), ctx.GetPlace()); + // BNTrainingReduce ONLY support rank = 4 + if (x->dims().size() == 3) { + auto x_shape_vec = framework::vectorize(x->dims()); + if (data_layout == DataLayout::kNCHW) { + x_shape_vec.push_back(1); // expand NCL -> NCL1 + } else { + x_shape_vec.insert(x_shape_vec.begin() + 2, 1); // expand NLC -> NL1C + } + auto x_new_shape = framework::make_ddim(x_shape_vec); + x_tensor.Resize(x_new_shape); + x_tensor.Resize(x_new_shape); + } const auto &runner_reduce = NpuOpRunner("BNTrainingReduce", {x_tensor}, {sum, square_sum}, {{"epsilon", epsilon}}); @@ -127,8 +144,11 @@ class NPUBatchNormGradOpKernel : public framework::OpKernel { use_global_stats = is_test || use_global_stats; - Tensor x_tensor(x->type()); - Tensor dy_tensor(d_y->type()); + auto &dev_ctx = ctx.template device_context(); + auto x_tensor = + ctx.AllocateTmpTensor(x->dims(), dev_ctx); + auto dy_tensor = + ctx.AllocateTmpTensor(d_y->dims(), dev_ctx); x_tensor.ShareDataWith(*x); dy_tensor.ShareDataWith(*d_y); if (data_layout == DataLayout::kNHWC) { @@ -136,14 +156,14 @@ class NPUBatchNormGradOpKernel : public framework::OpKernel { dy_tensor.set_layout(DataLayout::kNHWC); } - Tensor scale_grad_tmp(scale->type()); - Tensor bias_grad_tmp(bias->type()); + auto scale_grad_tmp = + ctx.AllocateTmpTensor(scale->dims(), dev_ctx); + auto bias_grad_tmp = + ctx.AllocateTmpTensor(bias->dims(), dev_ctx); if (d_scale == nullptr) { - scale_grad_tmp.Resize(scale->dims()); d_scale = &scale_grad_tmp; } if (d_bias == nullptr) { - bias_grad_tmp.Resize(bias->dims()); d_bias = &bias_grad_tmp; } @@ -169,9 +189,23 @@ class NPUBatchNormGradOpKernel : public framework::OpKernel { } if (d_x) { d_x->mutable_data(ctx.GetPlace()); - Tensor dx_tensor(d_x->type()); + auto dx_tensor = + ctx.AllocateTmpTensor(d_x->dims(), dev_ctx); dx_tensor.ShareDataWith(*d_x); if (use_global_stats) { + if (x->dims().size() == 3) { + // BNInferGrad only support x rank = 4, + auto x_shape_vec = framework::vectorize(d_x->dims()); + if (data_layout == DataLayout::kNCHW) { + x_shape_vec.push_back(1); // expand NCL -> NCL1 + } else { + x_shape_vec.insert(x_shape_vec.begin() + 2, + 1); // expand NLC -> NL1C + } + auto x_new_shape = framework::make_ddim(x_shape_vec); + dx_tensor.Resize(x_new_shape); + dy_tensor.Resize(x_new_shape); + } const auto *running_var = ctx.Input("Variance"); const auto &runner_infer = NpuOpRunner("BNInferGrad", {dy_tensor, *scale, *running_var}, diff --git a/paddle/fluid/operators/conv_op_npu.cc b/paddle/fluid/operators/conv_op_npu.cc index 86724e06975..47de843d1ac 100644 --- a/paddle/fluid/operators/conv_op_npu.cc +++ b/paddle/fluid/operators/conv_op_npu.cc @@ -186,11 +186,6 @@ class DepthwiseConvGradNPUKernel : public framework::OpKernel { dilations[3] = dilation[1]; } - // LOG(INFO) << "strides = " << framework::make_ddim(strides).to_str(); - // LOG(INFO) << "dilations = " << framework::make_ddim(dilations).to_str(); - // LOG(INFO) << "padding = " << framework::make_ddim(padding).to_str(); - // LOG(INFO) << "data_format = " << data_format; - if (filter_grad) { filter_grad->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/npu/test_batch_norm_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_batch_norm_op_npu.py index 1b8b13a0d27..877f9904f34 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_batch_norm_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_batch_norm_op_npu.py @@ -45,6 +45,14 @@ class TestBatchNormOpInference(unittest.TestCase): if len(shape) == 2: x_shape = shape c = x_shape[1] + if len(shape) == 3: + n, l, c = shape[0], shape[1], shape[2] + if data_layout == "NHWC": # NLC + x_shape = [n, l, c] + elif data_layout == "NCHW": # NCL + x_shape = [n, c, l] + else: + raise ValueError("Unknown data layout.") else: n, h, w, c = shape[0], shape[1], shape[2], shape[3] if data_layout == "NHWC": @@ -117,6 +125,7 @@ class TestBatchNormOpInference(unittest.TestCase): place = core.NPUPlace(0) for data_format in self.data_formats: self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) + self.check_with_place(place, data_format, self.dtype, [3, 8, 5]) def init_kernel_type(self): pass @@ -185,10 +194,19 @@ class TestBatchNormOpTraining(unittest.TestCase): # attr epsilon = self.epsilon momentum = self.momentum - if data_layout == "NCHW": - n, c, h, w = shape[0], shape[1], shape[2], shape[3] + + if len(shape) == 3: + if data_layout == "NHWC": # NLC + n, l, c = shape[0], shape[1], shape[2] + elif data_layout == "NCHW": # NCL + n, c, l = shape[0], shape[1], shape[2] + else: + raise ValueError("Unknown data layout.") else: - n, h, w, c = shape[0], shape[1], shape[2], shape[3] + if data_layout == "NCHW": + n, c, h, w = shape[0], shape[1], shape[2], shape[3] + else: + n, h, w, c = shape[0], shape[1], shape[2], shape[3] scale_shape = [c] np.random.seed(123) @@ -296,6 +314,7 @@ class TestBatchNormOpTraining(unittest.TestCase): for data_format in self.data_formats: test_with_place(core.NPUPlace(0), data_format, [2, 3, 4, 5]) + test_with_place(core.NPUPlace(0), data_format, [3, 8, 5]) def init_kernel_type(self): pass @@ -328,6 +347,17 @@ class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining): ] def reference_grad(self, x, y_grad, scale, mean, var, epsilon, data_format): + x_shape = x.shape + if len(x_shape) == 3: + if data_format == "NCHW": # NCL -> NCL1 + x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1)) + y_grad = np.reshape(y_grad, + (x_shape[0], x_shape[1], x_shape[2], 1)) + else: # NLC -> NL1C + x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2])) + y_grad = np.reshape(y_grad, + (x_shape[0], x_shape[1], 1, x_shape[2])) + if data_format == "NCHW": x = np.transpose(x, (0, 2, 3, 1)) y_grad = np.transpose(y_grad, (0, 2, 3, 1)) @@ -343,6 +373,9 @@ class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining): x = np.transpose(x, (0, 3, 1, 2)) y_grad = np.transpose(y_grad, (0, 3, 1, 2)) + if len(x_shape) == 3: + x_grad = np.reshape(x_grad, x_shape) + return x_grad, grad_scale, grad_offset def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance, @@ -350,6 +383,17 @@ class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining): if data_layout != "NCHW" and data_layout != "NHWC": raise ValueError("Unknown data order.") + x_shape = x.shape + if len(x_shape) == 3: + if data_layout == "NCHW": # NCL -> NCL1 + x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1)) + y_grad = np.reshape(y_grad, + (x_shape[0], x_shape[1], x_shape[2], 1)) + else: # NLC -> NL1C + x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2])) + y_grad = np.reshape(y_grad, + (x_shape[0], x_shape[1], 1, x_shape[2])) + if data_layout == "NCHW": x = np.transpose(x, (0, 2, 3, 1)) @@ -369,6 +413,10 @@ class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining): x_grad, scale_grad, bias_grad = self.reference_grad( x, y_grad, scale, mean, variance, epsilon, data_layout) + if len(x_shape) == 3: + y = np.reshape(y, x_shape) + x_grad = np.reshape(x_grad, x_shape) + return y, mean_out, variance_out, mean, saved_variance, x_grad, scale_grad, bias_grad diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 9eaa69ce644..cce13a8bf3b 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -36,6 +36,11 @@ def _reference_testing(x, scale, offset, mean, var, epsilon, data_format): x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1)) else: x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1])) + if len(x_shape) == 3: + if data_format == "NCHW": # NCL -> NCL1 + x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1)) + else: # NLC -> NL1C + x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2])) if data_format == "NCHW": n, c, h, w = x.shape @@ -55,13 +60,19 @@ def _reference_testing(x, scale, offset, mean, var, epsilon, data_format): else: raise ValueError("Unknown data order.") - if len(x_shape) == 2: + if len(x_shape) == 2 or len(x_shape) == 3: y = np.reshape(y, x_shape) return y def _cal_mean_variance(x, epsilon, data_format): assert data_format in ['NCHW', 'NHWC'] + x_shape = x.shape + if len(x_shape) == 3: + if data_format == "NCHW": # NCL -> NCL1 + x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1)) + else: # NLC -> NL1C + x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2])) x_square = x * x axis = (0, 2, 3) if data_format == 'NCHW' else (0, 1, 2) C = x.shape[1] if data_format == 'NCHW' else x.shape[-1] @@ -76,6 +87,12 @@ def _cal_mean_variance(x, epsilon, data_format): def _reference_training(x, scale, offset, epsilon, data_format): x_shape = x.shape + if len(x_shape) == 3: + if data_format == "NCHW": # NCL -> NCL1 + x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1)) + else: # NLC -> NL1C + x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2])) + if data_format == "NCHW": n, c, h, w = x.shape x_square = x * x @@ -94,7 +111,6 @@ def _reference_training(x, scale, offset, epsilon, data_format): offset_tile = np.reshape(offset, (1, c, 1, 1)) offset_tile = np.reshape(offset_tile, (1, c, 1, 1)) y = normalized * scale_tile + offset_tile - return y, mean, var elif data_format == "NHWC": x_square = x * x x_square_sum = np.sum(x_square, (0, 1, 2)) @@ -104,10 +120,13 @@ def _reference_training(x, scale, offset, epsilon, data_format): var = x_square_sum / element_count - mean * mean normalized = (x - mean) / np.sqrt(var + epsilon) y = normalized * scale + offset - return y, mean, var else: raise ValueError("Unknown data order.") + if len(x_shape) == 3: + y = np.reshape(y, x_shape) + return y, mean, var + def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format): # Use the following formulas to calculate gradients: @@ -124,6 +143,15 @@ def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format): if data_format != "NCHW" and data_format != "NHWC": raise ValueError("Unknown data order.") + x_shape = x.shape + if len(x_shape) == 3: + if data_format == "NCHW": # NCL -> NCL1 + x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1)) + y_grad = np.reshape(y_grad, (x_shape[0], x_shape[1], x_shape[2], 1)) + else: # NLC -> NL1C + x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2])) + y_grad = np.reshape(y_grad, (x_shape[0], x_shape[1], 1, x_shape[2])) + if data_format == "NCHW": x = np.transpose(x, (0, 2, 3, 1)) y_grad = np.transpose(y_grad, (0, 2, 3, 1)) @@ -142,6 +170,9 @@ def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format): x = np.transpose(x, (0, 3, 1, 2)) y_grad = np.transpose(y_grad, (0, 3, 1, 2)) + if len(x_shape) == 3: + x_grad = np.reshape(x_grad, x_shape) + return x_grad, grad_scale, grad_offset -- GitLab