未验证 提交 88d68c08 编写于 作者: Q QingshuChen 提交者: GitHub

support nhwc format for kunlun conv/batch_norm (#42195)

* support nhwc format for kunlun conv/batch_norm
*test=kunlun

* minor
*test=kunlun
上级 5be9b824
......@@ -17,7 +17,7 @@ endif()
# ubuntu and centos: use output by XDNN API team
if(NOT DEFINED XPU_XDNN_BASE_URL)
SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220412")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220425")
else()
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()
......
......@@ -53,8 +53,12 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
"But received: the size of input's dimensions is [%d]",
x_dims.size()));
int N, C, H, W, D;
int N = -1, C = -1, H = -1, W = -1, D = -1;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
N = (N == 0) ? 1 : N;
C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W;
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
......@@ -103,12 +107,6 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
"The batch_norm XPU API return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
} else {
PADDLE_ENFORCE_EQ(
data_layout_str == "NCHW", true,
platform::errors::InvalidArgument(
"The batch_norm_infer 'data_layout' attribute must be NCHW. "
"But recevived 'data_layout' is [%s].",
data_layout_str));
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *variance = ctx.Input<Tensor>("Variance");
const auto *mean_data = mean->data<float>();
......@@ -222,8 +220,12 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
"But received: the size of input's dimensions is [%d]",
x_dims.size()));
int N, C, H, W, D;
int N = -1, C = -1, H = -1, W = -1, D = -1;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
N = (N == 0) ? 1 : N;
C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W;
const auto *x_data = x->data<T>();
const auto *d_y_data = d_y->data<T>();
......
......@@ -38,9 +38,10 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
PADDLE_ENFORCE_EQ(data_format == "NHWC" || data_format == "NDHWC", false,
platform::errors::InvalidArgument(
("XPU do support data_format is NCHW in conv op.")));
PADDLE_ENFORCE_EQ(
data_format == "NDHWC", false,
platform::errors::InvalidArgument(
("XPU does not support data_format is NDHWC in conv op.")));
framework::DDim in_data_dims =
phi::slice_ddim(input->dims(), 2, input->dims().size());
......@@ -50,11 +51,18 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]);
const int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
int batch_size = static_cast<int>(input->dims()[0]);
int img_c = static_cast<int>(input->dims()[1]);
int img_h = static_cast<int>(input->dims()[2]);
int img_w = static_cast<int>(input->dims()[3]);
int f = static_cast<int>(filter.dims()[0]);
bool is_nchw = true;
if (data_format == "NHWC") {
img_c = static_cast<int>(input->dims()[3]);
img_h = static_cast<int>(input->dims()[1]);
img_w = static_cast<int>(input->dims()[2]);
is_nchw = false;
}
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
......@@ -64,7 +72,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
nullptr, nullptr, nullptr, true);
nullptr, nullptr, nullptr, is_nchw);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
......@@ -99,9 +107,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
context.Attr<std::string>("padding_algorithm");
PADDLE_ENFORCE_EQ(
data_format == "NHWC" || data_format == "NDHWC", false,
data_format == "NDHWC", false,
platform::errors::InvalidArgument(
("XPU do support data_format is NCHW in conv grad op.")));
("XPU doesn't support data_format is NDHWC in conv grad op.")));
framework::DDim in_data_dims =
phi::slice_ddim(input->dims(), 2, input->dims().size());
......@@ -111,11 +119,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]);
const int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
int batch_size = static_cast<int>(input->dims()[0]);
int img_c = static_cast<int>(input->dims()[1]);
int img_h = static_cast<int>(input->dims()[2]);
int img_w = static_cast<int>(input->dims()[3]);
int f = static_cast<int>(filter.dims()[0]);
bool is_nchw = true;
if (data_format == "NHWC") {
img_c = static_cast<int>(input->dims()[3]);
img_h = static_cast<int>(input->dims()[1]);
img_w = static_cast<int>(input->dims()[2]);
is_nchw = false;
}
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
......@@ -136,7 +151,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), input_data, filter_data, output_grad_data,
input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
nullptr, nullptr, true);
nullptr, nullptr, is_nchw);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册