未验证 提交 e9f8baa6 编写于 作者: L lijin23 提交者: GitHub

add int quantization for xpu (#54802)

上级 3371f98b
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
...@@ -68,33 +69,64 @@ void Conv2dTransposeGradKernel(const Context& ctx, ...@@ -68,33 +69,64 @@ void Conv2dTransposeGradKernel(const Context& ctx,
if (dfilter) { if (dfilter) {
ctx.template Alloc<T>(dfilter); ctx.template Alloc<T>(dfilter);
} }
int fccal_type = FCCalcType<T>();
int r = xpu::conv2d_transpose_grad<float, float, float, int16_t>( if (fccal_type == XPUFCCalcType::FC_INT32 ||
ctx.x_context(), fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
x.data<T>(), // xpu api do not support int31 quantization now.
filter_.data<T>(), int r = xpu::conv2d_transpose_grad<float, float, float, int_with_ll_t>(
dout.data<T>(), ctx.x_context(),
dx ? dx->data<T>() : nullptr, x.data<T>(),
dfilter ? dfilter->data<T>() : nullptr, filter_.data<T>(),
batch_size, dout.data<T>(),
img_yc, dx ? dx->data<T>() : nullptr,
img_yh, dfilter ? dfilter->data<T>() : nullptr,
img_yw, batch_size,
img_xc, img_yc,
img_xh, img_yh,
img_xw, img_yw,
ksize, img_xc,
strides, img_xh,
paddings_, img_xw,
dilations_, ksize,
groups, strides,
nullptr, paddings_,
nullptr, dilations_,
nullptr, groups,
nullptr, nullptr,
nullptr, nullptr,
true); nullptr,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad"); nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad");
} else {
int r = xpu::conv2d_transpose_grad<float, float, float, int16_t>(
ctx.x_context(),
x.data<T>(),
filter_.data<T>(),
dout.data<T>(),
dx ? dx->data<T>() : nullptr,
dfilter ? dfilter->data<T>() : nullptr,
batch_size,
img_yc,
img_yh,
img_yw,
img_xc,
img_xh,
img_xw,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_grad");
}
} }
template <typename T, typename Context> template <typename T, typename Context>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册