未验证 提交 7cb19f57 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] BatchNorm support layout of NCL and NLC, test=develop (#35668)

* [NPU] support NCL and NCL for BatchNorm, test=develop

* [NPU] remove debug files, test=develop

* update, test=develop
上级 a29ff4c7
......@@ -38,11 +38,13 @@ class NPUBatchNormOpKernel : public framework::OpKernel<T> {
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
PADDLE_ENFORCE_EQ(
(x_dims.size() == 4UL || x_dims.size() == 3UL), true,
platform::errors::InvalidArgument(
"The input tensor X's dimension must equal to 4. But "
"received X's shape = [%s], X's dimension = [%d].",
x_dims, x_dims.size()));
"The input tensor X's dimension must equal to 3 or 4. "
" But got X's shape = [%s], X's dimension = [%d].",
x_dims.to_str(), x_dims.size()));
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_var = ctx.Input<Tensor>("Variance");
const auto *scale = ctx.Input<Tensor>("Scale");
......@@ -51,8 +53,11 @@ class NPUBatchNormOpKernel : public framework::OpKernel<T> {
auto *y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
Tensor x_tensor(x->type());
Tensor y_tesnor(y->type());
auto &dev_ctx = ctx.template device_context<NPUDeviceContext>();
auto x_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(x->dims(), dev_ctx);
auto y_tesnor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(y->dims(), dev_ctx);
x_tensor.ShareDataWith(*x);
y_tesnor.ShareDataWith(*y);
if (data_layout == DataLayout::kNHWC) {
......@@ -89,6 +94,18 @@ class NPUBatchNormOpKernel : public framework::OpKernel<T> {
sum.mutable_data<float>(running_mean->dims(), ctx.GetPlace());
square_sum.mutable_data<float>(running_mean->dims(), ctx.GetPlace());
// BNTrainingReduce ONLY support rank = 4
if (x->dims().size() == 3) {
auto x_shape_vec = framework::vectorize(x->dims());
if (data_layout == DataLayout::kNCHW) {
x_shape_vec.push_back(1); // expand NCL -> NCL1
} else {
x_shape_vec.insert(x_shape_vec.begin() + 2, 1); // expand NLC -> NL1C
}
auto x_new_shape = framework::make_ddim(x_shape_vec);
x_tensor.Resize(x_new_shape);
x_tensor.Resize(x_new_shape);
}
const auto &runner_reduce =
NpuOpRunner("BNTrainingReduce", {x_tensor}, {sum, square_sum},
{{"epsilon", epsilon}});
......@@ -127,8 +144,11 @@ class NPUBatchNormGradOpKernel : public framework::OpKernel<T> {
use_global_stats = is_test || use_global_stats;
Tensor x_tensor(x->type());
Tensor dy_tensor(d_y->type());
auto &dev_ctx = ctx.template device_context<NPUDeviceContext>();
auto x_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(x->dims(), dev_ctx);
auto dy_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(d_y->dims(), dev_ctx);
x_tensor.ShareDataWith(*x);
dy_tensor.ShareDataWith(*d_y);
if (data_layout == DataLayout::kNHWC) {
......@@ -136,14 +156,14 @@ class NPUBatchNormGradOpKernel : public framework::OpKernel<T> {
dy_tensor.set_layout(DataLayout::kNHWC);
}
Tensor scale_grad_tmp(scale->type());
Tensor bias_grad_tmp(bias->type());
auto scale_grad_tmp =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(scale->dims(), dev_ctx);
auto bias_grad_tmp =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(bias->dims(), dev_ctx);
if (d_scale == nullptr) {
scale_grad_tmp.Resize(scale->dims());
d_scale = &scale_grad_tmp;
}
if (d_bias == nullptr) {
bias_grad_tmp.Resize(bias->dims());
d_bias = &bias_grad_tmp;
}
......@@ -169,9 +189,23 @@ class NPUBatchNormGradOpKernel : public framework::OpKernel<T> {
}
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
Tensor dx_tensor(d_x->type());
auto dx_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(d_x->dims(), dev_ctx);
dx_tensor.ShareDataWith(*d_x);
if (use_global_stats) {
if (x->dims().size() == 3) {
// BNInferGrad only support x rank = 4,
auto x_shape_vec = framework::vectorize(d_x->dims());
if (data_layout == DataLayout::kNCHW) {
x_shape_vec.push_back(1); // expand NCL -> NCL1
} else {
x_shape_vec.insert(x_shape_vec.begin() + 2,
1); // expand NLC -> NL1C
}
auto x_new_shape = framework::make_ddim(x_shape_vec);
dx_tensor.Resize(x_new_shape);
dy_tensor.Resize(x_new_shape);
}
const auto *running_var = ctx.Input<Tensor>("Variance");
const auto &runner_infer =
NpuOpRunner("BNInferGrad", {dy_tensor, *scale, *running_var},
......
......@@ -186,11 +186,6 @@ class DepthwiseConvGradNPUKernel : public framework::OpKernel<T> {
dilations[3] = dilation[1];
}
// LOG(INFO) << "strides = " << framework::make_ddim(strides).to_str();
// LOG(INFO) << "dilations = " << framework::make_ddim(dilations).to_str();
// LOG(INFO) << "padding = " << framework::make_ddim(padding).to_str();
// LOG(INFO) << "data_format = " << data_format;
if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace());
......
......@@ -45,6 +45,14 @@ class TestBatchNormOpInference(unittest.TestCase):
if len(shape) == 2:
x_shape = shape
c = x_shape[1]
if len(shape) == 3:
n, l, c = shape[0], shape[1], shape[2]
if data_layout == "NHWC": # NLC
x_shape = [n, l, c]
elif data_layout == "NCHW": # NCL
x_shape = [n, c, l]
else:
raise ValueError("Unknown data layout.")
else:
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
if data_layout == "NHWC":
......@@ -117,6 +125,7 @@ class TestBatchNormOpInference(unittest.TestCase):
place = core.NPUPlace(0)
for data_format in self.data_formats:
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [3, 8, 5])
def init_kernel_type(self):
pass
......@@ -185,6 +194,15 @@ class TestBatchNormOpTraining(unittest.TestCase):
# attr
epsilon = self.epsilon
momentum = self.momentum
if len(shape) == 3:
if data_layout == "NHWC": # NLC
n, l, c = shape[0], shape[1], shape[2]
elif data_layout == "NCHW": # NCL
n, c, l = shape[0], shape[1], shape[2]
else:
raise ValueError("Unknown data layout.")
else:
if data_layout == "NCHW":
n, c, h, w = shape[0], shape[1], shape[2], shape[3]
else:
......@@ -296,6 +314,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
for data_format in self.data_formats:
test_with_place(core.NPUPlace(0), data_format, [2, 3, 4, 5])
test_with_place(core.NPUPlace(0), data_format, [3, 8, 5])
def init_kernel_type(self):
pass
......@@ -328,6 +347,17 @@ class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
]
def reference_grad(self, x, y_grad, scale, mean, var, epsilon, data_format):
x_shape = x.shape
if len(x_shape) == 3:
if data_format == "NCHW": # NCL -> NCL1
x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1))
y_grad = np.reshape(y_grad,
(x_shape[0], x_shape[1], x_shape[2], 1))
else: # NLC -> NL1C
x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2]))
y_grad = np.reshape(y_grad,
(x_shape[0], x_shape[1], 1, x_shape[2]))
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
y_grad = np.transpose(y_grad, (0, 2, 3, 1))
......@@ -343,6 +373,9 @@ class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
x = np.transpose(x, (0, 3, 1, 2))
y_grad = np.transpose(y_grad, (0, 3, 1, 2))
if len(x_shape) == 3:
x_grad = np.reshape(x_grad, x_shape)
return x_grad, grad_scale, grad_offset
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
......@@ -350,6 +383,17 @@ class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
if data_layout != "NCHW" and data_layout != "NHWC":
raise ValueError("Unknown data order.")
x_shape = x.shape
if len(x_shape) == 3:
if data_layout == "NCHW": # NCL -> NCL1
x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1))
y_grad = np.reshape(y_grad,
(x_shape[0], x_shape[1], x_shape[2], 1))
else: # NLC -> NL1C
x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2]))
y_grad = np.reshape(y_grad,
(x_shape[0], x_shape[1], 1, x_shape[2]))
if data_layout == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
......@@ -369,6 +413,10 @@ class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
x_grad, scale_grad, bias_grad = self.reference_grad(
x, y_grad, scale, mean, variance, epsilon, data_layout)
if len(x_shape) == 3:
y = np.reshape(y, x_shape)
x_grad = np.reshape(x_grad, x_shape)
return y, mean_out, variance_out, mean, saved_variance, x_grad, scale_grad, bias_grad
......
......@@ -36,6 +36,11 @@ def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
else:
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
if len(x_shape) == 3:
if data_format == "NCHW": # NCL -> NCL1
x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1))
else: # NLC -> NL1C
x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2]))
if data_format == "NCHW":
n, c, h, w = x.shape
......@@ -55,13 +60,19 @@ def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
else:
raise ValueError("Unknown data order.")
if len(x_shape) == 2:
if len(x_shape) == 2 or len(x_shape) == 3:
y = np.reshape(y, x_shape)
return y
def _cal_mean_variance(x, epsilon, data_format):
assert data_format in ['NCHW', 'NHWC']
x_shape = x.shape
if len(x_shape) == 3:
if data_format == "NCHW": # NCL -> NCL1
x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1))
else: # NLC -> NL1C
x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2]))
x_square = x * x
axis = (0, 2, 3) if data_format == 'NCHW' else (0, 1, 2)
C = x.shape[1] if data_format == 'NCHW' else x.shape[-1]
......@@ -76,6 +87,12 @@ def _cal_mean_variance(x, epsilon, data_format):
def _reference_training(x, scale, offset, epsilon, data_format):
x_shape = x.shape
if len(x_shape) == 3:
if data_format == "NCHW": # NCL -> NCL1
x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1))
else: # NLC -> NL1C
x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2]))
if data_format == "NCHW":
n, c, h, w = x.shape
x_square = x * x
......@@ -94,7 +111,6 @@ def _reference_training(x, scale, offset, epsilon, data_format):
offset_tile = np.reshape(offset, (1, c, 1, 1))
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
y = normalized * scale_tile + offset_tile
return y, mean, var
elif data_format == "NHWC":
x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2))
......@@ -104,10 +120,13 @@ def _reference_training(x, scale, offset, epsilon, data_format):
var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon)
y = normalized * scale + offset
return y, mean, var
else:
raise ValueError("Unknown data order.")
if len(x_shape) == 3:
y = np.reshape(y, x_shape)
return y, mean, var
def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format):
# Use the following formulas to calculate gradients:
......@@ -124,6 +143,15 @@ def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format):
if data_format != "NCHW" and data_format != "NHWC":
raise ValueError("Unknown data order.")
x_shape = x.shape
if len(x_shape) == 3:
if data_format == "NCHW": # NCL -> NCL1
x = np.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1))
y_grad = np.reshape(y_grad, (x_shape[0], x_shape[1], x_shape[2], 1))
else: # NLC -> NL1C
x = np.reshape(x, (x_shape[0], x_shape[1], 1, x_shape[2]))
y_grad = np.reshape(y_grad, (x_shape[0], x_shape[1], 1, x_shape[2]))
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
y_grad = np.transpose(y_grad, (0, 2, 3, 1))
......@@ -142,6 +170,9 @@ def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format):
x = np.transpose(x, (0, 3, 1, 2))
y_grad = np.transpose(y_grad, (0, 3, 1, 2))
if len(x_shape) == 3:
x_grad = np.reshape(x_grad, x_shape)
return x_grad, grad_scale, grad_offset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册