diff --git a/paddle/fluid/operators/conv_transpose_op_xpu.cc b/paddle/fluid/operators/conv_transpose_op_xpu.cc index 08a58678a2ea8528d9b2740088e9e444c835be7e..882bd0b091e67ece82c4e8cfd8db7105c86b3d90 100644 --- a/paddle/fluid/operators/conv_transpose_op_xpu.cc +++ b/paddle/fluid/operators/conv_transpose_op_xpu.cc @@ -24,106 +24,6 @@ namespace operators { using Tensor = framework::Tensor; -// target_len == 2 || target_len == 4 -inline std::vector vector_extend(const std::vector& src, - int target_len) { - if (target_len == 2 && src.size() == 1) { - return {src[0], src[0]}; - } - if (target_len == 4 && src.size() == 1) { - return {src[0], src[0], src[0], src[0]}; - } - if (target_len == 4 && src.size() == 2) { - return {src[0], src[0], src[1], src[1]}; - } - return src; -} - -template -class Conv2DTransposeXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - // The filter will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); - Tensor* output = context.Output("Output"); - output->mutable_data(context.GetPlace()); - int groups = context.Attr("groups"); - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - std::vector dilations = context.Attr>("dilations"); - const std::string data_format = context.Attr("data_format"); - const std::string padding_algorithm = - context.Attr("padding_algorithm"); - - PADDLE_ENFORCE_EQ( - data_format == "NHWC" || data_format == "NDHWC", - false, - platform::errors::InvalidArgument( - ("XPU do support data_format is NCHW in conv_transpose op."))); - - framework::DDim in_data_dims = - phi::slice_ddim(input->dims(), 2, input->dims().size()); - framework::DDim filter_data_dims = - phi::slice_ddim(filter.dims(), 2, filter.dims().size()); - std::vector ksize = phi::vectorize(filter_data_dims); - phi::UpdatePaddingAndDilation( - &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); - - const int batch_size = static_cast(input->dims()[0]); - const int img_yc = static_cast(input->dims()[1]); - const int img_yh = static_cast(input->dims()[2]); - const int img_yw = static_cast(input->dims()[3]); - const int img_xc = static_cast(output->dims()[1]); - const int img_xh = static_cast(output->dims()[2]); - const int img_xw = static_cast(output->dims()[3]); - - { - std::vector ksize_check = vector_extend(ksize, 2); - std::vector stride_check = vector_extend(strides, 2); - std::vector pad_check = vector_extend(paddings, 4); - std::vector dilation_check = vector_extend(dilations, 2); - - int xh_check = (img_yh - 1) * stride_check[0] - pad_check[0] - - pad_check[1] + - (dilation_check[0] * (ksize_check[0] - 1) + 1); - int xw_check = (img_yw - 1) * stride_check[1] - pad_check[2] - - pad_check[3] + - (dilation_check[1] * (ksize_check[1] - 1) + 1); - - PADDLE_ENFORCE_EQ( - xh_check == img_xh && xw_check == img_xw, - true, - platform::errors::InvalidArgument( - ("XPU output size check error in conv_transpose op."))); - } - - auto& dev_ctx = context.template device_context(); - int r = xpu::conv2d_transpose( - dev_ctx.x_context(), - input->data(), - filter.data(), - output->data(), - batch_size, - img_yc, - img_yh, - img_yw, - img_xc, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - true); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose"); - } -}; - template class Conv2DTransposeGradXPUKernel : public framework::OpKernel { public: @@ -209,9 +109,6 @@ class Conv2DTransposeGradXPUKernel : public framework::OpKernel { } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - conv2d_transpose, - ops::Conv2DTransposeXPUKernel); REGISTER_OP_XPU_KERNEL( conv2d_transpose_grad, ops::Conv2DTransposeGradXPUKernel vector_extend(const std::vector& src, + int target_len) { + if (target_len == 2 && src.size() == 1) { + return {src[0], src[0]}; + } + if (target_len == 4 && src.size() == 1) { + return {src[0], src[0], src[0], src[0]}; + } + if (target_len == 4 && src.size() == 2) { + return {src[0], src[0], src[1], src[1]}; + } + return src; +} + +template +void Conv2dTransposeKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::vector& output_padding, + const std::vector& output_size, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + DenseTensor* out) { + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + DenseTensor filter_ = filter; + + ctx.template Alloc(out); + + PADDLE_ENFORCE_EQ( + data_format == "NHWC" || data_format == "NDHWC", + false, + errors::InvalidArgument( + ("XPU do support data_format is NCHW in conv_transpose op."))); + + DDim in_data_dims = slice_ddim(x.dims(), 2, x.dims().size()); + DDim filter_data_dims = slice_ddim(filter_.dims(), 2, filter_.dims().size()); + std::vector ksize = vectorize(filter_data_dims); + + std::vector paddings_ = paddings; + std::vector dilations_ = dilations; + UpdatePaddingAndDilation( + &paddings_, &dilations_, padding_algorithm, in_data_dims, strides, ksize); + + const int batch_size = static_cast(x.dims()[0]); + const int img_yc = static_cast(x.dims()[1]); + const int img_yh = static_cast(x.dims()[2]); + const int img_yw = static_cast(x.dims()[3]); + const int img_xc = static_cast(out->dims()[1]); + const int img_xh = static_cast(out->dims()[2]); + const int img_xw = static_cast(out->dims()[3]); + + { + std::vector ksize_check = vector_extend(ksize, 2); + std::vector stride_check = vector_extend(strides, 2); + std::vector pad_check = vector_extend(paddings_, 4); + std::vector dilation_check = vector_extend(dilations_, 2); + + int xh_check = (img_yh - 1) * stride_check[0] - pad_check[0] - + pad_check[1] + + (dilation_check[0] * (ksize_check[0] - 1) + 1); + int xw_check = (img_yw - 1) * stride_check[1] - pad_check[2] - + pad_check[3] + + (dilation_check[1] * (ksize_check[1] - 1) + 1); + + PADDLE_ENFORCE_EQ( + xh_check == img_xh && xw_check == img_xw, + true, + errors::InvalidArgument( + ("XPU output size check error in conv_transpose op."))); + } + + int r = + xpu::conv2d_transpose(ctx.x_context(), + x.data(), + filter_.data(), + out->data(), + batch_size, + img_yc, + img_yh, + img_yw, + img_xc, + ksize, + strides, + paddings_, + dilations_, + groups, + nullptr, + nullptr, + nullptr, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + conv2d_transpose, XPU, ALL_LAYOUT, phi::Conv2dTransposeKernel, float) {}