未验证 提交 88d68c08 编写于 作者: Q QingshuChen 提交者: GitHub

support nhwc format for kunlun conv/batch_norm (#42195)

* support nhwc format for kunlun conv/batch_norm
*test=kunlun

* minor
*test=kunlun
上级 5be9b824
...@@ -17,7 +17,7 @@ endif() ...@@ -17,7 +17,7 @@ endif()
# ubuntu and centos: use output by XDNN API team # ubuntu and centos: use output by XDNN API team
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220412") SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220425")
else() else()
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
...@@ -53,8 +53,12 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -53,8 +53,12 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
"But received: the size of input's dimensions is [%d]", "But received: the size of input's dimensions is [%d]",
x_dims.size())); x_dims.size()));
int N, C, H, W, D; int N = -1, C = -1, H = -1, W = -1, D = -1;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
N = (N == 0) ? 1 : N;
C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W;
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
...@@ -103,12 +107,6 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -103,12 +107,6 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
"The batch_norm XPU API return wrong value[%d %s]", "The batch_norm XPU API return wrong value[%d %s]",
r, XPUAPIErrorMsg[r])); r, XPUAPIErrorMsg[r]));
} else { } else {
PADDLE_ENFORCE_EQ(
data_layout_str == "NCHW", true,
platform::errors::InvalidArgument(
"The batch_norm_infer 'data_layout' attribute must be NCHW. "
"But recevived 'data_layout' is [%s].",
data_layout_str));
const auto *mean = ctx.Input<Tensor>("Mean"); const auto *mean = ctx.Input<Tensor>("Mean");
const auto *variance = ctx.Input<Tensor>("Variance"); const auto *variance = ctx.Input<Tensor>("Variance");
const auto *mean_data = mean->data<float>(); const auto *mean_data = mean->data<float>();
...@@ -222,8 +220,12 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -222,8 +220,12 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
"But received: the size of input's dimensions is [%d]", "But received: the size of input's dimensions is [%d]",
x_dims.size())); x_dims.size()));
int N, C, H, W, D; int N = -1, C = -1, H = -1, W = -1, D = -1;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
N = (N == 0) ? 1 : N;
C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W;
const auto *x_data = x->data<T>(); const auto *x_data = x->data<T>();
const auto *d_y_data = d_y->data<T>(); const auto *d_y_data = d_y->data<T>();
......
...@@ -38,9 +38,10 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -38,9 +38,10 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const std::string padding_algorithm = const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm"); context.Attr<std::string>("padding_algorithm");
PADDLE_ENFORCE_EQ(data_format == "NHWC" || data_format == "NDHWC", false, PADDLE_ENFORCE_EQ(
data_format == "NDHWC", false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
("XPU do support data_format is NCHW in conv op."))); ("XPU does not support data_format is NDHWC in conv op.")));
framework::DDim in_data_dims = framework::DDim in_data_dims =
phi::slice_ddim(input->dims(), 2, input->dims().size()); phi::slice_ddim(input->dims(), 2, input->dims().size());
...@@ -50,11 +51,18 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -50,11 +51,18 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]); int batch_size = static_cast<int>(input->dims()[0]);
const int img_c = static_cast<int>(input->dims()[1]); int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]); int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]); int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]); int f = static_cast<int>(filter.dims()[0]);
bool is_nchw = true;
if (data_format == "NHWC") {
img_c = static_cast<int>(input->dims()[3]);
img_h = static_cast<int>(input->dims()[1]);
img_w = static_cast<int>(input->dims()[2]);
is_nchw = false;
}
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>()); const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>()); const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
...@@ -64,7 +72,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -64,7 +72,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>( int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_data, batch_size, dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups, img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
nullptr, nullptr, nullptr, true); nullptr, nullptr, nullptr, is_nchw);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]", platform::errors::External("XPU conv kernel return wrong value[%d %s]",
...@@ -99,9 +107,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -99,9 +107,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
context.Attr<std::string>("padding_algorithm"); context.Attr<std::string>("padding_algorithm");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
data_format == "NHWC" || data_format == "NDHWC", false, data_format == "NDHWC", false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
("XPU do support data_format is NCHW in conv grad op."))); ("XPU doesn't support data_format is NDHWC in conv grad op.")));
framework::DDim in_data_dims = framework::DDim in_data_dims =
phi::slice_ddim(input->dims(), 2, input->dims().size()); phi::slice_ddim(input->dims(), 2, input->dims().size());
...@@ -111,11 +119,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -111,11 +119,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]); int batch_size = static_cast<int>(input->dims()[0]);
const int img_c = static_cast<int>(input->dims()[1]); int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]); int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]); int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]); int f = static_cast<int>(filter.dims()[0]);
bool is_nchw = true;
if (data_format == "NHWC") {
img_c = static_cast<int>(input->dims()[3]);
img_h = static_cast<int>(input->dims()[1]);
img_w = static_cast<int>(input->dims()[2]);
is_nchw = false;
}
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>()); const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>()); const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
...@@ -136,7 +151,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -136,7 +151,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), input_data, filter_data, output_grad_data, dev_ctx.x_context(), input_data, filter_data, output_grad_data,
input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f, input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
nullptr, nullptr, true); nullptr, nullptr, is_nchw);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]", platform::errors::External("XPU conv kernel return wrong value[%d %s]",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册