未验证 提交 bd67209f 编写于 作者: L lijin23 提交者: GitHub

[XPU][PHI Kernels] add int_with_ll quantization for conv kernels (#54827)

* add int_with_ll to conv

* fix bugs when output_size is specified for conv2d_transpose
上级 9c2dae1a
......@@ -107,7 +107,7 @@ void ConvGradKernel(const Context& dev_ctx,
}
}
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
if (fccal_type == XPUFCCalcType::FC_INT32) {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
......@@ -132,7 +132,7 @@ void ConvGradKernel(const Context& dev_ctx,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
} else if (fccal_type == 2) {
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
......@@ -157,6 +157,31 @@ void ConvGradKernel(const Context& dev_ctx,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
int r =
xpu::conv2d_grad<XPUT, XPUT, XPUT, int_with_ll_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
} else {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
......@@ -305,7 +330,7 @@ void Conv3DGradKernel(const Context& dev_ctx,
}
}
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
if (fccal_type == XPUFCCalcType::FC_INT32) {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
......@@ -330,7 +355,7 @@ void Conv3DGradKernel(const Context& dev_ctx,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
} else if (fccal_type == 2) {
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
......@@ -355,6 +380,32 @@ void Conv3DGradKernel(const Context& dev_ctx,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
int r =
xpu::conv3d_grad<XPUT, XPUT, XPUT, int_with_ll_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
} else {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
......
......@@ -89,7 +89,7 @@ void ConvKernel(const Context& dev_ctx,
}
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
if (fccal_type == XPUFCCalcType::FC_INT32) {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
......@@ -109,7 +109,7 @@ void ConvKernel(const Context& dev_ctx,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
} else if (fccal_type == 2) {
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
int r = xpu::conv2d<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
......@@ -129,6 +129,26 @@ void ConvKernel(const Context& dev_ctx,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int_with_ll_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
} else {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
......@@ -239,7 +259,7 @@ void Conv3DKernel(const Context& dev_ctx,
}
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
if (fccal_type == XPUFCCalcType::FC_INT32) {
int r = xpu::conv3d<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
......@@ -260,7 +280,7 @@ void Conv3DKernel(const Context& dev_ctx,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
} else if (fccal_type == 2) {
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
int r = xpu::conv3d<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
......@@ -282,6 +302,27 @@ void Conv3DKernel(const Context& dev_ctx,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
int r = xpu::conv3d<XPUT, XPUT, XPUT, int_with_ll_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
} else {
int r = xpu::conv3d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
......
......@@ -14,6 +14,8 @@
#include "paddle/phi/kernels/conv_transpose_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
......@@ -122,6 +124,57 @@ void Conv2dTransposeKernel(const Context& ctx,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
if (output_size.size()) {
VLOG(4) << "int_with_ll quantization is not supported when output_size "
"is specified, "
<< "use int31 instead";
int r = xpu::conv2d_transpose_v2<float, float, float, int32_t>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
batch_size,
img_yc,
img_xh,
img_xw,
img_xc,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
} else {
// xpu::conv2d_transpose_v2 do not support int_with_ll now
// use xpu::conv2d_transpose
int img_yh = static_cast<int>(x.dims()[2]);
int img_yw = static_cast<int>(x.dims()[3]);
int r = xpu::conv2d_transpose<float, float, float, int_with_ll_t>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
batch_size,
img_yc,
img_yh,
img_yw,
img_xc,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose");
}
} else {
int r = xpu::conv2d_transpose_v2<XPUT, XPUT, XPUT, int16_t>(
ctx.x_context(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册