diff --git a/paddle/fluid/operators/conv_op_npu.cc b/paddle/fluid/operators/conv_op_npu.cc index bc62bb5c81570ccdf375b5cdab5c2bf316cb5c40..5fc39b5fb4dc530a56fab88880ea74b09a2cb326 100644 --- a/paddle/fluid/operators/conv_op_npu.cc +++ b/paddle/fluid/operators/conv_op_npu.cc @@ -126,6 +126,117 @@ class DepthwiseConvNPUKernel : public framework::OpKernel { } }; +template +class DepthwiseConvGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // input + const Tensor* input = context.Input("Input"); + const Tensor* filter = context.Input("Filter"); + // output + auto output_grad = context.Input(framework::GradVarName("Output")); + auto input_grad = context.Output(framework::GradVarName("Input")); + auto filter_grad = context.Output(framework::GradVarName("Filter")); + // attr + const std::vector stride = context.Attr>("strides"); + std::vector padding = context.Attr>("paddings"); + std::vector dilation = context.Attr>("dilations"); + const std::string data_format = context.Attr("data_format"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + + // npu stream + auto stream = + context.template device_context().stream(); + + // check dimension + const bool channel_last = data_format == "NHWC"; + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = framework::slice_ddim(filter_dims, 2, in_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm, + in_data_dims, stride, ksize); + + // Transform filter (n, 1, h, w) --> (1, n, h, w) + Tensor transformed_filter(filter->type()); + transformed_filter.mutable_data({filter->dims()[1], filter->dims()[0], + filter->dims()[2], filter->dims()[3]}, + context.device_context().GetPlace()); + std::vector perm = {1, 0, 2, 3}; + const auto& runner_trans = NpuOpRunner( + "TransposeD", {*filter}, {transformed_filter}, {{"perm", perm}}); + runner_trans.Run(stream); + + // construct NPU attr + std::vector strides(4, 1); + std::vector dilations(4, 1); + + Tensor input_tensor, output_grad_tensor; + input_tensor.ShareDataWith(*input); + output_grad_tensor.ShareDataWith(*output_grad); + if (channel_last) { + input_tensor.set_layout(DataLayout::kNHWC); + output_grad_tensor.set_layout(DataLayout::kNHWC); + strides[1] = stride[0]; + strides[2] = stride[1]; + dilations[1] = dilation[0]; + dilations[2] = dilation[1]; + } else { + strides[2] = stride[0]; + strides[3] = stride[1]; + dilations[2] = dilation[0]; + dilations[3] = dilation[1]; + } + + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + std::vector filter_shape_vec = + framework::vectorize(transformed_filter.dims()); + + const auto& runner = NpuOpRunner( + "DepthwiseConv2DBackpropFilterD", {input_tensor, output_grad_tensor}, + {*filter_grad}, {{"filter_size", filter_shape_vec}, + {"strides", strides}, + {"pads", padding}, + {"dilations", dilations}, + {"data_format", data_format}}); + runner.Run(stream); + } + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + std::vector input_shape_vec = + framework::vectorize(input->dims()); + + Tensor input_grad_tensor; + input_grad_tensor.ShareDataWith(*input_grad); + if (channel_last) { + input_grad_tensor.set_layout(DataLayout::kNHWC); + } + const auto& runner = + NpuOpRunner("DepthwiseConv2DBackpropInputD", + {transformed_filter, output_grad_tensor}, + {input_grad_tensor}, {{"input_size", input_shape_vec}, + {"strides", strides}, + {"pads", padding}, + {"dilations", dilations}, + {"data_format", data_format}}); + runner.Run(stream); + } + } +}; + template class NPUConvOpKernel : public framework::OpKernel { public: @@ -298,6 +409,9 @@ REGISTER_OP_NPU_KERNEL( depthwise_conv2d, ops::DepthwiseConvNPUKernel); +REGISTER_OP_NPU_KERNEL( + depthwise_conv2d_grad, + ops::DepthwiseConvGradNPUKernel); REGISTER_OP_NPU_KERNEL(conv2d, ops::NPUConvOpKernel, ops::NPUConvOpKernel); REGISTER_OP_NPU_KERNEL(conv2d_grad, ops::NPUConvGradOpKernel, diff --git a/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py index b62ad1b8b8e55206a6dccb6891f51ba03038f00c..d1c1e80c218a9246741e1962ec71d709ae5cd3bf 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py @@ -22,8 +22,12 @@ import sys sys.path.append("..") from op_test import OpTest, skip_check_grad_ci from test_conv2d_op import conv2d_forward_naive +from paddle import ParamAttr +from paddle.regularizer import L2Decay +from paddle.nn.initializer import KaimingNormal paddle.enable_static() +SEED = 2021 def create_test_channel_last_class(parent): @@ -279,5 +283,151 @@ create_test_padding_VALID_class(TestDepthwiseConvNPU_Padding) create_test_padding_VALID_class(TestDepthwiseConvNPU2_Padding) create_test_padding_VALID_class(TestDepthwiseConvNPU3_Padding) + +class TestDepthwiseConvNet(unittest.TestCase): + def __init__(self, methodName='runTest'): + super().__init__(methodName=methodName) + + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(2, 4, 16, 16)).astype('float16') + b_np = np.random.random(size=(4, 1, 3, 3)).astype('float16') + if not run_npu: + a_np = a_np.astype('float32') + b_np = b_np.astype('float32') + label_np = np.random.randint(10, size=(2, 10)).astype('float32') + with paddle.static.program_guard(main_prog, startup_prog): + if run_npu: + a = paddle.static.data( + name="a", shape=[2, 4, 16, 16], dtype='float16') + b = paddle.static.data( + name="b", shape=[4, 1, 3, 3], dtype='float16') + else: + a = paddle.static.data( + name="a", shape=[2, 4, 16, 16], dtype='float32') + b = paddle.static.data( + name="b", shape=[4, 1, 3, 3], dtype='float32') + label = paddle.static.data( + name="label", shape=[2, 10], dtype='float32') + + a *= 2.0 + b += 0.01 + fc_1 = paddle.nn.functional.conv2d(a, b, bias=None, groups=4) + if run_npu: + fc_1 = paddle.cast(fc_1, dtype='float32') + fc_1 = paddle.nn.functional.relu(fc_1) + prediction = fluid.layers.fc(input=fc_1, size=10, act='softmax') + + cost = paddle.nn.functional.smooth_l1_loss( + input=prediction, label=label) + loss = paddle.sum(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.00001) + sgd.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + + return pred_res, loss_res + + def test_npu(self): + cpu_pred, cpu_loss = self._test(False) + npu_pred, npu_loss = self._test(True) + + self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-04, atol=1e-03)) + self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-04, atol=1e-03)) + + +class TestDepthwiseConvNet_NHWC(unittest.TestCase): + def __init__(self, methodName='runTest'): + super().__init__(methodName=methodName) + + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(2, 16, 16, 4)).astype('float16') + b_np = np.random.random(size=(4, 1, 3, 3)).astype('float16') + if not run_npu: + a_np = a_np.astype('float32') + b_np = b_np.astype('float32') + label_np = np.random.randint(10, size=(2, 10)).astype('float32') + with paddle.static.program_guard(main_prog, startup_prog): + if run_npu: + a = paddle.static.data( + name="a", shape=[2, 16, 16, 4], dtype='float16') + b = paddle.static.data( + name="b", shape=[4, 1, 3, 3], dtype='float16') + else: + a = paddle.static.data( + name="a", shape=[2, 16, 16, 4], dtype='float32') + b = paddle.static.data( + name="b", shape=[4, 1, 3, 3], dtype='float32') + label = paddle.static.data( + name="label", shape=[2, 10], dtype='float32') + + a *= 2.0 + b += 0.01 + fc_1 = paddle.nn.functional.conv2d( + a, b, bias=None, groups=4, data_format='NHWC') + if run_npu: + fc_1 = paddle.cast(fc_1, dtype='float32') + fc_1 = paddle.nn.functional.relu(fc_1) + prediction = fluid.layers.fc(input=fc_1, size=10, act='softmax') + + cost = paddle.nn.functional.smooth_l1_loss( + input=prediction, label=label) + loss = paddle.sum(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.00001) + sgd.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + + return pred_res, loss_res + + def test_npu(self): + cpu_pred, cpu_loss = self._test(False) + npu_pred, npu_loss = self._test(True) + + self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-04, atol=1e-03)) + self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-04, atol=1e-03)) + + if __name__ == '__main__': unittest.main()