From e9f8baa674333c4c4e7ebf86b14d9c36ed3cb461 Mon Sep 17 00:00:00 2001 From: lijin23 <41257772+lj970926@users.noreply.github.com> Date: Wed, 21 Jun 2023 15:12:57 +0800 Subject: [PATCH] add int quantization for xpu (#54802) --- .../kernels/xpu/conv_transpose_grad_kernel.cc | 86 +++++++++++++------ 1 file changed, 59 insertions(+), 27 deletions(-) diff --git a/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc index 4b61b61b5e2..f6090980745 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" namespace phi { template @@ -68,33 +69,64 @@ void Conv2dTransposeGradKernel(const Context& ctx, if (dfilter) { ctx.template Alloc(dfilter); } - - int r = xpu::conv2d_transpose_grad( - ctx.x_context(), - x.data(), - filter_.data(), - dout.data(), - dx ? dx->data() : nullptr, - dfilter ? dfilter->data() : nullptr, - batch_size, - img_yc, - img_yh, - img_yw, - img_xc, - img_xh, - img_xw, - ksize, - strides, - paddings_, - dilations_, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - true); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad"); + int fccal_type = FCCalcType(); + if (fccal_type == XPUFCCalcType::FC_INT32 || + fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { + // xpu api do not support int31 quantization now. + int r = xpu::conv2d_transpose_grad( + ctx.x_context(), + x.data(), + filter_.data(), + dout.data(), + dx ? dx->data() : nullptr, + dfilter ? dfilter->data() : nullptr, + batch_size, + img_yc, + img_yh, + img_yw, + img_xc, + img_xh, + img_xw, + ksize, + strides, + paddings_, + dilations_, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad"); + } else { + int r = xpu::conv2d_transpose_grad( + ctx.x_context(), + x.data(), + filter_.data(), + dout.data(), + dx ? dx->data() : nullptr, + dfilter ? dfilter->data() : nullptr, + batch_size, + img_yc, + img_yh, + img_yw, + img_xc, + img_xh, + img_xw, + ksize, + strides, + paddings_, + dilations_, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad"); + } } template -- GitLab