diff --git a/paddle/fluid/operators/conv_op_npu.cc b/paddle/fluid/operators/conv_op_npu.cc index 4065394effa47b72b58dc2319b475a1205b32357..bc62bb5c81570ccdf375b5cdab5c2bf316cb5c40 100644 --- a/paddle/fluid/operators/conv_op_npu.cc +++ b/paddle/fluid/operators/conv_op_npu.cc @@ -126,6 +126,169 @@ class DepthwiseConvNPUKernel : public framework::OpKernel { } }; +template +class NPUConvOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + const Tensor* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + const std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + + const bool channel_last = data_format == "NHWC"; + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = framework::slice_ddim(filter_dims, 2, in_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + std::vector strides_vec(4, 1); + std::vector dilations_vec(4, 1); + + Tensor input_tensor, output_tensor; + input_tensor.ShareDataWith(*input); + output_tensor.ShareDataWith(*output); + if (channel_last) { + input_tensor.set_layout(DataLayout::kNHWC); + output_tensor.set_layout(DataLayout::kNHWC); + strides_vec[1] = strides[0]; + strides_vec[2] = strides[1]; + dilations_vec[1] = dilations[0]; + dilations_vec[2] = dilations[1]; + } else { + strides_vec[2] = strides[0]; + strides_vec[3] = strides[1]; + dilations_vec[2] = dilations[0]; + dilations_vec[3] = dilations[1]; + } + + const auto& runner = + NpuOpRunner("Conv2D", {input_tensor, *filter}, {output_tensor}, + {{"strides", strides_vec}, + {"pads", paddings}, + {"dilations", dilations_vec}, + {"groups", groups}, + {"data_format", data_format}}); + runner.Run(dev_ctx.stream()); + } +}; + +template +class NPUConvGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + + const std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + + const bool channel_last = data_format == "NHWC"; + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = framework::slice_ddim(filter_dims, 2, in_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + std::vector strides_vec(4, 1); + std::vector dilations_vec(4, 1); + + Tensor input_tensor, output_grad_tensor; + input_tensor.ShareDataWith(*input); + output_grad_tensor.ShareDataWith(*output_grad); + if (channel_last) { + input_tensor.set_layout(DataLayout::kNHWC); + output_grad_tensor.set_layout(DataLayout::kNHWC); + strides_vec[1] = strides[0]; + strides_vec[2] = strides[1]; + dilations_vec[1] = dilations[0]; + dilations_vec[2] = dilations[1]; + } else { + strides_vec[2] = strides[0]; + strides_vec[3] = strides[1]; + dilations_vec[2] = dilations[0]; + dilations_vec[3] = dilations[1]; + } + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + std::vector filter_shape_vec = + framework::vectorize(filter->dims()); + + const auto& runner = NpuOpRunner( + "Conv2DBackpropFilterD", {input_tensor, output_grad_tensor}, + {*filter_grad}, {{"filter_size", filter_shape_vec}, + {"strides", strides_vec}, + {"pads", paddings}, + {"dilations", dilations_vec}, + {"groups", groups}, + {"data_format", data_format}}); + runner.Run(dev_ctx.stream()); + } + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + std::vector input_shape_vec = + framework::vectorize(input->dims()); + + Tensor input_grad_tensor; + input_grad_tensor.ShareDataWith(*input_grad); + if (channel_last) { + input_grad_tensor.set_layout(DataLayout::kNHWC); + } + const auto& runner = + NpuOpRunner("Conv2DBackpropInputD", {*filter, output_grad_tensor}, + {input_grad_tensor}, {{"input_size", input_shape_vec}, + {"strides", strides_vec}, + {"pads", paddings}, + {"dilations", dilations_vec}, + {"groups", groups}, + {"data_format", data_format}}); + runner.Run(dev_ctx.stream()); + } + } +}; } // namespace operators } // namespace paddle @@ -135,3 +298,7 @@ REGISTER_OP_NPU_KERNEL( depthwise_conv2d, ops::DepthwiseConvNPUKernel); +REGISTER_OP_NPU_KERNEL(conv2d, ops::NPUConvOpKernel, + ops::NPUConvOpKernel); +REGISTER_OP_NPU_KERNEL(conv2d_grad, ops::NPUConvGradOpKernel, + ops::NPUConvGradOpKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..dff7438702d36b5829109bb97c638c0b1fe0428d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_npu.py @@ -0,0 +1,529 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test import OpTest + +from test_conv2d_op import conv2d_forward_naive + +paddle.enable_static() + + +def create_test_channel_last_class(parent): + class TestChannelLastCase(parent): + def init_data_format(self): + self.data_format = "NHWC" + + def init_test_case_2(self): + N, C, H, W = self.input_size + self.input_size = [N, H, W, C] + + cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast") + TestChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestChannelLastCase + + +def create_test_padding_SAME_class(parent): + class TestPaddingSMAECase(parent): + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") + TestPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestPaddingSMAECase + + +def create_test_padding_VALID_class(parent): + class TestPaddingVALIDCase(parent): + def init_paddings(self): + self.pad = [1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp") + TestPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestPaddingVALIDCase + + +def create_test_fp16_class(parent): + class TestFp16Case(parent): + def init_dtype(self): + self.dtype = np.float16 + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16") + TestFp16Case.__name__ = cls_name + globals()[cls_name] = TestFp16Case + + +class TestConv2DOp(OpTest): + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def init_data_format(self): + self.data_format = "NCHW" + + def setUp(self): + self.set_npu() + self.op_type = "conv2d" + self.init_data_format() + self.init_dtype() + self.init_group() + self.init_dilation() + self.init_test_case() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) + + output, _, _, _, _ = conv2d_forward_naive( + input, + filter, + self.groups, + conv2d_param, + data_format=self.data_format) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'data_format': self.data_format, + } + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output_with_place(fluid.NPUPlace(0), atol=1e-2) + + def test_check_grad(self): + self.check_grad_with_place( + fluid.NPUPlace(0), {'Input', 'Filter'}, + 'Output', + max_relative_error=0.03) + + def test_check_grad_no_filter(self): + self.check_grad_with_place( + fluid.NPUPlace(0), ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad_with_place( + fluid.NPUPlace(0), ['Filter'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + +class TestWithPad(TestConv2DOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + +class TestWithStride(TestConv2DOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 6, 6] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + +class TestWithGroup(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.group = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [18, f_c, 3, 3] + + +class TestWith1x1(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + + def init_group(self): + # FIXME: Supporting group = 3 in this case. + # NOTE(wangran16): There is an unknown error (acl error code is : 507015) + # when group = 3, which needs to be fixed. + self.groups = 1 + + +class TestWithDepthWise5x5(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 4, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [8, f_c, 5, 5] + + def init_group(self): + self.groups = 4 + + +class TestWithDepthWise7x7(TestConv2DOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 8, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [16, f_c, 7, 7] + + def init_group(self): + self.groups = 8 + + +class TestWithDilation(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [2, 2] + + def init_group(self): + self.groups = 3 + + +class TestWithInput1x1Filter1x1(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [100, 1, 1, 1] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + + def init_group(self): + self.groups = 1 + + +class TestConv2DOp_v2(OpTest): + def set_npu(self): + self.__class__.use_npu = True + + def setUp(self): + self.set_npu() + self.op_type = "conv2d" + self.dtype = np.float32 + self.init_kernel_type() + self.init_group() + self.init_dilation() + self.init_data_format() + self.init_test_case() + self.init_paddings() + self.init_test_case_2() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) + output, _, _, _, _ = conv2d_forward_naive( + input, filter, self.groups, conv2d_param, self.padding_algorithm, + self.data_format) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'data_format': self.data_format, + } + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output_with_place(paddle.NPUPlace(0), atol=1e-2) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.NPUPlace(0), {'Input', 'Filter'}, + 'Output', + max_relative_error=0.02) + + def test_check_grad_no_filter(self): + self.check_grad_with_place( + paddle.NPUPlace(0), ['Input'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad_with_place( + paddle.NPUPlace(0), ['Filter'], + 'Output', + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 2] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 4, 3] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + pass + + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" + + def init_data_format(self): + self.data_format = "NCHW" + + def init_test_case_2(self): + pass + + +class TestConv2DOp_AsyPadding(TestConv2DOp_v2): + def init_paddings(self): + self.pad = [0, 0, 1, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithPad_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_paddings(self): + self.pad = [2, 1, 3, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithStride_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [2, 2] + self.input_size = [2, 3, 6, 6] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_paddings(self): + self.pad = [2, 1, 3, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithGroup_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.group = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 4, 3] + + +class TestWith1x1_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + + def init_group(self): + self.groups = 1 + + def init_paddings(self): + self.pad = [2, 2, 4, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise3x3_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [3, 4, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [16, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [2, 2] + + def init_group(self): + self.groups = 4 + + def init_paddings(self): + self.pad = [1, 3, 2, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise5x5_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 4, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [8, f_c, 5, 5] + + def init_group(self): + self.groups = 4 + + def init_paddings(self): + self.pad = [0, 1, 1, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise7x7_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [2, 2] + self.input_size = [2, 8, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [16, f_c, 7, 7] + + def init_group(self): + self.groups = 8 + + def init_paddings(self): + self.pad = [1, 3, 4, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDilation_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [2, 2] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [0, 1, 3, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithInput1x1Filter1x1_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [100, 1, 1, 1] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + + def init_group(self): + self.groups = 1 + + def init_paddings(self): + self.pad = [0, 3, 4, 0] + self.padding_algorithm = "EXPLICIT" + + +create_test_padding_SAME_class(TestConv2DOp_AsyPadding) +create_test_padding_SAME_class(TestWithPad_AsyPadding) +create_test_padding_SAME_class(TestWithStride_AsyPadding) +create_test_padding_SAME_class(TestWithGroup_AsyPadding) +create_test_padding_SAME_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_padding_VALID_class(TestConv2DOp_AsyPadding) +create_test_padding_VALID_class(TestWithPad_AsyPadding) +create_test_padding_VALID_class(TestWithStride_AsyPadding) +create_test_padding_VALID_class(TestWithGroup_AsyPadding) +create_test_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_channel_last_class(TestConv2DOp_AsyPadding) +create_test_channel_last_class(TestWithPad_AsyPadding) +create_test_channel_last_class(TestWithGroup_AsyPadding) +create_test_channel_last_class(TestWith1x1_AsyPadding) +create_test_channel_last_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_fp16_class(TestConv2DOp_AsyPadding) +create_test_fp16_class(TestWithPad_AsyPadding) +create_test_fp16_class(TestWithStride_AsyPadding) +create_test_fp16_class(TestWithGroup_AsyPadding) +create_test_fp16_class(TestWithInput1x1Filter1x1_AsyPadding) + +if __name__ == "__main__": + unittest.main()