From 4c115a82feed0a44c840f73c736953d4ab93823d Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+lili0826@users.noreply.github.com> Date: Fri, 20 Aug 2021 14:12:00 +0800 Subject: [PATCH] [NPU] Support npu op depthwise_conv2d (#34853) * add depthwise_conv2d npu * add some tests * Delete test_unique_op_npu.py * delete trans input --- paddle/fluid/operators/conv_op.cc | 9 +- paddle/fluid/operators/conv_op_npu.cc | 137 +++++++++ .../npu/test_conv2d_op_depthwise_conv_npu.py | 283 ++++++++++++++++++ 3 files changed, 426 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/conv_op_npu.cc create mode 100755 python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 1266cfe6081..9defe3262ff 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -194,11 +194,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( paddle::framework::DataTypeToString(input_data_type), paddle::framework::DataTypeToString(filter_data_type))); } +#ifndef PADDLE_WITH_ASCEND_CL if (input_data_type == framework::proto::VarType::FP16) { - PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN, - platform::errors::InvalidArgument( - "float16 can only be used when CUDNN is used")); + PADDLE_ENFORCE_EQ( + library, framework::LibraryType::kCUDNN, + platform::errors::InvalidArgument( + "float16 can only be used when CUDNN or NPU is used")); } +#endif #if PADDLE_WITH_CUDA if (input_data_type == framework::proto::VarType::BF16 && library == framework::LibraryType::kCUDNN) { diff --git a/paddle/fluid/operators/conv_op_npu.cc b/paddle/fluid/operators/conv_op_npu.cc new file mode 100644 index 00000000000..4065394effa --- /dev/null +++ b/paddle/fluid/operators/conv_op_npu.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class DepthwiseConvNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // input + const Tensor* input = context.Input("Input"); + const Tensor* filter = context.Input("Filter"); + // output + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + // attr + const std::vector stride = context.Attr>("strides"); + std::vector padding = context.Attr>("paddings"); + std::vector dilation = context.Attr>("dilations"); + const std::string data_format = context.Attr("data_format"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + + // npu stream + auto stream = + context.template device_context().stream(); + + // check dimension + const bool channel_last = data_format == "NHWC"; + if (channel_last) { + // NHWC + PADDLE_ENFORCE_EQ( + output->dims()[output->dims().size() - 1], + input->dims()[input->dims().size() - 1], + platform::errors::InvalidArgument( + "ShapeError: The output channels must be equal to the " + "input channels. But receivced output channel number is %d " + "and input channel number is %d", + output->dims()[output->dims().size() - 1], + input->dims()[input->dims().size() - 1])); + } else { + // NCHW + PADDLE_ENFORCE_EQ( + output->dims()[1], input->dims()[1], + platform::errors::InvalidArgument( + "ShapeError: The output channels must be equal to the " + "input channels. But receivced output channel number is %d " + "and input channel number is %d", + output->dims()[1], input->dims()[1])); + } + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = framework::slice_ddim(filter_dims, 2, in_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm, + in_data_dims, stride, ksize); + + // Transform filter (n, 1, h, w) --> (1, n, h, w) + Tensor transformed_filter(filter->type()); + transformed_filter.mutable_data({filter->dims()[1], filter->dims()[0], + filter->dims()[2], filter->dims()[3]}, + context.device_context().GetPlace()); + std::vector perm = {1, 0, 2, 3}; + const auto& runner_trans = NpuOpRunner( + "TransposeD", {*filter}, {transformed_filter}, {{"perm", perm}}); + runner_trans.Run(stream); + + // construct NPU attr + std::vector strides(4, 1); + std::vector dilations(4, 1); + + Tensor input_tensor, output_tensor; + input_tensor.ShareDataWith(*input); + output_tensor.ShareDataWith(*output); + + if (channel_last) { + input_tensor.set_layout(DataLayout::kNHWC); + output_tensor.set_layout(DataLayout::kNHWC); + strides[1] = stride[0]; + strides[2] = stride[1]; + dilations[1] = dilation[0]; + dilations[2] = dilation[1]; + } else { + strides[2] = stride[0]; + strides[3] = stride[1]; + dilations[2] = dilation[0]; + dilations[3] = dilation[1]; + } + + // CANN OP + const auto& runner = + NpuOpRunner("DepthwiseConv2D", {input_tensor, transformed_filter}, + {output_tensor}, {{"strides", strides}, + {"dilations", dilations}, + {"pads", padding}, + {"data_format", data_format}}); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + depthwise_conv2d, + ops::DepthwiseConvNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py new file mode 100755 index 00000000000..b62ad1b8b8e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py @@ -0,0 +1,283 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import sys +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +from test_conv2d_op import conv2d_forward_naive + +paddle.enable_static() + + +def create_test_channel_last_class(parent): + class TestChannelLastCase(parent): + def init_data_format(self): + self.data_format = "NHWC" + + def init_test_case_2(self): + N, C, H, W = self.input_size + self.input_size = [N, H, W, C] + + cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast") + TestChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestChannelLastCase + + +def create_test_padding_SAME_class(parent): + class TestPaddingSMAECase(parent): + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") + TestPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestPaddingSMAECase + + +def create_test_padding_VALID_class(parent): + class TestPaddingVALIDCase(parent): + def init_paddings(self): + self.pad = [1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp") + TestPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestPaddingVALIDCase + + +@skip_check_grad_ci( + reason='''Inference only, it doesn't need to call check_grad.''') +class TestDepthwiseConvNPU(OpTest): + def setUp(self): + self.op_type = "depthwise_conv2d" + self.dtype = np.float16 + self.set_npu() + self.init_data_format() + self.init_test_case() + self.init_test_case_2() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) + + output, _, _, _, _ = conv2d_forward_naive(input, filter, self.groups, + conv2d_param, "EXPLICIT", + self.data_format) + + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'data_format': self.data_format, + } + self.outputs = {'Output': output} + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def init_test_case(self): + self.pad = [1, 1] + self.dilations = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_data_format(self): + self.data_format = "NCHW" + + def init_test_case_2(self): + pass + + +class TestDepthwiseConvNPU2(TestDepthwiseConvNPU): + def init_test_case(self): + self.pad = [1, 1] + self.dilations = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + + +class TestDepthwiseConvNPU3(TestDepthwiseConvNPU): + def init_test_case(self): + self.pad = [1, 1] + self.dilations = [2, 2] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + + +class TestDepthwiseConvNPU4(TestDepthwiseConvNPU): + def init_test_case(self): + self.pad = [1, 1] + self.dilations = [2, 2] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + + +@skip_check_grad_ci( + reason='''Inference only, it doesn't need to call check_grad.''') +class TestDepthwiseConvNPU_Padding(OpTest): + def setUp(self): + self.op_type = "depthwise_conv2d" + self.dtype = np.float16 + self.set_npu() + self.init_data_format() + self.init_paddings() + self.init_test_case() + self.init_test_case_2() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) + + output, _, _, _, _ = conv2d_forward_naive( + input, filter, self.groups, conv2d_param, self.padding_algorithm, + self.data_format) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'data_format': self.data_format + } + self.outputs = {'Output': output} + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def init_test_case(self): + self.pad = [1, 1, 0, 1] + self.dilations = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_data_format(self): + self.data_format = "NCHW" + + def init_paddings(self): + self.pad = [1, 1, 0, 1] + self.padding_algorithm = "EXPLICIT" + + def init_test_case_2(self): + pass + + +class TestDepthwiseConvNPU2_Padding(TestDepthwiseConvNPU_Padding): + def init_test_case(self): + self.pad = [1, 1, 0, 1] + self.dilations = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + + def init_paddings(self): + self.pad = [0, 1, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConvNPU3_Padding(TestDepthwiseConvNPU_Padding): + def init_test_case(self): + self.pad = [1, 1, 0, 1] + self.dilations = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + + def init_paddings(self): + self.pad = [2, 1, 2, 3] + self.padding_algorithm = "EXPLICIT" + + +# test channel last +create_test_channel_last_class(TestDepthwiseConvNPU) +create_test_channel_last_class(TestDepthwiseConvNPU2) +create_test_channel_last_class(TestDepthwiseConvNPU_Padding) +create_test_channel_last_class(TestDepthwiseConvNPU2_Padding) + +# test padding SAME +create_test_padding_SAME_class(TestDepthwiseConvNPU_Padding) +create_test_padding_SAME_class(TestDepthwiseConvNPU2_Padding) +create_test_padding_SAME_class(TestDepthwiseConvNPU3_Padding) + +# test padding VALID +create_test_padding_VALID_class(TestDepthwiseConvNPU_Padding) +create_test_padding_VALID_class(TestDepthwiseConvNPU2_Padding) +create_test_padding_VALID_class(TestDepthwiseConvNPU3_Padding) + +if __name__ == '__main__': + unittest.main() -- GitLab