diff --git a/paddle/fluid/operators/batch_norm_op_mlu.cc b/paddle/fluid/operators/batch_norm_op_mlu.cc index 534af63d2a03fb0fe71769e32e3e9377be5ba68b..0e64b461786cce845f7388a520c09101dcba9c09 100644 --- a/paddle/fluid/operators/batch_norm_op_mlu.cc +++ b/paddle/fluid/operators/batch_norm_op_mlu.cc @@ -106,7 +106,7 @@ class MLUBatchNormOpKernel : public framework::OpKernel { if (ctx.HasInput("MomentumTensor")) { const auto *mom_tensor = ctx.Input("MomentumTensor"); Tensor mom_cpu; - TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu); + framework::TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu); momentum = mom_cpu.data()[0]; } diff --git a/paddle/fluid/operators/conv_op_mlu.cc b/paddle/fluid/operators/conv_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..88698c02dd5daf11d6c5b7d68446d292696977ec --- /dev/null +++ b/paddle/fluid/operators/conv_op_mlu.cc @@ -0,0 +1,251 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; + +template +class MLUConvOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + const std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + + const bool channel_last = data_format == "NHWC"; + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + auto in_dims_size = in_dims.size(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = framework::slice_ddim(filter_dims, 2, in_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + Tensor input_tensor(input->type()); + Tensor output_tensor(output->type()); + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + if (channel_last) { + input_tensor.ShareDataWith(*input); + output_tensor.ShareDataWith(*output); + } else { + // transpose input from NCHW to NHWC + TransposeFromMLUTensor(ctx, perm_to_nhwc, input, &input_tensor, + true /*need_reshape_or_alloc*/); + auto output_dims = output->dims(); + output_tensor.mutable_data( + {output_dims[0], output_dims[2], output_dims[3], output_dims[1]}, + ctx.GetPlace()); + } + input_tensor.set_layout(DataLayout::kNHWC); + output_tensor.set_layout(DataLayout::kNHWC); + + // transpose filter from MCHW to MHWC + Tensor trans_filter(filter->type()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, filter, &trans_filter, + true /*need_reshape_or_alloc*/); + + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc input_desc(input_tensor, data_layout, + ToCnnlDataType(input_tensor.type())); + MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, + ToCnnlDataType(trans_filter.type())); + MLUCnnlTensorDesc output_desc(output_tensor, data_layout, + ToCnnlDataType(output_tensor.type())); + + MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), + strides.data(), dilations.data(), groups, + ToCnnlDataType()); + + MLUCnnl::ConvolutionForward( + ctx, conv_desc.get(), nullptr /*alpha*/, nullptr /*beta*/, + nullptr /*bias_desc*/, nullptr /*bias_ptr*/, input_desc.get(), + GetBasePtr(&input_tensor), filter_desc.get(), GetBasePtr(&trans_filter), + output_desc.get(), GetBasePtr(&output_tensor)); + + if (!channel_last) { + // transpose ouput from NHWC to NCHW + const std::vector perm_to_nchw = {0, 3, 1, 2}; + TransposeFromMLUTensor(ctx, perm_to_nchw, &output_tensor, output, + false /*need_reshape_or_alloc*/); + } + } +}; + +template +class MLUConvGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + + const std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + + const bool channel_last = data_format == "NHWC"; + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + auto in_dims_size = in_dims.size(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = framework::slice_ddim(filter_dims, 2, in_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + Tensor input_tensor(input->type()); + Tensor output_grad_tensor(output_grad->type()); + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + const std::vector perm_to_nchw = {0, 3, 1, 2}; + if (channel_last) { + input_tensor.ShareDataWith(*input); + output_grad_tensor.ShareDataWith(*output_grad); + } else { + // transpose input and output_grad from NCHW to NHWC + TransposeFromMLUTensor(ctx, perm_to_nhwc, input, &input_tensor, + true /*need_reshape_or_alloc*/); + TransposeFromMLUTensor(ctx, perm_to_nhwc, output_grad, + &output_grad_tensor, + true /*need_reshape_or_alloc*/); + } + input_tensor.set_layout(DataLayout::kNHWC); + output_grad_tensor.set_layout(DataLayout::kNHWC); + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + + auto filter_grad_dims = filter_grad->dims(); + Tensor temp_filter_grad(filter_grad->type()); + temp_filter_grad.mutable_data( + {filter_grad_dims[0], filter_grad_dims[2], filter_grad_dims[3], + filter_grad_dims[1]}, + ctx.GetPlace()); + + cnnlDataType_t tensor_dtype = ToCnnlDataType(); + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc input_desc(input_tensor, data_layout, tensor_dtype); + MLUCnnlTensorDesc out_grad_desc(output_grad_tensor, data_layout, + tensor_dtype); + MLUCnnlTensorDesc temp_filter_grad_desc(temp_filter_grad, data_layout, + tensor_dtype); + + MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), + strides.data(), dilations.data(), groups, + tensor_dtype); + + MLUCnnl::ConvBackpropFilter( + ctx, conv_desc.get(), input_desc.get(), GetBasePtr(&input_tensor), + out_grad_desc.get(), GetBasePtr(&output_grad_tensor), + temp_filter_grad_desc.get(), GetBasePtr(&temp_filter_grad)); + + // transpose filter_grad from MHWC to MCHW + TransposeFromMLUTensor(ctx, perm_to_nchw, &temp_filter_grad, + filter_grad, false /*need_reshape_or_alloc*/); + } + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + + Tensor input_grad_tensor(input_grad->type()); + if (channel_last) { + input_grad_tensor.ShareDataWith(*input_grad); + } else { + auto input_grad_dims = input_grad->dims(); + input_grad_tensor.mutable_data( + {input_grad_dims[0], input_grad_dims[2], input_grad_dims[3], + input_grad_dims[1]}, + ctx.GetPlace()); + } + input_grad_tensor.set_layout(DataLayout::kNHWC); + + // transpose filter from MCHW to MHWC + Tensor trans_filter(filter->type()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, filter, &trans_filter, + true /*need_reshape_or_alloc*/); + + cnnlDataType_t tensor_dtype = ToCnnlDataType(); + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, tensor_dtype); + MLUCnnlTensorDesc out_grad_desc(output_grad_tensor, data_layout, + tensor_dtype); + MLUCnnlTensorDesc in_grad_desc(input_grad_tensor, data_layout, + tensor_dtype); + + MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), + strides.data(), dilations.data(), groups, + tensor_dtype); + + MLUCnnl::ConvBackpropInput( + ctx, conv_desc.get(), filter_desc.get(), GetBasePtr(&trans_filter), + out_grad_desc.get(), GetBasePtr(&output_grad_tensor), + in_grad_desc.get(), GetBasePtr(&input_grad_tensor)); + + if (!channel_last) { + // transpose input_grad from NHWC to NCHW + TransposeFromMLUTensor(ctx, perm_to_nchw, &input_grad_tensor, + input_grad, false /*need_reshape_or_alloc*/); + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(conv2d, ops::MLUConvOpKernel, + ops::MLUConvOpKernel); + +REGISTER_OP_MLU_KERNEL(conv2d_grad, ops::MLUConvGradOpKernel, + ops::MLUConvGradOpKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 67b6b3ec1614dd51adc62cf418d9eadadf276ca9..82d7c56aea1234bec8c1d22cc10717c100fbe369 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1137,5 +1137,28 @@ class MLUCnnl { void* output); }; +template +inline void TransposeFromMLUTensor(const ExecutionContext& ctx, + const std::vector perm, + const Tensor* transformed_input, + Tensor* transformed_output, + bool need_reshape_or_alloc) { + auto in_dims_vec = framework::vectorize(transformed_input->dims()); + if (need_reshape_or_alloc) { + transformed_output->mutable_data( + {in_dims_vec[perm[0]], in_dims_vec[perm[1]], in_dims_vec[perm[2]], + in_dims_vec[perm[3]]}, + ctx.GetPlace()); + } + MLUCnnlTensorDesc trans_in_desc(*transformed_input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType()); + MLUCnnlTensorDesc trans_out_desc(*transformed_output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType()); + + MLUCnnl::Transpose(ctx, perm, in_dims_vec.size(), trans_in_desc.get(), + GetBasePtr(transformed_input), trans_out_desc.get(), + GetBasePtr(transformed_output)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/top_k_op_mlu.cc b/paddle/fluid/operators/top_k_op_mlu.cc index affe5a4bc6c2dc603fe5a4cc4ef91c297ec81d59..e5064ed90d5d718c63aaf68c5630f12b7032483a 100644 --- a/paddle/fluid/operators/top_k_op_mlu.cc +++ b/paddle/fluid/operators/top_k_op_mlu.cc @@ -33,8 +33,7 @@ class TopkMLUKernel : public framework::OpKernel { auto k_t_ptr = static_cast(k_t->data()); auto size = k_t->numel() * sizeof(int); memory::Copy(platform::CPUPlace(), reinterpret_cast(&k), - BOOST_GET_CONST(platform::MLUPlace, k_t->place()), k_t_ptr, - size, nullptr); + k_t->place(), k_t_ptr, size, nullptr); framework::DDim output_dims = output->dims(); output_dims[output_dims.size() - 1] = k; output->Resize(output_dims); diff --git a/paddle/fluid/operators/top_k_v2_op_mlu.cc b/paddle/fluid/operators/top_k_v2_op_mlu.cc index 08c960186bafeb59ab6657e2445a3d5a9c58b6ab..cc05e11495b7bbe278cd79aa09cb35077e659d05 100644 --- a/paddle/fluid/operators/top_k_v2_op_mlu.cc +++ b/paddle/fluid/operators/top_k_v2_op_mlu.cc @@ -43,8 +43,7 @@ class TopkV2MLUKernel : public framework::OpKernel { auto k_t_ptr = static_cast(k_t->data()); auto size = k_t->numel() * sizeof(int); memory::Copy(platform::CPUPlace(), reinterpret_cast(&k), - BOOST_GET_CONST(platform::MLUPlace, k_t->place()), k_t_ptr, - size, nullptr); + k_t->place(), k_t_ptr, size, nullptr); framework::DDim output_dims = output->dims(); // accroding to axis to set K value in the dim output_dims[axis] = k; diff --git a/python/paddle/fluid/tests/unittests/mlu/test_conv2d_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_conv2d_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..b09d892554bab6dc2951a72d72773935e5f60ddb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_conv2d_op_mlu.py @@ -0,0 +1,555 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test import OpTest + +from test_conv2d_op import conv2d_forward_naive + +paddle.enable_static() + + +def create_test_channel_last_class(parent): + class TestChannelLastCase(parent): + def init_data_format(self): + self.data_format = "NHWC" + + def init_test_case_2(self): + N, C, H, W = self.input_size + self.input_size = [N, H, W, C] + + cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast") + TestChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestChannelLastCase + + +def create_test_padding_SAME_class(parent): + class TestPaddingSMAECase(parent): + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") + TestPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestPaddingSMAECase + + +def create_test_padding_VALID_class(parent): + class TestPaddingVALIDCase(parent): + def init_paddings(self): + self.pad = [1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp") + TestPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestPaddingVALIDCase + + +def create_test_fp16_class(parent): + class TestFp16Case(parent): + def init_dtype(self): + self.dtype = np.float16 + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16") + TestFp16Case.__name__ = cls_name + globals()[cls_name] = TestFp16Case + + +class TestConv2DOp(OpTest): + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def init_data_format(self): + self.data_format = "NCHW" + + def setUp(self): + self.set_mlu() + self.op_type = "conv2d" + self.init_data_format() + self.init_dtype() + self.init_group() + self.init_dilation() + self.init_test_case() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) + + output, _, _, _, _ = conv2d_forward_naive( + input, + filter, + self.groups, + conv2d_param, + data_format=self.data_format) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'data_format': self.data_format, + } + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-2) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, {'Input', 'Filter'}, + 'Output', + max_relative_error=0.03, + numeric_place=paddle.CPUPlace()) + + def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Filter']), + numeric_place=paddle.CPUPlace()) + + def test_check_grad_no_input(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Filter'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Input']), + numeric_place=paddle.CPUPlace()) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + +class TestWithPad(TestConv2DOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + +class TestWithStride(TestConv2DOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 6, 6] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + +class TestWithGroup(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.group = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [18, f_c, 3, 3] + + +class TestWith1x1(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + + def init_group(self): + # FIXME: Supporting group = 3 in this case. + # NOTE(wangran16): There is an unknown error (acl error code is : 507015) + # when group = 3, which needs to be fixed. + self.groups = 1 + + +class TestWithDepthWise5x5(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 4, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [8, f_c, 5, 5] + + def init_group(self): + self.groups = 4 + + +class TestWithDepthWise7x7(TestConv2DOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 8, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [16, f_c, 7, 7] + + def init_group(self): + self.groups = 8 + + +class TestWithDilation(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [12, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [2, 2] + + # TODO(MLU): Depthwise opration does not support dilation yet + # it will throw an error of CNNL_STATUS_NOT_SUPPORTED. + # def init_group(self): + # self.groups = 3 + + +class TestWithInput1x1Filter1x1(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [100, 1, 1, 1] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + + def init_group(self): + self.groups = 1 + + +class TestConv2DOp_v2(OpTest): + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def setUp(self): + self.set_mlu() + self.op_type = "conv2d" + self.dtype = np.float32 + self.init_kernel_type() + self.init_group() + self.init_dilation() + self.init_data_format() + self.init_test_case() + self.init_paddings() + self.init_test_case_2() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) + output, _, _, _, _ = conv2d_forward_naive( + input, filter, self.groups, conv2d_param, self.padding_algorithm, + self.data_format) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'data_format': self.data_format, + } + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-2) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, {'Input', 'Filter'}, + 'Output', + max_relative_error=0.02, + numeric_place=paddle.CPUPlace()) + + def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Input'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Filter']), + numeric_place=paddle.CPUPlace()) + + def test_check_grad_no_input(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Filter'], + 'Output', + no_grad_set=set(['Input']), + numeric_place=paddle.CPUPlace()) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 2] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 4, 3] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + pass + + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" + + def init_data_format(self): + self.data_format = "NCHW" + + def init_test_case_2(self): + pass + + +class TestConv2DOp_AsyPadding(TestConv2DOp_v2): + def init_paddings(self): + self.pad = [0, 0, 1, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithPad_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_paddings(self): + self.pad = [2, 1, 3, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithStride_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [2, 2] + self.input_size = [2, 3, 6, 6] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_paddings(self): + self.pad = [2, 1, 3, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithGroup_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.group = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 4, 3] + + +class TestWith1x1_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + + def init_group(self): + self.groups = 1 + + def init_paddings(self): + self.pad = [2, 2, 4, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise3x3_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [3, 4, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [16, f_c, 3, 3] + + # TODO(MLU): Depthwise opration does not support dilation yet + # it will throw an error of CNNL_STATUS_NOT_SUPPORTED. + # def init_dilation(self): + # self.dilations = [2, 2] + + def init_group(self): + self.groups = 4 + + def init_paddings(self): + self.pad = [1, 3, 2, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise5x5_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 4, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [8, f_c, 5, 5] + + def init_group(self): + self.groups = 4 + + def init_paddings(self): + self.pad = [0, 1, 1, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise7x7_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [2, 2] + self.input_size = [2, 8, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [16, f_c, 7, 7] + + def init_group(self): + self.groups = 8 + + def init_paddings(self): + self.pad = [1, 3, 4, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDilation_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [2, 2] + + # TODO(MLU): Depthwise opration does not support dilation yet + # it will throw an error of CNNL_STATUS_NOT_SUPPORTED. + # def init_group(self): + # self.groups = 3 + + def init_paddings(self): + self.pad = [0, 1, 3, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithInput1x1Filter1x1_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [100, 1, 1, 1] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + + def init_group(self): + self.groups = 1 + + def init_paddings(self): + self.pad = [0, 3, 4, 0] + self.padding_algorithm = "EXPLICIT" + + +create_test_padding_SAME_class(TestConv2DOp_AsyPadding) +create_test_padding_SAME_class(TestWithPad_AsyPadding) +create_test_padding_SAME_class(TestWithStride_AsyPadding) +create_test_padding_SAME_class(TestWithGroup_AsyPadding) +create_test_padding_SAME_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_padding_VALID_class(TestConv2DOp_AsyPadding) +create_test_padding_VALID_class(TestWithPad_AsyPadding) +create_test_padding_VALID_class(TestWithStride_AsyPadding) +create_test_padding_VALID_class(TestWithGroup_AsyPadding) +create_test_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_channel_last_class(TestConv2DOp_AsyPadding) +create_test_channel_last_class(TestWithPad_AsyPadding) +create_test_channel_last_class(TestWithGroup_AsyPadding) +create_test_channel_last_class(TestWith1x1_AsyPadding) +create_test_channel_last_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_fp16_class(TestConv2DOp_AsyPadding) +create_test_fp16_class(TestWithPad_AsyPadding) +create_test_fp16_class(TestWithStride_AsyPadding) +create_test_fp16_class(TestWithGroup_AsyPadding) +create_test_fp16_class(TestWithInput1x1Filter1x1_AsyPadding) + +if __name__ == "__main__": + unittest.main()