未验证 提交 ca520280 编写于 作者: Y ykkk2333 提交者: GitHub

add xpu batch norm ncdhw layout, test=kunlun (#50384)

上级 c1f2c52c
......@@ -128,6 +128,9 @@ void BatchNormGradKernel(const Context &dev_ctx,
C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W;
D = (D == 0) ? 1 : D;
W = W * D;
const auto *x_data = x.data<T>();
const auto *d_y_data = y_grad.data<T>();
......
......@@ -64,6 +64,9 @@ void BatchNormKernel(const Context& dev_ctx,
C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W;
D = (D == 0) ? 1 : D;
W = W * D;
const auto* x_data = x.data<T>();
const auto* scale_data = scale.data<float>();
......@@ -76,6 +79,14 @@ void BatchNormKernel(const Context& dev_ctx,
dev_ctx.template Alloc<float>(saved_mean);
dev_ctx.template Alloc<float>(saved_variance);
PADDLE_ENFORCE_LE(
x_dims.size(),
5,
phi::errors::InvalidArgument(
"The size of input X's dimensions should be less than 6."
"But received: the size of input X's dimensionss is [%d]",
x_dims.size()));
bool is_nchw = data_layout_str == "NCHW";
if (!global_stats) {
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi {
......@@ -105,7 +106,34 @@ void ConvGradKernel(const Context& dev_ctx,
filter_grad_data_ptr = filter_grad_data_tmp;
}
}
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
} else if (fccal_type == 2) {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
......@@ -127,7 +155,33 @@ void ConvGradKernel(const Context& dev_ctx,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
} else {
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
}
if ((filter_grad_data_ptr != nullptr) && (data_format == "NHWC")) {
std::vector<int> filter_shape_fhwc = {
......@@ -250,7 +304,34 @@ void Conv3DGradKernel(const Context& dev_ctx,
filter_grad_data_ptr = filter_grad_data_tmp;
}
}
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
} else if (fccal_type == 2) {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
......@@ -273,7 +354,33 @@ void Conv3DGradKernel(const Context& dev_ctx,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
} else {
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
}
if ((filter_grad_data_ptr != nullptr) && (data_format == "NDHWC")) {
std::vector<int> filter_shape_fhwc = {filter_shape[0],
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi {
......@@ -87,7 +88,29 @@ void ConvKernel(const Context& dev_ctx,
filter_data_ptr = reinterpret_cast<const XPUT*>(filter_data_tmp);
}
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
} else if (fccal_type == 2) {
int r = xpu::conv2d<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
......@@ -105,7 +128,28 @@ void ConvKernel(const Context& dev_ctx,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
} else {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
}
}
template <typename T, typename Context>
......@@ -194,7 +238,30 @@ void Conv3DKernel(const Context& dev_ctx,
filter_data_ptr = reinterpret_cast<const XPUT*>(filter_data_tmp);
}
int r = xpu::conv3d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == 1) {
int r = xpu::conv3d<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
} else if (fccal_type == 2) {
int r = xpu::conv3d<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
......@@ -213,7 +280,30 @@ void Conv3DKernel(const Context& dev_ctx,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
} else {
int r = xpu::conv3d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
}
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册