未验证 提交 731d45a3 编写于 作者: Q qingqing01 提交者: GitHub

Enable BatchNorm to use global mean and variane during training (#14630)

* Enable BatchNorm to use global mean and variane during training
* Update doc and follow comments.
上级 400cf19f
...@@ -69,7 +69,7 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'] ...@@ -69,7 +69,7 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name']
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)) paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False)) paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
......
...@@ -146,7 +146,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -146,7 +146,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum"); const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu"); const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
bool global_stats = is_test || use_global_stats;
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean"); const auto *mean = ctx.Input<Tensor>("Mean");
...@@ -177,13 +179,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -177,13 +179,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
T *batch_mean_data = nullptr; T *batch_mean_data = nullptr;
T *batch_variance_data = nullptr; T *batch_variance_data = nullptr;
if (!is_test) { if (!global_stats) {
batch_mean_data = batch_mean->mutable_data<T>(ctx.GetPlace()); batch_mean_data = batch_mean->mutable_data<T>(ctx.GetPlace());
batch_variance_data = batch_variance->mutable_data<T>(ctx.GetPlace()); batch_variance_data = batch_variance->mutable_data<T>(ctx.GetPlace());
} }
auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring auto propagation = global_stats == true
: mkldnn::prop_kind::forward_training; ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto src_tz = paddle::framework::vectorize2int(x->dims()); auto src_tz = paddle::framework::vectorize2int(x->dims());
auto scale_tz = paddle::framework::vectorize2int(scale->dims()); auto scale_tz = paddle::framework::vectorize2int(scale->dims());
...@@ -199,7 +202,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -199,7 +202,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
shift->data<T>() + ic, &scaleshift_data); shift->data<T>() + ic, &scaleshift_data);
unsigned flags = mkldnn::use_scale_shift; unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats; if (global_stats) flags |= mkldnn::use_global_stats;
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
...@@ -208,7 +211,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -208,7 +211,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// keys for backward pass // keys for backward pass
const std::string key = BatchNormMKLDNNHandler::GetHash( const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, is_test, input_format, src_tz, epsilon, flags, global_stats, input_format,
ctx.op().Output("SavedMean")); ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
...@@ -239,7 +242,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -239,7 +242,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data); batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data);
std::shared_ptr<batch_norm_fwd> batch_norm_p; std::shared_ptr<batch_norm_fwd> batch_norm_p;
if (is_test) { if (global_stats) {
// create mkldnn memory for stats (as input) // create mkldnn memory for stats (as input)
std::shared_ptr<memory> mean_memory = std::shared_ptr<memory> mean_memory =
handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data)); handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data));
...@@ -269,7 +272,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -269,7 +272,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*batch_norm_p); pipeline.push_back(*batch_norm_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
if (!is_test) { if (!global_stats) {
// mkldnn only compute stats for current batch // mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib // so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(batch_mean_data, ic); EigenVectorArrayMap<T> batch_mean_e(batch_mean_data, ic);
......
...@@ -159,6 +159,14 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -159,6 +159,14 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("fuse_with_relu", AddAttr<bool>("fuse_with_relu",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_global_stats",
"(bool, default false) Whether to use global mean and "
"variance. In inference or test mode, set use_global_stats "
"to true or is_test true. the behavior is equivalent. "
"In train mode, when setting use_global_stats True, the "
"global mean and variance are also used during train time, "
"the BN acts as scaling and shiffting.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Batch Normalization. Batch Normalization.
...@@ -190,6 +198,10 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -190,6 +198,10 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum"); const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
bool global_stats = is_test || use_global_stats;
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
...@@ -217,7 +229,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -217,7 +229,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_mean->mutable_data<T>(ctx.GetPlace()); saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace()); saved_variance->mutable_data<T>(ctx.GetPlace());
if (!is_test) { if (!global_stats) {
// saved_xx is use just in this batch of data // saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e( EigenVectorArrayMap<T> saved_mean_e(
saved_mean->mutable_data<T>(ctx.GetPlace()), C); saved_mean->mutable_data<T>(ctx.GetPlace()), C);
...@@ -234,7 +246,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -234,7 +246,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
if ((N * sample_size) == 1) { if ((N * sample_size) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, " LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x."; << "we skip the batch norm calculation, let y = x.";
framework::TensorCopySync(*x, ctx.GetPlace(), y); framework::TensorCopy(*x, ctx.GetPlace(), y);
return; return;
} }
...@@ -277,7 +289,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -277,7 +289,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
// use SavedMean and SavedVariance to do normalize // use SavedMean and SavedVariance to do normalize
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C); Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C);
if (is_test) { if (global_stats) {
ConstEigenVectorArrayMap<T> var_arr( ConstEigenVectorArrayMap<T> var_arr(
ctx.Input<Tensor>("Variance")->data<T>(), C); ctx.Input<Tensor>("Variance")->data<T>(), C);
inv_std = (var_arr + epsilon).sqrt().inverse(); inv_std = (var_arr + epsilon).sqrt().inverse();
...@@ -289,8 +301,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -289,8 +301,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
inv_std = saved_inv_std; inv_std = saved_inv_std;
} }
ConstEigenVectorArrayMap<T> mean_arr( ConstEigenVectorArrayMap<T> mean_arr(
is_test ? ctx.Input<Tensor>("Mean")->data<T>() global_stats ? ctx.Input<Tensor>("Mean")->data<T>()
: ctx.Output<Tensor>("SavedMean")->data<T>(), : ctx.Output<Tensor>("SavedMean")->data<T>(),
C); C);
// ((x - est_mean) * (inv_var) * scale + bias // ((x - est_mean) * (inv_var) * scale + bias
...@@ -336,15 +348,27 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -336,15 +348,27 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// check input // check input
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); PADDLE_ENFORCE(ctx->HasInput("Scale"), "Input(scale) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
PADDLE_ENFORCE(ctx->HasInput("SavedMean"), ""); "Input(Y@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("SavedVariance"), ""); PADDLE_ENFORCE(ctx->HasInput("SavedMean"),
"Input(SavedMean) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("SavedVariance"),
"Input(SavedVariance) should not be null");
// check output // check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Scale")), ""); if (ctx->HasOutput(framework::GradVarName("Scale"))) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Output(Scale@GRAD) and Output(Bias@GRAD) should not be "
"null at same time");
}
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"),
"Using global stats during training is not supported "
"in gradient op kernel of batch_norm_mkldnn_op now.");
}
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
...@@ -354,8 +378,10 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -354,8 +378,10 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); if (ctx->HasOutput(framework::GradVarName("Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
} }
protected: protected:
...@@ -405,6 +431,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -405,6 +431,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
// SavedVariance have been reverted in forward operator // SavedVariance have been reverted in forward operator
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance"); const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
...@@ -419,38 +447,60 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -419,38 +447,60 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C; const int sample_size = x->numel() / N / C;
ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
ConstEigenVectorArrayMap<T> mean_arr(saved_mean->data<T>(), C);
ConstEigenVectorArrayMap<T> inv_var_arr(saved_inv_variance->data<T>(), C);
// init output // init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace()); const T *mean_data = saved_mean->data<T>();
const T *inv_var_data = saved_inv_variance->data<T>();
Tensor inv_var_tensor;
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
mean_data = running_mean->data<T>();
T *running_inv_var_data = inv_var_tensor.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> inv_var_tmp(running_inv_var_data, C);
ConstEigenVectorArrayMap<T> var_arr(running_variance->data<T>(), C);
inv_var_tmp = (var_arr + epsilon).sqrt().inverse().eval();
inv_var_data = running_inv_var_data;
}
ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);
T *d_bias_data = nullptr;
T *d_scale_data = nullptr;
if (d_scale && d_bias) {
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
d_bias_data = d_bias->mutable_data<T>(ctx.GetPlace());
d_scale_data = d_scale->mutable_data<T>(ctx.GetPlace());
}
// d_bias = np.sum(d_y, axis=0) // d_bias = np.sum(d_y, axis=0)
// d_scale = np.sum((X - mean) / inv_std * dy, axis=0) // d_scale = np.sum((X - mean) / inv_std * dy, axis=0)
// d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0) // d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0)
// - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0)) // - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0))
EigenVectorArrayMap<T> d_bias_arr(d_bias_data, C);
EigenVectorArrayMap<T> d_scale_arr(d_scale_data, C);
EigenVectorArrayMap<T> d_bias_arr(d_bias->mutable_data<T>(ctx.GetPlace()), if (d_scale && d_bias) {
C); d_bias_arr.setZero();
EigenVectorArrayMap<T> d_scale_arr(d_scale->mutable_data<T>(ctx.GetPlace()), d_scale_arr.setZero();
C); }
d_bias_arr.setZero();
d_scale_arr.setZero();
if ((N * sample_size) == 1) { if ((N * sample_size) == 1 && !use_global_stats) {
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x); framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
return; return;
} }
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size); int scale_coefff = use_global_stats ? 1 : N * sample_size;
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / scale_coefff;
switch (data_layout) { switch (data_layout) {
case DataLayout::kNCHW: { case DataLayout::kNCHW: {
...@@ -460,19 +510,29 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -460,19 +510,29 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
sample_size, N * C); sample_size, N * C);
d_x_arr.setZero(); d_x_arr.setZero();
for (int nc = 0; nc < N * C; ++nc) { if (d_scale && d_bias) {
int c = nc % C; for (int nc = 0; nc < N * C; ++nc) {
d_bias_arr(c) += d_y_arr.col(nc).sum(); int c = nc % C;
d_scale_arr(c) += d_bias_arr(c) += d_y_arr.col(nc).sum();
((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * d_y_arr.col(nc)) d_scale_arr(c) += ((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) *
.sum(); d_y_arr.col(nc))
.sum();
}
} }
for (int nc = 0; nc < N * C; ++nc) { if (!use_global_stats) {
int c = nc % C; for (int nc = 0; nc < N * C; ++nc) {
d_x_arr.col(nc) += int c = nc % C;
scale_inv_var_nhw(c) * d_x_arr.col(nc) +=
(d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) - scale_inv_var_nhw(c) *
(x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) * inv_var_arr(c)); (d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) -
(x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) *
inv_var_arr(c));
}
} else {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) += scale_inv_var_nhw(c) * d_y_arr.col(nc);
}
} }
break; break;
} }
...@@ -488,15 +548,27 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -488,15 +548,27 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const auto d_y_mul_x_minus_mean_row_sum = const auto d_y_mul_x_minus_mean_row_sum =
(d_y_arr * x_minus_mean).rowwise().sum(); (d_y_arr * x_minus_mean).rowwise().sum();
const auto inv_var_sqr = inv_var_arr * inv_var_arr; const auto inv_var_sqr = inv_var_arr * inv_var_arr;
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_bias_arr += d_y_arr.col(nhw); if (d_scale && d_bias) {
d_scale_arr += for (int nhw = 0; nhw < N * sample_size; ++nhw) {
(x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw); d_bias_arr += d_y_arr.col(nhw);
d_x_arr.col(nhw) += d_scale_arr +=
scale_inv_var_nhw * (x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw);
(d_y_arr.col(nhw) * N * sample_size - d_y_row_sum - }
x_minus_mean.col(nhw) * inv_var_sqr * }
d_y_mul_x_minus_mean_row_sum);
if (!use_global_stats) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) +=
scale_inv_var_nhw *
(d_y_arr.col(nhw) * N * sample_size - d_y_row_sum -
x_minus_mean.col(nhw) * inv_var_sqr *
d_y_mul_x_minus_mean_row_sum);
}
} else {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) += scale_inv_var_nhw * d_y_arr.col(nhw);
}
} }
break; break;
} }
...@@ -522,6 +594,10 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker { ...@@ -522,6 +594,10 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput("SavedMean", Output("SavedMean")); op->SetInput("SavedMean", Output("SavedMean"));
op->SetInput("SavedVariance", Output("SavedVariance")); op->SetInput("SavedVariance", Output("SavedVariance"));
// used when setting use_global_stats True during training
op->SetInput("Mean", Output("MeanOut"));
op->SetInput("Variance", Output("VarianceOut"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
......
...@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h" #include <algorithm>
#include <cfloat> #include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -59,6 +63,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -59,6 +63,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum"); const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
...@@ -121,7 +126,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -121,7 +126,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths. // Now, depending on whether we are running test or not, we have two paths.
if (is_test) { if (is_test || use_global_stats) {
// only when test we use input to do computation. // only when test we use input to do computation.
const auto *est_mean = ctx.Input<Tensor>("Mean"); const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance"); const auto *est_var = ctx.Input<Tensor>("Variance");
...@@ -163,7 +168,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -163,7 +168,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
if ((N * H * W * D) == 1) { if ((N * H * W * D) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, " LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x."; << "we skip the batch norm calculation, let y = x.";
framework::TensorCopySync(*x, ctx.GetPlace(), y); framework::TensorCopy(*x, ctx.GetPlace(), y);
} else { } else {
double this_factor = 1. - momentum; double this_factor = 1. - momentum;
...@@ -191,6 +196,58 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -191,6 +196,58 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
} }
}; };
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *variance,
const double epsilon, const int C,
const int HxW, const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> inv_var = 1.0 / sqrt(variance[c] + epsilon);
dx[i] = static_cast<T>(static_cast<BatchNormParamType<T>>(dy[i]) *
scale[c] * inv_var);
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(
const T *dy, const T *x, const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance, const double epsilon, const int N,
const int C, const int HxW, BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
BatchNormParamType<T> mean_i = mean[i];
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
ds_sum += static_cast<BatchNormParamType<T>>(dy[index]) *
(static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
db_sum += static_cast<BatchNormParamType<T>>(dy[index]);
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
dscale[i] = ds_sum * inv_var_i;
dbias[i] = db_sum;
}
__syncthreads();
}
}
template <typename T> template <typename T>
class BatchNormGradKernel<platform::CUDADeviceContext, T> class BatchNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -200,6 +257,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -200,6 +257,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
"It must use CUDAPlace."); "It must use CUDAPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
...@@ -219,42 +278,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -219,42 +278,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); if (d_scale && d_bias) {
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if ((N * H * W * D) == 1) {
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
return;
} }
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C); PADDLE_ENFORCE_EQ(scale->dims()[0], C);
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
std::vector<int> dims; std::vector<int> dims;
std::vector<int> strides; std::vector<int> strides;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
...@@ -264,34 +294,114 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -264,34 +294,114 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
dims = {N, C, H, W, D}; dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C}; strides = {H * W * C * D, 1, W * D * C, D * C, C};
} }
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data =
saved_mean->template data<BatchNormParamType<T>>();
const void *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
// clean when exit. auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); if (!use_global_stats) {
CUDNN_ENFORCE( if ((N * H * W * D) == 1) {
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
return;
}
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data =
saved_mean->template data<BatchNormParamType<T>>();
const void *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
// clean when exit.
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
} else {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_var = ctx.Input<Tensor>("Variance");
const auto *running_mean_data =
running_mean->template data<BatchNormParamType<T>>();
const auto *running_var_data =
running_var->template data<BatchNormParamType<T>>();
const int num = x->numel();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
if (data_layout == framework::DataLayout::kNCHW) {
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNCHW><<<
grid1, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
running_var_data, epsilon, C, H * W, num, d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
epsilon, C, H * W, num, d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNHWC><<<
grid1, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
running_var_data, epsilon, C, H * W, num, d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
epsilon, C, H * W, num, d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
}
}
} }
}; };
......
...@@ -2300,7 +2300,8 @@ def batch_norm(input, ...@@ -2300,7 +2300,8 @@ def batch_norm(input,
moving_mean_name=None, moving_mean_name=None,
moving_variance_name=None, moving_variance_name=None,
do_model_average_for_mean_and_var=False, do_model_average_for_mean_and_var=False,
fuse_with_relu=False): fuse_with_relu=False,
use_global_stats=False):
""" """
**Batch Normalization Layer** **Batch Normalization Layer**
...@@ -2327,6 +2328,19 @@ def batch_norm(input, ...@@ -2327,6 +2328,19 @@ def batch_norm(input,
\\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift
When use_global_stats = True, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch.
They are global (or running) statistics. (It usually got from the
pre-trained model.)
The training and testing (or inference) have the same behavior:
.. math::
\\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\
\\sigma_{\\beta}^{2} + \\epsilon}} \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta
Args: Args:
input(variable): The input variable which is a LoDTensor. input(variable): The input variable which is a LoDTensor.
act(string, Default None): Activation type, linear|relu|prelu|... act(string, Default None): Activation type, linear|relu|prelu|...
...@@ -2349,6 +2363,11 @@ def batch_norm(input, ...@@ -2349,6 +2363,11 @@ def batch_norm(input,
moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance.
do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not. do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not.
fuse_with_relu (bool): if True, this OP performs relu after batch norm. fuse_with_relu (bool): if True, this OP performs relu after batch norm.
use_global_stats(bool, Default False): Whether to use global mean and
variance. In inference or test mode, set use_global_stats to true
or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period.
Returns: Returns:
Variable: A tensor variable which is the result after applying batch normalization on the input. Variable: A tensor variable which is the result after applying batch normalization on the input.
...@@ -2381,9 +2400,15 @@ def batch_norm(input, ...@@ -2381,9 +2400,15 @@ def batch_norm(input,
shape=param_shape, shape=param_shape,
dtype=dtype, dtype=dtype,
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
# setting stop_gradient=True to reduce computation
if use_global_stats and helper.param_attr.learning_rate == 0.:
scale.stop_gradient = True
bias = helper.create_parameter( bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
# setting stop_gradient=True to reduce computation
if use_global_stats and helper.bias_attr.learning_rate == 0.:
scale.stop_gradient = True
mean = helper.create_parameter( mean = helper.create_parameter(
attr=ParamAttr( attr=ParamAttr(
...@@ -2439,7 +2464,8 @@ def batch_norm(input, ...@@ -2439,7 +2464,8 @@ def batch_norm(input,
"epsilon": epsilon, "epsilon": epsilon,
"is_test": is_test, "is_test": is_test,
"use_mkldnn": False, "use_mkldnn": False,
"fuse_with_relu": fuse_with_relu "fuse_with_relu": fuse_with_relu,
"use_global_stats": use_global_stats
}) })
return helper.append_activation(batch_norm_out) return helper.append_activation(batch_norm_out)
......
...@@ -54,6 +54,19 @@ def _reference_testing(x, scale, offset, mean, var, epsilon, data_format): ...@@ -54,6 +54,19 @@ def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
return y return y
def _cal_mean_variance(x, epsilon, data_format):
assert data_format in ['NCHW', 'NHWC']
x_square = x * x
axis = (0, 2, 3) if data_format == 'NCHW' else (0, 1, 2)
C = x.shape[1] if data_format == 'NCHW' else x.shape[-1]
x_square_sum = np.sum(x_square, axis)
x_sum = np.sum(x, axis=axis)
element_count = np.size(x) / C
mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean
return mean, var
def _reference_training(x, scale, offset, epsilon, data_format): def _reference_training(x, scale, offset, epsilon, data_format):
x_shape = x.shape x_shape = x.shape
...@@ -294,7 +307,18 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -294,7 +307,18 @@ class TestBatchNormOpTraining(unittest.TestCase):
self.use_mkldnn = False self.use_mkldnn = False
self.fuse_with_relu = False self.fuse_with_relu = False
self.data_formats = ["NCHW", "NHWC"] self.data_formats = ["NCHW", "NHWC"]
self.momentum = 0.9
self.epsilon = 0.00001
self.init_kernel_type() self.init_kernel_type()
self.init_test_case()
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set()
self.fetch_list = [
'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD',
'scale@GRAD', 'bias@GRAD'
]
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
np.allclose(np.array(tensor), np_array, atol=atol) np.allclose(np.array(tensor), np_array, atol=atol)
...@@ -313,11 +337,22 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -313,11 +337,22 @@ class TestBatchNormOpTraining(unittest.TestCase):
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
def set_mean_variance(self, scale_shape, x, data_layout):
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
# computing global mean/variance for one step
if self.use_global_stats:
mom = self.momentum
x_mean, x_var = _cal_mean_variance(x, self.epsilon, data_layout)
mean = x_mean * (1. - mom) + mom * mean
variance = x_var * (1. - mom) + mom * variance
return mean, variance
def test_forward_backward(self): def test_forward_backward(self):
def test_with_place(place, data_layout, shape): def test_with_place(place, data_layout, shape):
# attr # attr
epsilon = 0.00001 epsilon = self.epsilon
momentum = 0.9 momentum = self.momentum
if data_layout == "NCHW": if data_layout == "NCHW":
n, c, h, w = shape[0], shape[1], shape[2], shape[3] n, c, h, w = shape[0], shape[1], shape[2], shape[3]
else: else:
...@@ -328,9 +363,7 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -328,9 +363,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
x = np.random.random_sample(shape).astype(np.float32) x = np.random.random_sample(shape).astype(np.float32)
scale = np.random.random_sample(scale_shape).astype(np.float32) scale = np.random.random_sample(scale_shape).astype(np.float32)
bias = np.random.random_sample(scale_shape).astype(np.float32) bias = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32) mean, variance = self.set_mean_variance(scale_shape, x, data_layout)
variance = np.ones(scale_shape).astype(np.float32)
y_grad = np.random.random_sample(shape).astype(np.float32) y_grad = np.random.random_sample(shape).astype(np.float32)
y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward( y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward(
...@@ -339,6 +372,9 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -339,6 +372,9 @@ class TestBatchNormOpTraining(unittest.TestCase):
var_dict = locals() var_dict = locals()
var_dict['y@GRAD'] = y_grad var_dict['y@GRAD'] = y_grad
var_dict['x@GRAD'] = x_grad
var_dict['scale@GRAD'] = scale_grad
var_dict['bias@GRAD'] = bias_grad
var_names = [ var_names = [
'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean', 'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean',
...@@ -365,9 +401,8 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -365,9 +401,8 @@ class TestBatchNormOpTraining(unittest.TestCase):
}, },
outputs={ outputs={
"Y": block.var('y'), "Y": block.var('y'),
"MeanOut": block.var('mean'), # share the same memory "MeanOut": block.var('mean'), # share memory
"VarianceOut": "VarianceOut": block.var('variance'), # share memory
block.var('variance'), # share the same memory
"SavedMean": block.var('saved_mean'), "SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance') "SavedVariance": block.var('saved_variance')
}, },
...@@ -377,13 +412,14 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -377,13 +412,14 @@ class TestBatchNormOpTraining(unittest.TestCase):
"is_test": False, "is_test": False,
"data_layout": data_layout, "data_layout": data_layout,
"use_mkldnn": self.use_mkldnn, "use_mkldnn": self.use_mkldnn,
"fuse_with_relu": self.fuse_with_relu "fuse_with_relu": self.fuse_with_relu,
"use_global_stats": self.use_global_stats
}) })
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape) block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
# generate backward op_desc # generate backward op_desc
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
bn_op.desc, set(), []) bn_op.desc, self.no_grad_set, [])
grad_op_desc = grad_op_desc_list[0] grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc) new_op_desc.copy_from(grad_op_desc)
...@@ -403,20 +439,10 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -403,20 +439,10 @@ class TestBatchNormOpTraining(unittest.TestCase):
for name in for name in
['x', 'scale', 'bias', 'mean', 'variance', 'y@GRAD'] ['x', 'scale', 'bias', 'mean', 'variance', 'y@GRAD']
}, },
fetch_list=[ fetch_list=self.fetch_list)
'y', 'mean', 'variance', 'saved_mean', 'saved_variance',
'x@GRAD', 'scale@GRAD', 'bias@GRAD'
])
self.__assert_close(y, out[0], "y")
self.__assert_close(mean_out, out[1], "mean")
self.__assert_close(variance_out, out[2], "variance", 1e-3)
self.__assert_close(saved_mean, out[3], "saved_mean")
self.__assert_close(saved_variance, out[4], "saved_variance", 1e-3)
self.__assert_close(x_grad, out[5], "x_grad")
self.__assert_close(scale_grad, out[6], "scale_grad")
self.__assert_close(bias_grad, out[7], "bias_grad")
for id, name in enumerate(self.fetch_list):
self.__assert_close(var_dict[name], out[id], name)
print("op test forward passed: ", str(place), data_layout) print("op test forward passed: ", str(place), data_layout)
places = [core.CPUPlace()] places = [core.CPUPlace()]
...@@ -432,5 +458,66 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -432,5 +458,66 @@ class TestBatchNormOpTraining(unittest.TestCase):
pass pass
class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = True
self.no_grad_set = set()
self.fetch_list = [
'y', 'mean', 'variance', 'x@GRAD', 'scale@GRAD', 'bias@GRAD'
]
def reference_grad(self, x, y_grad, scale, mean, var, epsilon, data_format):
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
y_grad = np.transpose(y_grad, (0, 2, 3, 1))
x_grad = scale * y_grad / np.sqrt(var + epsilon)
grad_scale = np.sum(y_grad * (x - mean) / np.sqrt(var + epsilon),
axis=(0, 1, 2))
grad_offset = np.sum(y_grad, axis=(0, 1, 2))
# transfer back to N, C, H, W
if data_format == "NCHW":
x_grad = np.transpose(x_grad, (0, 3, 1, 2))
x = np.transpose(x, (0, 3, 1, 2))
y_grad = np.transpose(y_grad, (0, 3, 1, 2))
return x_grad, grad_scale, grad_offset
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
epsilon, momentum, shape, data_layout):
if data_layout != "NCHW" and data_layout != "NHWC":
raise ValueError("Unknown data order.")
if data_layout == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
# run normalizaton
normalized = (x - mean) / np.sqrt(variance + epsilon)
y = normalized * scale + bias
# transfer back to N, C, H, W
if data_layout == "NCHW":
x = np.transpose(x, (0, 3, 1, 2))
y = np.transpose(y, (0, 3, 1, 2))
mean_out = mean
variance_out = variance
saved_variance = 1. / np.sqrt(variance + epsilon)
# run backward
x_grad, scale_grad, bias_grad = self.reference_grad(
x, y_grad, scale, mean, variance, epsilon, data_layout)
return y, mean_out, variance_out, mean, saved_variance, x_grad, scale_grad, bias_grad
class TestBatchNormOpFreezeStatsAndScaleBiasTraining(
TestBatchNormOpFreezeStatsTraining):
def init_test_case(self):
self.use_global_stats = True
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -955,6 +955,15 @@ class TestBook(unittest.TestCase): ...@@ -955,6 +955,15 @@ class TestBook(unittest.TestCase):
print(str(program)) print(str(program))
def test_batch_norm(self):
program = Program()
with program_guard(program):
data = layers.data(
name='data', shape=[32, 128, 128], dtype="float32")
out = layers.batch_norm(data)
print(str(program))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册