From 88d68c0813b85054121aa923683bd26786ce82c3 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Tue, 26 Apr 2022 22:18:34 +0800 Subject: [PATCH] support nhwc format for kunlun conv/batch_norm (#42195) * support nhwc format for kunlun conv/batch_norm *test=kunlun * minor *test=kunlun --- cmake/external/xpu.cmake | 2 +- paddle/fluid/operators/batch_norm_op_xpu.cc | 18 ++++---- paddle/fluid/operators/conv_op_xpu.cc | 49 ++++++++++++++------- 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index cda8029bfe..be911eb7ea 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -17,7 +17,7 @@ endif() # ubuntu and centos: use output by XDNN API team if(NOT DEFINED XPU_XDNN_BASE_URL) SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") - SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220412") + SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220425") else() SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") endif() diff --git a/paddle/fluid/operators/batch_norm_op_xpu.cc b/paddle/fluid/operators/batch_norm_op_xpu.cc index da138fb482..0893324c60 100644 --- a/paddle/fluid/operators/batch_norm_op_xpu.cc +++ b/paddle/fluid/operators/batch_norm_op_xpu.cc @@ -53,8 +53,12 @@ class BatchNormXPUKernel : public framework::OpKernel { "But received: the size of input's dimensions is [%d]", x_dims.size())); - int N, C, H, W, D; + int N = -1, C = -1, H = -1, W = -1, D = -1; ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + N = (N == 0) ? 1 : N; + C = (C == 0) ? 1 : C; + H = (H == 0) ? 1 : H; + W = (W == 0) ? 1 : W; const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); @@ -103,12 +107,6 @@ class BatchNormXPUKernel : public framework::OpKernel { "The batch_norm XPU API return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } else { - PADDLE_ENFORCE_EQ( - data_layout_str == "NCHW", true, - platform::errors::InvalidArgument( - "The batch_norm_infer 'data_layout' attribute must be NCHW. " - "But recevived 'data_layout' is [%s].", - data_layout_str)); const auto *mean = ctx.Input("Mean"); const auto *variance = ctx.Input("Variance"); const auto *mean_data = mean->data(); @@ -222,8 +220,12 @@ class BatchNormGradXPUKernel : public framework::OpKernel { "But received: the size of input's dimensions is [%d]", x_dims.size())); - int N, C, H, W, D; + int N = -1, C = -1, H = -1, W = -1, D = -1; ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + N = (N == 0) ? 1 : N; + C = (C == 0) ? 1 : C; + H = (H == 0) ? 1 : H; + W = (W == 0) ? 1 : W; const auto *x_data = x->data(); const auto *d_y_data = d_y->data(); diff --git a/paddle/fluid/operators/conv_op_xpu.cc b/paddle/fluid/operators/conv_op_xpu.cc index e4751f1f26..cc5c20d392 100644 --- a/paddle/fluid/operators/conv_op_xpu.cc +++ b/paddle/fluid/operators/conv_op_xpu.cc @@ -38,9 +38,10 @@ class GemmConvXPUKernel : public framework::OpKernel { const std::string padding_algorithm = context.Attr("padding_algorithm"); - PADDLE_ENFORCE_EQ(data_format == "NHWC" || data_format == "NDHWC", false, - platform::errors::InvalidArgument( - ("XPU do support data_format is NCHW in conv op."))); + PADDLE_ENFORCE_EQ( + data_format == "NDHWC", false, + platform::errors::InvalidArgument( + ("XPU does not support data_format is NDHWC in conv op."))); framework::DDim in_data_dims = phi::slice_ddim(input->dims(), 2, input->dims().size()); @@ -50,11 +51,18 @@ class GemmConvXPUKernel : public framework::OpKernel { UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); - const int batch_size = static_cast(input->dims()[0]); - const int img_c = static_cast(input->dims()[1]); - const int img_h = static_cast(input->dims()[2]); - const int img_w = static_cast(input->dims()[3]); - const int f = static_cast(filter.dims()[0]); + int batch_size = static_cast(input->dims()[0]); + int img_c = static_cast(input->dims()[1]); + int img_h = static_cast(input->dims()[2]); + int img_w = static_cast(input->dims()[3]); + int f = static_cast(filter.dims()[0]); + bool is_nchw = true; + if (data_format == "NHWC") { + img_c = static_cast(input->dims()[3]); + img_h = static_cast(input->dims()[1]); + img_w = static_cast(input->dims()[2]); + is_nchw = false; + } const XPUT *input_data = reinterpret_cast(input->data()); const XPUT *filter_data = reinterpret_cast(filter.data()); @@ -64,7 +72,7 @@ class GemmConvXPUKernel : public framework::OpKernel { int r = xpu::conv2d( dev_ctx.x_context(), input_data, filter_data, output_data, batch_size, img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups, - nullptr, nullptr, nullptr, true); + nullptr, nullptr, nullptr, is_nchw); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External("XPU conv kernel return wrong value[%d %s]", @@ -99,9 +107,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel { context.Attr("padding_algorithm"); PADDLE_ENFORCE_EQ( - data_format == "NHWC" || data_format == "NDHWC", false, + data_format == "NDHWC", false, platform::errors::InvalidArgument( - ("XPU do support data_format is NCHW in conv grad op."))); + ("XPU doesn't support data_format is NDHWC in conv grad op."))); framework::DDim in_data_dims = phi::slice_ddim(input->dims(), 2, input->dims().size()); @@ -111,11 +119,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel { UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); - const int batch_size = static_cast(input->dims()[0]); - const int img_c = static_cast(input->dims()[1]); - const int img_h = static_cast(input->dims()[2]); - const int img_w = static_cast(input->dims()[3]); - const int f = static_cast(filter.dims()[0]); + int batch_size = static_cast(input->dims()[0]); + int img_c = static_cast(input->dims()[1]); + int img_h = static_cast(input->dims()[2]); + int img_w = static_cast(input->dims()[3]); + int f = static_cast(filter.dims()[0]); + bool is_nchw = true; + if (data_format == "NHWC") { + img_c = static_cast(input->dims()[3]); + img_h = static_cast(input->dims()[1]); + img_w = static_cast(input->dims()[2]); + is_nchw = false; + } const XPUT *input_data = reinterpret_cast(input->data()); const XPUT *filter_data = reinterpret_cast(filter.data()); @@ -136,7 +151,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel { dev_ctx.x_context(), input_data, filter_data, output_grad_data, input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, - nullptr, nullptr, true); + nullptr, nullptr, is_nchw); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External("XPU conv kernel return wrong value[%d %s]", -- GitLab