未验证 提交 ca520280 编写于 作者: Y ykkk2333 提交者: GitHub

add xpu batch norm ncdhw layout, test=kunlun (#50384)

上级 c1f2c52c
...@@ -128,6 +128,9 @@ void BatchNormGradKernel(const Context &dev_ctx, ...@@ -128,6 +128,9 @@ void BatchNormGradKernel(const Context &dev_ctx,
C = (C == 0) ? 1 : C; C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H; H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W; W = (W == 0) ? 1 : W;
D = (D == 0) ? 1 : D;
W = W * D;
const auto *x_data = x.data<T>(); const auto *x_data = x.data<T>();
const auto *d_y_data = y_grad.data<T>(); const auto *d_y_data = y_grad.data<T>();
......
...@@ -64,6 +64,9 @@ void BatchNormKernel(const Context& dev_ctx, ...@@ -64,6 +64,9 @@ void BatchNormKernel(const Context& dev_ctx,
C = (C == 0) ? 1 : C; C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H; H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W; W = (W == 0) ? 1 : W;
D = (D == 0) ? 1 : D;
W = W * D;
const auto* x_data = x.data<T>(); const auto* x_data = x.data<T>();
const auto* scale_data = scale.data<float>(); const auto* scale_data = scale.data<float>();
...@@ -76,6 +79,14 @@ void BatchNormKernel(const Context& dev_ctx, ...@@ -76,6 +79,14 @@ void BatchNormKernel(const Context& dev_ctx,
dev_ctx.template Alloc<float>(saved_mean); dev_ctx.template Alloc<float>(saved_mean);
dev_ctx.template Alloc<float>(saved_variance); dev_ctx.template Alloc<float>(saved_variance);
PADDLE_ENFORCE_LE(
x_dims.size(),
5,
phi::errors::InvalidArgument(
"The size of input X's dimensions should be less than 6."
"But received: the size of input X's dimensionss is [%d]",
x_dims.size()));
bool is_nchw = data_layout_str == "NCHW"; bool is_nchw = data_layout_str == "NCHW";
if (!global_stats) { if (!global_stats) {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi { namespace phi {
...@@ -105,7 +106,34 @@ void ConvGradKernel(const Context& dev_ctx, ...@@ -105,7 +106,34 @@ void ConvGradKernel(const Context& dev_ctx,
filter_grad_data_ptr = filter_grad_data_tmp; filter_grad_data_ptr = filter_grad_data_tmp;
} }
} }
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(), int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
} else if (fccal_type == 2) {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data, input_data,
filter_data_ptr, filter_data_ptr,
output_grad_data, output_grad_data,
...@@ -127,7 +155,33 @@ void ConvGradKernel(const Context& dev_ctx, ...@@ -127,7 +155,33 @@ void ConvGradKernel(const Context& dev_ctx,
nullptr, nullptr,
nullptr, nullptr,
is_nchw); is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
} else {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
}
if ((filter_grad_data_ptr != nullptr) && (data_format == "NHWC")) { if ((filter_grad_data_ptr != nullptr) && (data_format == "NHWC")) {
std::vector<int> filter_shape_fhwc = { std::vector<int> filter_shape_fhwc = {
...@@ -250,7 +304,34 @@ void Conv3DGradKernel(const Context& dev_ctx, ...@@ -250,7 +304,34 @@ void Conv3DGradKernel(const Context& dev_ctx,
filter_grad_data_ptr = filter_grad_data_tmp; filter_grad_data_ptr = filter_grad_data_tmp;
} }
} }
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(), int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
} else if (fccal_type == 2) {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data, input_data,
filter_data_ptr, filter_data_ptr,
output_grad_data, output_grad_data,
...@@ -273,7 +354,33 @@ void Conv3DGradKernel(const Context& dev_ctx, ...@@ -273,7 +354,33 @@ void Conv3DGradKernel(const Context& dev_ctx,
nullptr, nullptr,
nullptr, nullptr,
is_ncdhw); is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
} else {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
}
if ((filter_grad_data_ptr != nullptr) && (data_format == "NDHWC")) { if ((filter_grad_data_ptr != nullptr) && (data_format == "NDHWC")) {
std::vector<int> filter_shape_fhwc = {filter_shape[0], std::vector<int> filter_shape_fhwc = {filter_shape[0],
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi { namespace phi {
...@@ -87,7 +88,29 @@ void ConvKernel(const Context& dev_ctx, ...@@ -87,7 +88,29 @@ void ConvKernel(const Context& dev_ctx,
filter_data_ptr = reinterpret_cast<const XPUT*>(filter_data_tmp); filter_data_ptr = reinterpret_cast<const XPUT*>(filter_data_tmp);
} }
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(), int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
} else if (fccal_type == 2) {
int r = xpu::conv2d<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data, input_data,
filter_data_ptr, filter_data_ptr,
output_data, output_data,
...@@ -105,7 +128,28 @@ void ConvKernel(const Context& dev_ctx, ...@@ -105,7 +128,28 @@ void ConvKernel(const Context& dev_ctx,
nullptr, nullptr,
nullptr, nullptr,
is_nchw); is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
} else {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
}
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -194,7 +238,30 @@ void Conv3DKernel(const Context& dev_ctx, ...@@ -194,7 +238,30 @@ void Conv3DKernel(const Context& dev_ctx,
filter_data_ptr = reinterpret_cast<const XPUT*>(filter_data_tmp); filter_data_ptr = reinterpret_cast<const XPUT*>(filter_data_tmp);
} }
int r = xpu::conv3d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(), int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
int r = xpu::conv3d<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
} else if (fccal_type == 2) {
int r = xpu::conv3d<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data, input_data,
filter_data_ptr, filter_data_ptr,
output_data, output_data,
...@@ -213,7 +280,30 @@ void Conv3DKernel(const Context& dev_ctx, ...@@ -213,7 +280,30 @@ void Conv3DKernel(const Context& dev_ctx,
nullptr, nullptr,
nullptr, nullptr,
is_ncdhw); is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
} else {
int r = xpu::conv3d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
}
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册