From c7de7440fea357489947abfbc013c3ae8aacada7 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Wed, 19 Jan 2022 10:02:07 +0800 Subject: [PATCH] Add conv2d_transpose and conv2d_transpose_grad for XPU,test=kunlun (#38956) --- .../fluid/operators/conv_transpose_op_xpu.cc | 175 +++++++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 5 +- .../xpu/test_conv2d_transpose_op_xpu.py | 271 ++++++++++++++++++ 3 files changed, 450 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/conv_transpose_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_conv2d_transpose_op_xpu.py diff --git a/paddle/fluid/operators/conv_transpose_op_xpu.cc b/paddle/fluid/operators/conv_transpose_op_xpu.cc new file mode 100644 index 0000000000..7cefb9298e --- /dev/null +++ b/paddle/fluid/operators/conv_transpose_op_xpu.cc @@ -0,0 +1,175 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/conv_transpose_op.h" +#include +#include +#include +#include "paddle/fluid/platform/device/device_wrapper.h" +#ifdef PADDLE_WITH_XPU +namespace paddle { +namespace operators { + +// target_len == 2 || target_len == 4 +inline std::vector vector_extend(const std::vector& src, + int target_len) { + if (target_len == 2 && src.size() == 1) { + return {src[0], src[0]}; + } + if (target_len == 4 && src.size() == 1) { + return {src[0], src[0], src[0], src[0]}; + } + if (target_len == 4 && src.size() == 2) { + return {src[0], src[0], src[1], src[1]}; + } + return src; +} + +template +class Conv2DTransposeXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + int groups = context.Attr("groups"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); + const std::string data_format = context.Attr("data_format"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + + PADDLE_ENFORCE_EQ( + data_format == "NHWC" || data_format == "NDHWC", false, + platform::errors::InvalidArgument( + ("XPU do support data_format is NCHW in conv_transpose op."))); + + framework::DDim in_data_dims = + framework::slice_ddim(input->dims(), 2, input->dims().size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter.dims(), 2, filter.dims().size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + const int batch_size = static_cast(input->dims()[0]); + const int img_yc = static_cast(input->dims()[1]); + const int img_yh = static_cast(input->dims()[2]); + const int img_yw = static_cast(input->dims()[3]); + const int img_xc = static_cast(output->dims()[1]); + const int img_xh = static_cast(output->dims()[2]); + const int img_xw = static_cast(output->dims()[3]); + + { + std::vector ksize_check = vector_extend(ksize, 2); + std::vector stride_check = vector_extend(strides, 2); + std::vector pad_check = vector_extend(paddings, 4); + std::vector dilation_check = vector_extend(dilations, 2); + + int xh_check = (img_yh - 1) * stride_check[0] - pad_check[0] - + pad_check[1] + + (dilation_check[0] * (ksize_check[0] - 1) + 1); + int xw_check = (img_yw - 1) * stride_check[1] - pad_check[2] - + pad_check[3] + + (dilation_check[1] * (ksize_check[1] - 1) + 1); + + PADDLE_ENFORCE_EQ( + xh_check == img_xh && xw_check == img_xw, true, + platform::errors::InvalidArgument( + ("XPU output size check error in conv_transpose op."))); + } + + auto& dev_ctx = context.template device_context(); + int r = xpu::conv2d_transpose( + dev_ctx.x_context(), input->data(), filter.data(), + output->data(), batch_size, img_yc, img_yh, img_yw, img_xc, + ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose"); + } +}; + +template +class Conv2DTransposeGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + if (!input_grad && !filter_grad) return; + int groups = context.Attr("groups"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); + const std::string data_format = context.Attr("data_format"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + + PADDLE_ENFORCE_EQ( + data_format == "NHWC" || data_format == "NDHWC", false, + platform::errors::InvalidArgument( + ("XPU do support data_format is NCHW in conv grad op."))); + + framework::DDim in_data_dims = + framework::slice_ddim(input->dims(), 2, input->dims().size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter.dims(), 2, filter.dims().size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + const int batch_size = static_cast(input->dims()[0]); + const int img_yc = static_cast(input->dims()[1]); + const int img_yh = static_cast(input->dims()[2]); + const int img_yw = static_cast(input->dims()[3]); + const int img_xc = static_cast(output_grad->dims()[1]); + const int img_xh = static_cast(output_grad->dims()[2]); + const int img_xw = static_cast(output_grad->dims()[3]); + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + } + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + } + + auto& dev_ctx = context.template device_context(); + int r = xpu::conv2d_transpose_grad( + dev_ctx.x_context(), input->data(), filter.data(), + output_grad->data(), input_grad ? input_grad->data() : nullptr, + filter_grad ? filter_grad->data() : nullptr, batch_size, img_yc, + img_yh, img_yw, img_xc, img_xh, img_xw, ksize, strides, paddings, + dilations, groups, nullptr, nullptr, nullptr, nullptr, nullptr, true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad"); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + conv2d_transpose, + ops::Conv2DTransposeXPUKernel); +REGISTER_OP_XPU_KERNEL(conv2d_transpose_grad, + ops::Conv2DTransposeGradXPUKernel< + paddle::platform::XPUDeviceContext, float>); +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 3d140b4693..f83e3f6d0d 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -53,6 +53,10 @@ XPUOpMap& get_kl2_ops() { {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"conv2d_transpose_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"conv2d_transpose", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"depthwise_conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"depthwise_conv2d", @@ -283,7 +287,6 @@ XPUOpMap& get_kl2_ops() { {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_conv2d_transpose_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_conv2d_transpose_op_xpu.py new file mode 100644 index 0000000000..b4f9f639ac --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_conv2d_transpose_op_xpu.py @@ -0,0 +1,271 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +sys.path.append("..") +import unittest +import numpy as np + +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test_xpu import XPUOpTest +import paddle +import paddle.nn as nn +from paddle.fluid import Program, program_guard + + +def conv2dtranspose_forward_naive(input_, filter_, attrs): + padding_algorithm = attrs['padding_algorithm'] + if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: + raise ValueError("Unknown Attr(padding_algorithm): '%s'. " + "It can only be 'SAME' or 'VALID'." % + str(padding_algorithm)) + + if attrs['data_format'] == 'NHWC': + input_ = np.transpose(input_, [0, 3, 1, 2]) + in_n, in_c, in_h, in_w = input_.shape + f_c, f_out_c, f_h, f_w = filter_.shape + groups = attrs['groups'] + assert in_c == f_c + out_c = f_out_c * groups + sub_in_c = in_c // groups + + stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ + 'dilations'] + + # update pad and dilation + def _get_padding_with_SAME(input_shape, kernel_size, kernel_stride): + padding = [] + for input_size, filter_size, stride_size in zip( + input_shape, kernel_size, kernel_stride): + out_size = int((input_size + stride_size - 1) / stride_size) + pad_sum = np.max(( + (out_size - 1) * stride_size + filter_size - input_size, 0)) + pad_0 = int(pad_sum / 2) + pad_1 = int(pad_sum - pad_0) + padding.append(pad_0) + padding.append(pad_1) + return padding + + ksize = filter_.shape[2:4] + if padding_algorithm == "VALID": + pad = [0, 0, 0, 0] + elif padding_algorithm == "SAME": + dilations = [1, 1] + input_data_shape = input_.shape[2:4] + pad = _get_padding_with_SAME(input_data_shape, ksize, stride) + + pad_h_0, pad_h_1 = pad[0], pad[0] + pad_w_0, pad_w_1 = pad[1], pad[1] + if len(pad) == 4: + pad_h_0, pad_h_1 = pad[0], pad[1] + pad_w_0, pad_w_1 = pad[2], pad[3] + + d_bolck_h = dilations[0] * (f_h - 1) + 1 + d_bolck_w = dilations[1] * (f_w - 1) + 1 + out_h = (in_h - 1) * stride[0] + d_bolck_h + out_w = (in_w - 1) * stride[1] + d_bolck_w + if 'output_size' in attrs: + output_size = attrs['output_size'] + out_h = output_size[0] + pad_h_0 + pad_h_1 + out_w = output_size[1] + pad_w_0 + pad_w_1 + out_pad_h = 0 + out_pad_w = 0 + if 'output_padding' in attrs: + out_pad_h = attrs['output_padding'][0] + out_pad_w = attrs['output_padding'][1] + out = np.zeros( + (in_n, out_c, out_h + out_pad_h, out_w + out_pad_w), dtype=input_.dtype) + + for n in range(in_n): + for i in range(in_h): + for j in range(in_w): + for g in range(groups): + input_masked = input_[n, g * sub_in_c:(g + 1) * sub_in_c, i, + j] # (c) + input_masked = np.reshape(input_masked, (sub_in_c, 1, 1)) + input_masked = np.tile(input_masked, (1, f_h, f_w)) + + for k in range(f_out_c): + tmp_out = np.sum( + input_masked * + filter_[g * sub_in_c:(g + 1) * sub_in_c, k, :, :], + axis=0) + i1, i2 = i * stride[0], i * stride[0] + d_bolck_h + j1, j2 = j * stride[1], j * stride[1] + d_bolck_w + out[n, g * f_out_c + k, i1:i2:dilations[0], j1:j2: + dilations[1]] += tmp_out + + out = out[:, :, pad_h_0:out_h - pad_h_1 + out_pad_h, pad_w_0:out_w - pad_w_1 + + out_pad_w] + if attrs['data_format'] == 'NHWC': + out = np.transpose(out, [0, 2, 3, 1]) + return out + + +class TestConv2DTransposeOp(XPUOpTest): + def setUp(self): + # init as conv transpose + self.dtype = np.float32 + self.need_check_grad = True + self.is_test = False + self.use_cudnn = False + self.use_mkldnn = False + self.output_size = None + self.output_padding = [] + self.data_format = "NCHW" + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" + self.init_op_type() + self.init_test_case() + self.__class__.op_type = "conv2d_transpose" + + input_ = np.random.random(self.input_size).astype(self.dtype) + filter_ = np.random.random(self.filter_size).astype(self.dtype) + + self.inputs = {'Input': input_, 'Filter': filter_} + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'is_test': self.is_test, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format + } + if self.output_size is not None: + self.attrs['output_size'] = self.output_size + + if len(self.output_padding) > 0: + self.attrs['output_padding'] = self.output_padding + + output = conv2dtranspose_forward_naive(input_, filter_, + self.attrs).astype(self.dtype) + + self.outputs = {'Output': output} + + def test_check_output(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad_no_input(self): + if self.need_check_grad: + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['Filter'], 'Output', no_grad_set=set(['Input'])) + + def test_check_grad_no_filter(self): + if self.need_check_grad: + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['Input'], 'Output', no_grad_set=set(['Filter'])) + + def test_check_grad(self): + if self.need_check_grad: + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, + set(['Input', 'Filter']), 'Output') + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2d_transpose" + + +class TestWithSymmetricPad(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithAsymmetricPad(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithSAMEPad(TestConv2DTransposeOp): + def init_test_case(self): + self.stride = [2, 1] + self.dilations = [1, 2] + self.groups = 1 + self.input_size = [2, 3, 6, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 4, 3] + self.padding_algorithm = 'SAME' + + +class TestWithVALIDPad(TestConv2DTransposeOp): + def init_test_case(self): + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + self.padding_algorithm = 'VALID' + + +class TestWithGroups(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3] + + +class TestWithStride(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +if __name__ == '__main__': + unittest.main() -- GitLab