未验证 提交 e9f8baa6 编写于 作者: L lijin23 提交者: GitHub

add int quantization for xpu (#54802)

上级 3371f98b
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi {
template <typename T, typename Context>
......@@ -68,7 +69,37 @@ void Conv2dTransposeGradKernel(const Context& ctx,
if (dfilter) {
ctx.template Alloc<T>(dfilter);
}
int fccal_type = FCCalcType<T>();
if (fccal_type == XPUFCCalcType::FC_INT32 ||
fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
// xpu api do not support int31 quantization now.
int r = xpu::conv2d_transpose_grad<float, float, float, int_with_ll_t>(
ctx.x_context(),
x.data<T>(),
filter_.data<T>(),
dout.data<T>(),
dx ? dx->data<T>() : nullptr,
dfilter ? dfilter->data<T>() : nullptr,
batch_size,
img_yc,
img_yh,
img_yw,
img_xc,
img_xh,
img_xw,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad");
} else {
int r = xpu::conv2d_transpose_grad<float, float, float, int16_t>(
ctx.x_context(),
x.data<T>(),
......@@ -95,6 +126,7 @@ void Conv2dTransposeGradKernel(const Context& ctx,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad");
}
}
template <typename T, typename Context>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册