From 077f3788b96e6e9c397d61d32e753bc6b110ce24 Mon Sep 17 00:00:00 2001 From: cambriconhsq <106155938+cambriconhsq@users.noreply.github.com> Date: Tue, 14 Jun 2022 19:55:43 +0800 Subject: [PATCH] [MLU] add mlu kernel for depthwise conv2d op (#43359) --- paddle/fluid/operators/conv_op_mlu.cc | 229 +++++++++++++++++ .../mlu/test_conv2d_op_depthwise_conv_mlu.py | 238 ++++++++++++++++++ 2 files changed, 467 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_conv2d_op_depthwise_conv_mlu.py diff --git a/paddle/fluid/operators/conv_op_mlu.cc b/paddle/fluid/operators/conv_op_mlu.cc index c1517dbe16f..b1b39608d62 100644 --- a/paddle/fluid/operators/conv_op_mlu.cc +++ b/paddle/fluid/operators/conv_op_mlu.cc @@ -238,6 +238,228 @@ class MLUConvGradOpKernel : public framework::OpKernel { } } }; + +template +class MLUDepthwiseConvOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + const std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + + const bool channel_last = data_format == "NHWC"; + int groups; + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + auto in_dims_size = in_dims.size(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size()); + std::vector ksize = phi::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + Tensor input_tensor(input->type()); + Tensor output_tensor(output->type()); + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + if (channel_last) { + groups = in_dims[3]; + input_tensor.ShareDataWith(*input); + output_tensor.ShareDataWith(*output); + } else { + // transpose input from NCHW to NHWC + groups = in_dims[1]; + TransposeFromMLUTensor(ctx, perm_to_nhwc, input, &input_tensor, + true /*need_reshape_or_alloc*/); + auto output_dims = output->dims(); + output_tensor.mutable_data( + {output_dims[0], output_dims[2], output_dims[3], output_dims[1]}, + ctx.GetPlace()); + } + input_tensor.set_layout(DataLayout::kNHWC); + output_tensor.set_layout(DataLayout::kNHWC); + + // transpose filter from MCHW to MHWC + Tensor trans_filter(filter->type()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, filter, &trans_filter, + true /*need_reshape_or_alloc*/); + + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc input_desc(input_tensor, data_layout, + ToCnnlDataType(input_tensor.dtype())); + MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, + ToCnnlDataType(trans_filter.type())); + MLUCnnlTensorDesc output_desc(output_tensor, data_layout, + ToCnnlDataType(output_tensor.dtype())); + + MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), + strides.data(), dilations.data(), groups, + ToCnnlDataType()); + + MLUCnnl::ConvolutionForward( + ctx, conv_desc.get(), nullptr /*alpha*/, nullptr /*beta*/, + nullptr /*bias_desc*/, nullptr /*bias_ptr*/, input_desc.get(), + GetBasePtr(&input_tensor), filter_desc.get(), GetBasePtr(&trans_filter), + output_desc.get(), GetBasePtr(&output_tensor)); + + if (!channel_last) { + // transpose output from NHWC to NCHW + const std::vector perm_to_nchw = {0, 3, 1, 2}; + TransposeFromMLUTensor(ctx, perm_to_nchw, &output_tensor, output, + false /*need_reshape_or_alloc*/); + } + } +}; + +template +class MLUDepthwiseConvGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + + const std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + + const bool channel_last = data_format == "NHWC"; + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + auto in_dims_size = in_dims.size(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + int groups; + + if (channel_last) { + in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size()); + + std::vector ksize = phi::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + Tensor input_tensor(input->type()); + Tensor output_grad_tensor(output_grad->type()); + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + const std::vector perm_to_nchw = {0, 3, 1, 2}; + if (channel_last) { + input_tensor.ShareDataWith(*input); + output_grad_tensor.ShareDataWith(*output_grad); + groups = in_dims[3]; + } else { + groups = in_dims[1]; + // transpose input and output_grad from NCHW to NHWC + TransposeFromMLUTensor(ctx, perm_to_nhwc, input, &input_tensor, + true /*need_reshape_or_alloc*/); + TransposeFromMLUTensor(ctx, perm_to_nhwc, output_grad, + &output_grad_tensor, + true /*need_reshape_or_alloc*/); + } + input_tensor.set_layout(DataLayout::kNHWC); + output_grad_tensor.set_layout(DataLayout::kNHWC); + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + + auto filter_grad_dims = filter_grad->dims(); + Tensor temp_filter_grad(filter_grad->type()); + temp_filter_grad.mutable_data( + {filter_grad_dims[0], filter_grad_dims[2], filter_grad_dims[3], + filter_grad_dims[1]}, + ctx.GetPlace()); + + cnnlDataType_t tensor_dtype = ToCnnlDataType(); + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc input_desc(input_tensor, data_layout, tensor_dtype); + MLUCnnlTensorDesc out_grad_desc(output_grad_tensor, data_layout, + tensor_dtype); + MLUCnnlTensorDesc temp_filter_grad_desc(temp_filter_grad, data_layout, + tensor_dtype); + + MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), + strides.data(), dilations.data(), groups, + tensor_dtype); + + MLUCnnl::ConvBackpropFilter( + ctx, conv_desc.get(), input_desc.get(), GetBasePtr(&input_tensor), + out_grad_desc.get(), GetBasePtr(&output_grad_tensor), + temp_filter_grad_desc.get(), GetBasePtr(&temp_filter_grad)); + + // transpose filter_grad from MHWC to MCHW + TransposeFromMLUTensor(ctx, perm_to_nchw, &temp_filter_grad, + filter_grad, false /*need_reshape_or_alloc*/); + } + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + + Tensor input_grad_tensor(input_grad->type()); + if (channel_last) { + input_grad_tensor.ShareDataWith(*input_grad); + } else { + auto input_grad_dims = input_grad->dims(); + input_grad_tensor.mutable_data( + {input_grad_dims[0], input_grad_dims[2], input_grad_dims[3], + input_grad_dims[1]}, + ctx.GetPlace()); + } + input_grad_tensor.set_layout(DataLayout::kNHWC); + + // transpose filter from MCHW to MHWC + Tensor trans_filter(filter->type()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, filter, &trans_filter, + true /*need_reshape_or_alloc*/); + + cnnlDataType_t tensor_dtype = ToCnnlDataType(); + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, tensor_dtype); + MLUCnnlTensorDesc out_grad_desc(output_grad_tensor, data_layout, + tensor_dtype); + MLUCnnlTensorDesc in_grad_desc(input_grad_tensor, data_layout, + tensor_dtype); + + MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), + strides.data(), dilations.data(), groups, + tensor_dtype); + + MLUCnnl::ConvBackpropInput( + ctx, conv_desc.get(), filter_desc.get(), GetBasePtr(&trans_filter), + out_grad_desc.get(), GetBasePtr(&output_grad_tensor), + in_grad_desc.get(), GetBasePtr(&input_grad_tensor)); + + if (!channel_last) { + // transpose input_grad from NHWC to NCHW + TransposeFromMLUTensor(ctx, perm_to_nchw, &input_grad_tensor, + input_grad, false /*need_reshape_or_alloc*/); + } + } + } +}; } // namespace operators } // namespace paddle @@ -249,3 +471,10 @@ REGISTER_OP_MLU_KERNEL(conv2d, ops::MLUConvOpKernel, REGISTER_OP_MLU_KERNEL(conv2d_grad, ops::MLUConvGradOpKernel, ops::MLUConvGradOpKernel); + +REGISTER_OP_MLU_KERNEL(depthwise_conv2d, ops::MLUDepthwiseConvOpKernel, + ops::MLUDepthwiseConvOpKernel); + +REGISTER_OP_MLU_KERNEL(depthwise_conv2d_grad, + ops::MLUDepthwiseConvGradOpKernel, + ops::MLUDepthwiseConvGradOpKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_conv2d_op_depthwise_conv_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_conv2d_op_depthwise_conv_mlu.py new file mode 100644 index 00000000000..8d239732e73 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_conv2d_op_depthwise_conv_mlu.py @@ -0,0 +1,238 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") + +import paddle + +paddle.enable_static() +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test import OpTest +from paddle.fluid import Program, program_guard +from test_conv2d_op_mlu import TestConv2DOp, TestConv2DOp_v2, create_test_padding_SAME_class, create_test_padding_VALID_class, create_test_channel_last_class, create_test_fp16_class + +#----------------TestDepthwiseConv ----- + + +class TestDepthwiseConv(TestConv2DOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConv2(TestConv2DOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConv3(TestConv2DOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConvandFuse(TestConv2DOp): + + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConv2andFuse(TestConv2DOp): + + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConv3andFuse(TestConv2DOp): + + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConv_AsyPadding(TestConv2DOp_v2): + + def init_test_case(self): + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 1, 0, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConv2_AsyPadding(TestConv2DOp_v2): + + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [0, 1, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConv3_AsyPadding(TestConv2DOp_v2): + + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 1, 0, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConvandFuse_AsyPadding(TestConv2DOp_v2): + + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [2, 1, 2, 3] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConv2andFuse_AsyPadding(TestConv2DOp_v2): + + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 1, 1, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConv3andFuse_AsyPadding(TestConv2DOp_v2): + + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 2, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +# depthwise conv2d + +create_test_padding_SAME_class(TestDepthwiseConv_AsyPadding) +create_test_padding_SAME_class(TestDepthwiseConvandFuse_AsyPadding) + +create_test_padding_VALID_class(TestDepthwiseConv_AsyPadding) +create_test_padding_VALID_class(TestDepthwiseConvandFuse_AsyPadding) + +# channel last + +create_test_channel_last_class(TestDepthwiseConv_AsyPadding) +create_test_channel_last_class(TestDepthwiseConvandFuse_AsyPadding) + +create_test_fp16_class(TestDepthwiseConv_AsyPadding) +create_test_fp16_class(TestDepthwiseConvandFuse_AsyPadding) + +# TODO(MLU): Depthwise opration does not support dilation yet +# it will throw an error of CNNL_STATUS_NOT_SUPPORTED. + +if __name__ == '__main__': + unittest.main() -- GitLab