diff --git a/paddle/fluid/operators/conv_transpose_op_mlu.cc b/paddle/fluid/operators/conv_transpose_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..160c16c3de995b8b16a8640c9bd06e539845729e --- /dev/null +++ b/paddle/fluid/operators/conv_transpose_op_mlu.cc @@ -0,0 +1,266 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/conv_transpose_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/phi/kernels/cpu/conv_util.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; + +template +class Conv2DTransposeMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("Input"); + const Tensor* filter = ctx.Input("Filter"); + Tensor* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + std::vector output_padding = + ctx.Attr>("output_padding"); + const std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + const std::string data_format = ctx.Attr("data_format"); + int groups = ctx.Attr("groups"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + + // check dimension + const bool channel_last = data_format == "NHWC"; + + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + auto in_dims_size = in_dims.size(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size()); + + std::vector ksize = phi::vectorize(filter_data_dims); + phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + Tensor input_tensor(input->type()); + Tensor output_tensor(output->type()); + input_tensor.set_layout(DataLayout::kNHWC); + output_tensor.set_layout(DataLayout::kNHWC); + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + + if (channel_last) { + input_tensor.ShareDataWith(*input); + output_tensor.ShareDataWith(*output); + } else { + // transpose input from NCHW to NHWC + TransposeFromMLUTensor(ctx, perm_to_nhwc, input, &input_tensor, + true /*need_reshape_or_alloc*/); + auto output_dims = output->dims(); + output_tensor.mutable_data( + {output_dims[0], output_dims[2], output_dims[3], output_dims[1]}, + ctx.GetPlace()); + } + + // transpose filter from MCHW to MHWC + Tensor trans_filter(filter->type()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, filter, &trans_filter, + true /*need_reshape_or_alloc*/); + + // construct MLU attr + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc input_desc(input_tensor, data_layout, + ToCnnlDataType(input_tensor.dtype())); + MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, + ToCnnlDataType(trans_filter.type())); + MLUCnnlTensorDesc output_desc(output_tensor, data_layout, + ToCnnlDataType(output_tensor.dtype())); + MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), + strides.data(), dilations.data(), groups, + ToCnnlDataType()); + + MLUCnnl::ConvBackpropInput(ctx, conv_desc.get(), filter_desc.get(), + GetBasePtr(&trans_filter), input_desc.get(), + GetBasePtr(&input_tensor), output_desc.get(), + GetBasePtr(&output_tensor)); + + if (!channel_last) { + // transpose output from NHWC to NCHW + const std::vector perm_to_nchw = {0, 3, 1, 2}; + TransposeFromMLUTensor(ctx, perm_to_nchw, &output_tensor, output, + false /*need_reshape_or_alloc*/); + } + } +}; + +template +class Conv2DTransposeGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("Input"); + const Tensor* filter = ctx.Input("Filter"); + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + + if ((!input_grad) && (!filter_grad)) return; + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + const int groups = ctx.Attr("groups"); + std::string padding_algorithm = ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_format); + + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + auto in_dims_size = in_dims.size(); + + const bool channel_last = (data_layout == framework::DataLayout::kNHWC); + + framework::DDim in_data_dims; + if (channel_last) { + in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size()); + } + framework::DDim filter_data_dims = + phi::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = phi::vectorize(filter_data_dims); + phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + Tensor input_tensor(input->type()); + Tensor output_grad_tensor(output_grad->type()); + output_grad_tensor.set_layout(DataLayout::kNHWC); + + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + if (channel_last) { + input_tensor.ShareDataWith(*input); + output_grad_tensor.ShareDataWith(*output_grad); + } else { + // transpose input from NCHW to NHWC + TransposeFromMLUTensor(ctx, perm_to_nhwc, input, &input_tensor, + true /*need_reshape_or_alloc*/); + TransposeFromMLUTensor(ctx, perm_to_nhwc, output_grad, + &output_grad_tensor, + true /*need_reshape_or_alloc*/); + } + + // transpose filter from MCHW to MHWC + Tensor trans_filter(filter->type()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, filter, &trans_filter, + true /*need_reshape_or_alloc*/); + + // MLU descs + cnnlTensorLayout_t data_layout_mlu = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc input_desc(input_tensor, data_layout_mlu, + ToCnnlDataType(input_tensor.dtype())); + MLUCnnlTensorDesc trans_filter_desc(trans_filter, data_layout_mlu, + ToCnnlDataType(trans_filter.type())); + MLUCnnlTensorDesc output_grad_desc( + output_grad_tensor, data_layout_mlu, + ToCnnlDataType(output_grad_tensor.dtype())); + MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), + strides.data(), dilations.data(), groups, + ToCnnlDataType()); + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + Tensor filter_grad_tensor(filter_grad->type()); + // filter_grad always MCHW + // filter_grad_tensor always MHWC + auto filter_grad_dims = filter_grad->dims(); + filter_grad_tensor.mutable_data( + {filter_grad_dims[0], filter_grad_dims[2], filter_grad_dims[3], + filter_grad_dims[1]}, + ctx.GetPlace()); + //} + filter_grad_tensor.set_layout(DataLayout::kNHWC); + + MLUCnnlTensorDesc filter_grad_desc( + filter_grad_tensor, data_layout_mlu, + ToCnnlDataType(filter_grad_tensor.dtype())); + + MLUCnnl::ConvBackpropFilter( + ctx, conv_desc.get(), output_grad_desc.get(), GetBasePtr(output_grad), + input_desc.get(), GetBasePtr(&input_tensor), filter_grad_desc.get(), + GetBasePtr(&filter_grad_tensor)); + // transpose output from MHWC to MCHW + const std::vector perm_to_mchw = {0, 3, 1, 2}; + TransposeFromMLUTensor(ctx, perm_to_mchw, &filter_grad_tensor, + filter_grad, false /*need_reshape_or_alloc*/); + } + + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + Tensor input_grad_tensor(input_grad->type()); + input_tensor.set_layout(DataLayout::kNHWC); + + if (channel_last) { + input_grad_tensor.ShareDataWith(*input_grad); + } else { + auto input_grad_dims = input_grad->dims(); + input_grad_tensor.mutable_data( + {input_grad_dims[0], input_grad_dims[2], input_grad_dims[3], + input_grad_dims[1]}, + ctx.GetPlace()); + } + + MLUCnnlTensorDesc input_grad_desc( + input_grad_tensor, data_layout_mlu, + ToCnnlDataType(input_grad_tensor.dtype())); + + cnnlDataType_t tensor_dtype = ToCnnlDataType(); + cnnlDataType_t dt_onchip = ToCnnlDataType(); + MLUCnnl::Conv2D(ctx, conv_desc.get(), tensor_dtype, dt_onchip, + nullptr /* input_position */, nullptr /* input_scale */, + nullptr /* input_offset */, nullptr /* filter_position */, + nullptr /* filter_scale */, nullptr /* filter_offset */, + output_grad_desc.get(), GetBasePtr(&output_grad_tensor), + trans_filter_desc.get(), GetBasePtr(&trans_filter), + nullptr /* bias_desc*/, nullptr /* bias */, + input_grad_desc.get(), GetBasePtr(&input_grad_tensor)); + if (!channel_last) { + // transpose output from NHWC to NCHW + const std::vector perm_to_nchw = {0, 3, 1, 2}; + TransposeFromMLUTensor(ctx, perm_to_nchw, &input_grad_tensor, + input_grad, false /*need_reshape_or_alloc*/); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(conv2d_transpose, ops::Conv2DTransposeMLUKernel, + ops::Conv2DTransposeMLUKernel); + +REGISTER_OP_MLU_KERNEL(conv2d_transpose_grad, + ops::Conv2DTransposeGradMLUKernel, + ops::Conv2DTransposeGradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index ebb8aae1eb329e72231387e5c28710d6260fa3b1..1763fc56cebf982e9dd2d22203d50edbfc882121 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1159,7 +1159,7 @@ class MLUCnnl { static void ConvBackpropInput( const ExecutionContext& ctx, const cnnlConvolutionDescriptor_t conv_desc, - const cnnlTensorDescriptor_t input_desc, const void* filter, + const cnnlTensorDescriptor_t filter_desc, const void* filter, const cnnlTensorDescriptor_t out_backprop_desc, const void* out_backprop, const cnnlTensorDescriptor_t in_backprop_desc, void* in_backprop); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_conv2d_transposed_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_conv2d_transposed_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..08485978a5f644fff1509319f90616ccd6dca1db --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_conv2d_transposed_op_mlu.py @@ -0,0 +1,661 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +import paddle.nn as nn + +paddle.enable_static() +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid.tests.unittests.op_test import OpTest + + +def conv2dtranspose_forward_naive(input_, filter_, attrs): + padding_algorithm = attrs['padding_algorithm'] + if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: + raise ValueError("Unknown Attr(padding_algorithm): '%s'. " + "It can only be 'SAME' or 'VALID'." % + str(padding_algorithm)) + + if attrs['data_format'] == 'NHWC': + input_ = np.transpose(input_, [0, 3, 1, 2]) + in_n, in_c, in_h, in_w = input_.shape + f_c, f_out_c, f_h, f_w = filter_.shape + groups = attrs['groups'] + assert in_c == f_c + out_c = f_out_c * groups + sub_in_c = in_c // groups + + stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ + 'dilations'] + + # update pad and dilation + def _get_padding_with_SAME(input_shape, kernel_size, kernel_stride): + padding = [] + for input_size, filter_size, stride_size in zip(input_shape, + kernel_size, + kernel_stride): + out_size = int((input_size + stride_size - 1) / stride_size) + pad_sum = np.max( + ((out_size - 1) * stride_size + filter_size - input_size, 0)) + pad_0 = int(pad_sum / 2) + pad_1 = int(pad_sum - pad_0) + padding.append(pad_0) + padding.append(pad_1) + return padding + + ksize = filter_.shape[2:4] + if padding_algorithm == "VALID": + pad = [0, 0, 0, 0] + elif padding_algorithm == "SAME": + dilations = [1, 1] + input_data_shape = input_.shape[2:4] + pad = _get_padding_with_SAME(input_data_shape, ksize, stride) + + pad_h_0, pad_h_1 = pad[0], pad[0] + pad_w_0, pad_w_1 = pad[1], pad[1] + if len(pad) == 4: + pad_h_0, pad_h_1 = pad[0], pad[1] + pad_w_0, pad_w_1 = pad[2], pad[3] + + d_bolck_h = dilations[0] * (f_h - 1) + 1 + d_bolck_w = dilations[1] * (f_w - 1) + 1 + out_h = (in_h - 1) * stride[0] + d_bolck_h + out_w = (in_w - 1) * stride[1] + d_bolck_w + if 'output_size' in attrs: + output_size = attrs['output_size'] + out_h = output_size[0] + pad_h_0 + pad_h_1 + out_w = output_size[1] + pad_w_0 + pad_w_1 + out_pad_h = 0 + out_pad_w = 0 + if 'output_padding' in attrs: + out_pad_h = attrs['output_padding'][0] + out_pad_w = attrs['output_padding'][1] + out = np.zeros((in_n, out_c, out_h + out_pad_h, out_w + out_pad_w), + dtype=input_.dtype) + + for n in range(in_n): + for i in range(in_h): + for j in range(in_w): + for g in range(groups): + input_masked = input_[n, g * sub_in_c:(g + 1) * sub_in_c, i, + j] # (c) + input_masked = np.reshape(input_masked, (sub_in_c, 1, 1)) + input_masked = np.tile(input_masked, (1, f_h, f_w)) + + for k in range(f_out_c): + tmp_out = np.sum( + input_masked * + filter_[g * sub_in_c:(g + 1) * sub_in_c, k, :, :], + axis=0) + i1, i2 = i * stride[0], i * stride[0] + d_bolck_h + j1, j2 = j * stride[1], j * stride[1] + d_bolck_w + out[n, g * f_out_c + k, i1:i2:dilations[0], + j1:j2:dilations[1]] += tmp_out + + out = out[:, :, pad_h_0:out_h - pad_h_1 + out_pad_h, + pad_w_0:out_w - pad_w_1 + out_pad_w] + if attrs['data_format'] == 'NHWC': + out = np.transpose(out, [0, 2, 3, 1]) + return out + + +class TestConv2DTransposeOp(OpTest): + + def setUp(self): + # init as conv transpose + self.dtype = np.float32 + self.set_mlu() + self.need_check_grad = True + self.is_test = False + self.use_cudnn = False + self.use_mkldnn = False + self.output_size = None + self.output_padding = [] + self.data_format = "NCHW" + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" + self.init_op_type() + self.init_test_case() + + input_ = np.random.random(self.input_size).astype(self.dtype) + filter_ = np.random.random(self.filter_size).astype(self.dtype) + + self.inputs = {'Input': input_, 'Filter': filter_} + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'is_test': self.is_test, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format + } + if self.output_size is not None: + self.attrs['output_size'] = self.output_size + + if len(self.output_padding) > 0: + self.attrs['output_padding'] = self.output_padding + + output = conv2dtranspose_forward_naive(input_, filter_, + self.attrs).astype(self.dtype) + + self.outputs = {'Output': output} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_no_input(self): + if self.need_check_grad: + self.check_grad_with_place(self.place, ['Filter'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Input'])) + + def test_check_grad_no_filter(self): + if self.need_check_grad: + self.check_grad_with_place(self.place, ['Input'], + 'Output', + no_grad_set=set(['Filter'])) + + def test_check_grad(self): + if self.need_check_grad: + self.check_grad_with_place(self.place, + set(['Input', 'Filter']), + 'Output', + max_relative_error=0.02) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2d_transpose" + + +class TestWithSymmetricPad(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithAsymmetricPad(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithSAMEPad(TestConv2DTransposeOp): + + def init_test_case(self): + self.stride = [2, 1] + self.dilations = [1, 2] + self.groups = 1 + self.input_size = [2, 3, 6, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 4, 3] + self.padding_algorithm = 'SAME' + + +class TestWithVALIDPad(TestConv2DTransposeOp): + + def init_test_case(self): + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + self.padding_algorithm = 'VALID' + + +class TestWithGroups(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3] + + +class TestWithStride(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithDilation(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithEvenUpsample(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 3, 7, 7] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 5, 5] + + +class TestWithEvenUpsampleOutputPadding(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_padding = [1, 1] + self.input_size = [2, 3, 7, 7] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 5, 5] + + +class Test_NHWC(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithSymmetricPad_NHWC(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithAsymmetricPad_NHWC(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithGroups_NHWC(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 4] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestWithStride_NHWC(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NCHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithDilation_NHWC(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [2, 2] + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithEvenUpsample_NHWC(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + +class TestWithEvenUpsample_NHWC_output_padding(TestConv2DTransposeOp): + + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_padding = [1, 1] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + +class TestMLU_FP16(TestConv2DTransposeOp): + + def init_test_case(self): + self.dtype = np.float16 + self.set_mlu() + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_op_type(self): + self.need_check_grad = False + self.op_type = "conv2d_transpose" + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-2) + + +class TestMLU_NHWC_FP16(TestMLU_FP16): + + def init_test_case(self): + self.dtype = np.float16 + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestMLUWithGroups_NHWC_FP16(TestMLU_FP16): + + def init_test_case(self): + self.dtype = np.float16 + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 4] # NCHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestMLUWithEvenUpsample_NHWC_FP16(TestMLU_FP16): + + def init_test_case(self): + self.dtype = np.float16 + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + +class TestConv2DTransposeAPI(unittest.TestCase): + + def setUp(self): + self.set_mlu() + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_case1(self): + data1 = fluid.layers.data(name='data1', + shape=[3, 5, 5], + dtype='float32') + data2 = fluid.layers.data(name='data2', + shape=[5, 5, 3], + dtype='float32') + out1 = fluid.layers.conv2d_transpose(input=data1, + groups=1, + num_filters=6, + filter_size=3, + data_format='NCHW') + out2 = fluid.layers.conv2d_transpose(input=data2, + groups=1, + num_filters=6, + filter_size=3, + data_format='NHWC') + out3 = fluid.layers.conv2d_transpose(input=data1, + groups=1, + num_filters=6, + filter_size=3, + padding=[[0, 0], [1, 1], [1, 1], + [0, 0]], + data_format='NHWC') + out4 = fluid.layers.conv2d_transpose(input=data1, + groups=3, + num_filters=6, + filter_size=3, + padding=[[0, 0], [0, 0], [2, 1], + [0, 0]], + data_format='NCHW') + out5 = fluid.layers.conv2d_transpose(input=data2, + groups=1, + num_filters=6, + filter_size=3, + padding='SAME', + data_format='NCHW') + out6 = fluid.layers.conv2d_transpose(input=data1, + groups=1, + num_filters=6, + filter_size=3, + padding='VALID', + data_format='NHWC') + out7 = fluid.layers.conv2d_transpose(input=data1, + groups=1, + num_filters=6, + output_size=[7, 7], + padding=[0, 0], + data_format='NHWC') + + data1_np = np.random.random((2, 3, 5, 5)).astype("float32") + data2_np = np.random.random((2, 5, 5, 3)).astype("float32") + + exe = fluid.Executor(self.place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={ + "data1": data1_np, + "data2": data2_np + }, + fetch_list=[out1, out2, out3, out4, out5, out6, out7], + return_numpy=True) + self.assertIsNotNone(results[0]) + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[2]) + self.assertIsNotNone(results[3]) + self.assertIsNotNone(results[4]) + self.assertIsNotNone(results[5]) + self.assertIsNotNone(results[6]) + + +class TestConv2DTransposeOpException(unittest.TestCase): + + def setUp(self): + self.set_mlu() + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_exception(self): + data = fluid.layers.data(name='data', shape=[3, 5, 5], dtype="float32") + + def attr_data_format(): + out = fluid.layers.conv2d_transpose(input=data, + groups=1, + num_filters=6, + filter_size=3, + data_format="NCDHW") + + self.assertRaises(ValueError, attr_data_format) + + def attr_padding_str(): + out = fluid.layers.conv2d_transpose(input=data, + groups=1, + num_filters=6, + filter_size=3, + padding='Vald') + + self.assertRaises(ValueError, attr_padding_str) + + def attr_padding_list(): + out = fluid.layers.conv2d_transpose(input=data, + groups=1, + num_filters=6, + filter_size=3, + padding=[[1, 1], [1, 1], [0, 0], + [0, 0]]) + + self.assertRaises(ValueError, attr_padding_list) + + def attr_padding_with_data_format(): + out = fluid.layers.conv2d_transpose(input=data, + groups=1, + num_filters=6, + filter_size=3, + padding=[[1, 1], [0, 0], [0, 0], + [1, 1]], + data_format='NHWC') + + self.assertRaises(ValueError, attr_padding_with_data_format) + + error_input = fluid.layers.data(name='error_data', + shape=[1], + dtype="float32") + + def error_input_size(): + out = fluid.layers.conv2d_transpose(input=error_input, + groups=1, + num_filters=6, + filter_size=3) + + self.assertRaises(ValueError, error_input_size) + + def error_groups(): + out = fluid.layers.conv2d_transpose(input=data, + groups=0, + num_filters=6, + filter_size=3, + data_format='NHWC') + + self.assertRaises(ValueError, error_groups) + + +class TestConv2DTransposeRepr(unittest.TestCase): + + def setUp(self): + self.set_mlu() + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_case(self): + paddle.disable_static() + x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) + conv = nn.Conv2DTranspose(4, 6, (3, 3), output_padding=1, stride=2) + print(conv) + y_var = conv(x_var) + y_np = y_var.numpy() + self.assertIsNotNone(y_np) + paddle.enable_static() + + +if __name__ == '__main__': + unittest.main()