未验证 提交 f33ae206 编写于 作者: L Leo Guo 提交者: GitHub

Adapt to batch_norm_grad op and add align function in roi_align op for kunlun (#39685)

* Adapt to batch_norm_grad op and add align function in
roi_align op for kunlun, *test=kunlun

* Adapt to batch_norm, batch_norm_grad op api for kunlun, and add unit-tests of batch_norm, roi_align. *test=kunlun
上级 728c0624
...@@ -38,23 +38,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -38,23 +38,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
bool global_stats = test_mode || use_global_stats; bool global_stats = test_mode || use_global_stats;
const auto &data_layout_str = ctx.Attr<std::string>("data_layout"); const auto &data_layout_str = ctx.Attr<std::string>("data_layout");
const auto data_layout = framework::StringToDataLayout(data_layout_str); const auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW, PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC",
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The 'data_layout' attribute must be NCHW. But " "The 'data_layout' attribute must be NCHW or NHWC. "
"recevived 'data_layout' is [%s].", "But recevived 'data_layout' is [%s].",
data_layout_str)); data_layout_str));
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size(), 4, PADDLE_ENFORCE_EQ(
x_dims.size() >= 2 && x_dims.size() <= 5, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input tensor X's dimension must equal to 4. But " "The size of input's dimensions should be between 2 and 5"
"received X's shape = [%s], X's dimension = [%d].", "But received: the size of input's dimensions is [%d]",
x_dims, x_dims.size())); x_dims.size()));
const int N = x_dims[0];
const int C = x_dims[1]; int N, C, H, W, D;
const int H = x_dims[2]; ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
const int W = x_dims[3];
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
const auto *x_data = x->data<T>(); const auto *x_data = x->data<T>();
...@@ -75,6 +77,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -75,6 +77,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
saved_variance->mutable_data<float>(ctx.GetPlace()); saved_variance->mutable_data<float>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
bool is_nchw = data_layout_str == "NCHW";
if (!global_stats) { if (!global_stats) {
auto *mean_out_data = mean_out->data<float>(); auto *mean_out_data = mean_out->data<float>();
...@@ -95,7 +98,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -95,7 +98,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
int r = xpu::batch_norm<T>(dev_ctx.x_context(), x_data, y_data, N, C, H, int r = xpu::batch_norm<T>(dev_ctx.x_context(), x_data, y_data, N, C, H,
W, epsilon, momentum, scale_data, bias_data, W, epsilon, momentum, scale_data, bias_data,
saved_mean_data, saved_variance_data, saved_mean_data, saved_variance_data,
mean_out_data, variance_out_data, true); mean_out_data, variance_out_data, is_nchw);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External( platform::errors::External(
"The batch_norm XPU API return wrong value[%d %s]", "The batch_norm XPU API return wrong value[%d %s]",
...@@ -107,7 +110,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -107,7 +110,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
const auto *variance_data = variance->data<float>(); const auto *variance_data = variance->data<float>();
int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C, int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C,
H, W, epsilon, scale_data, bias_data, H, W, epsilon, scale_data, bias_data,
mean_data, variance_data, true); mean_data, variance_data, is_nchw);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS, r, xpu::Error_t::SUCCESS,
platform::errors::External( platform::errors::External(
...@@ -168,11 +171,11 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -168,11 +171,11 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const auto data_layout = framework::StringToDataLayout(data_layout_str); const auto data_layout = framework::StringToDataLayout(data_layout_str);
// TODO(guozbin): Transform input tensor from NHWC to NCHW PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC",
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The 'data_layout' attribute must be NCHW. But " "The 'data_layout' attribute must be NCHW or NHWC. "
"recevived 'data_layout' is [%s].", "But recevived 'data_layout' is [%s].",
data_layout_str)); data_layout_str));
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
...@@ -207,15 +210,15 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -207,15 +210,15 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
} }
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size(), 4, PADDLE_ENFORCE_EQ(
x_dims.size() >= 2 && x_dims.size() <= 5, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input tensor X's dimension must equal to 4. But " "The size of input's dimensions should be between 2 and 5"
"received X's shape = [%s], X's dimension = [%d].", "But received: the size of input's dimensions is [%d]",
x_dims, x_dims.size())); x_dims.size()));
const int N = x_dims[0];
const int C = x_dims[1]; int N, C, H, W, D;
const int H = x_dims[2]; ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
const int W = x_dims[3];
const auto *x_data = x->data<T>(); const auto *x_data = x->data<T>();
const auto *d_y_data = d_y->data<T>(); const auto *d_y_data = d_y->data<T>();
...@@ -250,38 +253,35 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -250,38 +253,35 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
const T *mean_data = nullptr; const auto *batch_mean = ctx.Input<Tensor>("SavedMean");
const T *inv_var_data = nullptr; const auto *batch_inv_std = ctx.Input<Tensor>("SavedVariance");
const auto *global_mean = ctx.Input<Tensor>("Mean");
const auto *global_var = ctx.Input<Tensor>("Variance");
// TODO(guozibin): hadle the situation case of N * H * W = 1 // TODO(guozibin): hadle the situation case of N * H * W = 1
if (!use_global_stats) { if (is_inplace) {
const auto *saved_mean = ctx.Input<Tensor>("SavedMean"); float *global_inv_std_data;
// SavedVariance have been reverted in forward operator if (use_global_stats) {
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance"); global_inv_std_data =
mean_data = saved_mean->data<float>(); RAII_GUARD.alloc_l3_or_gm<float>(global_var->numel());
inv_var_data = saved_inv_variance->data<float>();
} else {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
mean_data = running_mean->data<float>();
inv_var_data = running_variance->data<float>();
float *running_inv_var_data =
RAII_GUARD.alloc_l3_or_gm<float>(running_variance->numel());
float *epsilon_data = RAII_GUARD.alloc_l3_or_gm<float>(1); float *epsilon_data = RAII_GUARD.alloc_l3_or_gm<float>(1);
int r1 = calculate_inv_var(dev_ctx.x_context(), inv_var_data, epsilon, C, int r1 =
epsilon_data, running_inv_var_data); calculate_inv_var(dev_ctx.x_context(), global_var->data<float>(),
epsilon, C, epsilon_data, global_inv_std_data);
PADDLE_ENFORCE_EQ(r1, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_EQ(r1, XPU_SUCCESS, platform::errors::External(
"XPU API(batch_norm_grad " "XPU API(batch_norm_grad "
"calculate_inv_var function) " "calculate_inv_var function) "
"return wrong value[%d %s]", "return wrong value[%d %s]",
r1, XPUAPIErrorMsg[r1])); r1, XPUAPIErrorMsg[r1]));
inv_var_data = running_inv_var_data;
} }
if (is_inplace) {
auto px = *x; auto px = *x;
auto *inv_std_data =
use_global_stats ? global_inv_std_data : batch_inv_std->data<float>();
auto mean_data = use_global_stats ? global_mean->data<float>()
: batch_mean->data<float>();
int r2 = calculate_inv_BN_Y( int r2 = calculate_inv_BN_Y(
dev_ctx.x_context(), px.mutable_data<T>(ctx.GetPlace()), dev_ctx.x_context(), px.mutable_data<T>(ctx.GetPlace()),
scale->data<float>(), bias->data<float>(), mean_data, inv_var_data, N, scale->data<float>(), bias->data<float>(), mean_data, inv_std_data, N,
C, H * W, x->data<T>()); C, H * W, x->data<T>());
PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, platform::errors::External(
"XPU API(batch_norm_grad " "XPU API(batch_norm_grad "
...@@ -289,6 +289,15 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -289,6 +289,15 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
"return wrong value[%d %s]", "return wrong value[%d %s]",
r2, XPUAPIErrorMsg[r2])); r2, XPUAPIErrorMsg[r2]));
} }
int r3;
bool is_nchw = data_layout_str == "NCHW";
if (use_global_stats) {
r3 = xpu::batch_norm_grad<T>(
dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W,
scale_data, nullptr, nullptr, d_scale_data, d_bias_data, is_nchw,
global_mean->data<float>(), global_var->data<float>(), epsilon);
} else {
if (!d_x) { if (!d_x) {
d_x_data = RAII_GUARD.alloc_l3_or_gm<T>(x->numel()); d_x_data = RAII_GUARD.alloc_l3_or_gm<T>(x->numel());
} }
...@@ -298,10 +307,11 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -298,10 +307,11 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
if (!d_bias_data) { if (!d_bias_data) {
d_bias_data = RAII_GUARD.alloc_l3_or_gm<float>(C); d_bias_data = RAII_GUARD.alloc_l3_or_gm<float>(C);
} }
r3 = xpu::batch_norm_grad<T>(
int r3 = xpu::batch_norm_grad<T>( dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W,
dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W, scale_data, scale_data, batch_mean->data<float>(), batch_inv_std->data<float>(),
mean_data, inv_var_data, d_scale_data, d_bias_data, true); d_scale_data, d_bias_data, is_nchw);
}
PADDLE_ENFORCE_EQ(r3, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_EQ(r3, XPU_SUCCESS, platform::errors::External(
"XPU API(batch_norm_grad) return " "XPU API(batch_norm_grad) return "
"wrong value[%d %s]", "wrong value[%d %s]",
......
...@@ -32,6 +32,7 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -32,6 +32,7 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> {
auto pooled_width = ctx.Attr<int>("pooled_width"); auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale"); auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio"); auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto aligned = ctx.Attr<bool>("aligned");
auto in_dims = in->dims(); auto in_dims = in->dims();
int batch_size = in_dims[0]; int batch_size = in_dims[0];
...@@ -117,7 +118,7 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -117,7 +118,7 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), in->data<T>(), dev_ctx.x_context(), in->data<T>(),
out->mutable_data<T>(ctx.GetPlace()), rois->data<T>(), roi_id_data, out->mutable_data<T>(ctx.GetPlace()), rois->data<T>(), roi_id_data,
batch_size, channels, height, width, out->dims()[0], pooled_height, batch_size, channels, height, width, out->dims()[0], pooled_height,
pooled_width, spatial_scale, sampling_ratio, true); pooled_width, spatial_scale, sampling_ratio, true, aligned);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External( platform::errors::External(
"The roi_align XPU OP return wrong value[%d %s]", r, "The roi_align XPU OP return wrong value[%d %s]", r,
...@@ -143,6 +144,7 @@ class XPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -143,6 +144,7 @@ class XPUROIAlignGradOpKernel : public framework::OpKernel<T> {
auto pooled_width = ctx.Attr<int>("pooled_width"); auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale"); auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio"); auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto aligned = ctx.Attr<bool>("aligned");
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
int channels = in->dims()[1]; int channels = in->dims()[1];
...@@ -197,7 +199,7 @@ class XPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -197,7 +199,7 @@ class XPUROIAlignGradOpKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), out_grad->data<T>(), in_grad->data<T>(), dev_ctx.x_context(), out_grad->data<T>(), in_grad->data<T>(),
rois->data<T>(), roi_id_data, in->dims()[0], channels, height, width, rois->data<T>(), roi_id_data, in->dims()[0], channels, height, width,
out_grad->dims()[0], pooled_height, pooled_width, spatial_scale, out_grad->dims()[0], pooled_height, pooled_width, spatial_scale,
sampling_ratio, true); sampling_ratio, true, aligned);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS, r, xpu::Error_t::SUCCESS,
platform::errors::External( platform::errors::External(
......
...@@ -296,7 +296,9 @@ class TestXPUBatchNormOpUseGlobalStats(unittest.TestCase): ...@@ -296,7 +296,9 @@ class TestXPUBatchNormOpUseGlobalStats(unittest.TestCase):
net2.training = False net2.training = False
y1 = net1(x) y1 = net1(x)
y2 = net2(x) y2 = net2(x)
self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) self.assertEqual(
np.allclose(
y1.numpy(), y2.numpy(), atol=1e-4), True)
class TestXPUBatchNormUseGlobalStatsCase1(TestXPUBatchNormOpUseGlobalStats): class TestXPUBatchNormUseGlobalStatsCase1(TestXPUBatchNormOpUseGlobalStats):
...@@ -320,5 +322,12 @@ class TestXPUBatchNormUseGlobalStatsCase3(TestXPUBatchNormOpUseGlobalStats): ...@@ -320,5 +322,12 @@ class TestXPUBatchNormUseGlobalStatsCase3(TestXPUBatchNormOpUseGlobalStats):
self.trainable_statistics = True self.trainable_statistics = True
class TestXPUBatchNormUseGlobalStatsCase4(TestXPUBatchNormOpUseGlobalStats):
### train mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = False
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -40,7 +40,8 @@ class TestROIAlignOp(XPUOpTest): ...@@ -40,7 +40,8 @@ class TestROIAlignOp(XPUOpTest):
'spatial_scale': self.spatial_scale, 'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height, 'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width, 'pooled_width': self.pooled_width,
'sampling_ratio': self.sampling_ratio 'sampling_ratio': self.sampling_ratio,
'aligned': self.continuous_coordinate
} }
self.outputs = {'Out': self.out_data} self.outputs = {'Out': self.out_data}
...@@ -51,6 +52,8 @@ class TestROIAlignOp(XPUOpTest): ...@@ -51,6 +52,8 @@ class TestROIAlignOp(XPUOpTest):
self.height = 8 self.height = 8
self.width = 6 self.width = 6
self.xpu_version = core.get_xpu_device_version(0)
# n, c, h, w # n, c, h, w
self.x_dim = (self.batch_size, self.channels, self.height, self.width) self.x_dim = (self.batch_size, self.channels, self.height, self.width)
...@@ -58,7 +61,10 @@ class TestROIAlignOp(XPUOpTest): ...@@ -58,7 +61,10 @@ class TestROIAlignOp(XPUOpTest):
self.pooled_height = 2 self.pooled_height = 2
self.pooled_width = 2 self.pooled_width = 2
self.sampling_ratio = -1 self.sampling_ratio = -1
if self.xpu_version == core.XPUVersion.XPU1:
self.continuous_coordinate = False
else:
self.continuous_coordinate = bool(np.random.randint(2))
self.x = np.random.random(self.x_dim).astype('float32') self.x = np.random.random(self.x_dim).astype('float32')
def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w, def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w,
...@@ -124,12 +130,16 @@ class TestROIAlignOp(XPUOpTest): ...@@ -124,12 +130,16 @@ class TestROIAlignOp(XPUOpTest):
roi = self.rois[i] roi = self.rois[i]
roi_batch_id = int(roi[0]) roi_batch_id = int(roi[0])
x_i = self.x[roi_batch_id] x_i = self.x[roi_batch_id]
roi_xmin = roi[1] * self.spatial_scale roi_offset = 0.5 if self.continuous_coordinate else 0
roi_ymin = roi[2] * self.spatial_scale roi_xmin = roi[1] * self.spatial_scale - roi_offset
roi_xmax = roi[3] * self.spatial_scale roi_ymin = roi[2] * self.spatial_scale - roi_offset
roi_ymax = roi[4] * self.spatial_scale roi_xmax = roi[3] * self.spatial_scale - roi_offset
roi_width = max(roi_xmax - roi_xmin, 1) roi_ymax = roi[4] * self.spatial_scale - roi_offset
roi_height = max(roi_ymax - roi_ymin, 1) roi_width = roi_xmax - roi_xmin
roi_height = roi_ymax - roi_ymin
if not self.continuous_coordinate:
roi_width = max(roi_width, 1)
roi_height = max(roi_height, 1)
bin_size_h = float(roi_height) / float(self.pooled_height) bin_size_h = float(roi_height) / float(self.pooled_height)
bin_size_w = float(roi_width) / float(self.pooled_width) bin_size_w = float(roi_width) / float(self.pooled_width)
roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \ roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \
...@@ -203,7 +213,8 @@ class TestROIAlignInLodOp(TestROIAlignOp): ...@@ -203,7 +213,8 @@ class TestROIAlignInLodOp(TestROIAlignOp):
'spatial_scale': self.spatial_scale, 'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height, 'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width, 'pooled_width': self.pooled_width,
'sampling_ratio': self.sampling_ratio 'sampling_ratio': self.sampling_ratio,
'aligned': self.continuous_coordinate
} }
self.outputs = {'Out': self.out_data} self.outputs = {'Out': self.out_data}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册