diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 9aeccde1256c1803c713cd4c7d28b4a4a7781095..684d2ef8628fdb453f1c97b5133ceed499bde170 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -33,7 +33,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BatchNorm"); bool is_test = ctx->Attrs().Get("is_test"); - if (!is_test) { + bool trainable_stats = ctx->Attrs().Get("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); + if (!test_mode) { OP_INOUT_CHECK(ctx->HasOutput("MeanOut"), "Output", "MeanOut", "BatchNorm"); OP_INOUT_CHECK(ctx->HasOutput("VarianceOut"), "Output", "VarianceOut", "BatchNorm"); @@ -258,7 +260,11 @@ void BatchNormOpMaker::Make() { "global mean and variance are also used during train time, " "the BN acts as scaling and shiffting.") .SetDefault(false); - + AddAttr("trainable_statistics", + "(bool, default false) Whether to calculate mean and variance " + "in test mode. If setting true in test mode, mean and variace " + "will be calculated by current batch statistics.") + .SetDefault(false); AddComment(R"DOC( Batch Normalization. @@ -281,8 +287,10 @@ class BatchNormKernel float momentum = ctx.Attr("momentum"); const bool is_test = ctx.Attr("is_test"); const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool trainable_stats = ctx.Attr("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); - bool global_stats = is_test || use_global_stats; + bool global_stats = test_mode || use_global_stats; const std::string data_layout_str = ctx.Attr("data_layout"); const DataLayout data_layout = diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 99612e4d4338fbb11ebb153df2446a38a9f3b9aa..e40049a51d194bf08ed31cb30bb9cc94c226d7f6 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -47,10 +47,13 @@ class BatchNormKernel float momentum = ctx.Attr("momentum"); const bool is_test = ctx.Attr("is_test"); const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool trainable_stats = ctx.Attr("trainable_statistics"); const std::string data_layout_str = ctx.Attr("data_layout"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + bool test_mode = is_test && (!trainable_stats); + // Get the size for each dimension. // NCHW [batch_size, in_channels, in_height, in_width] const auto *x = ctx.Input("X"); @@ -66,7 +69,7 @@ class BatchNormKernel auto dtype = platform::CudnnDataType::type; const bool fast_nhwc_batch_norm = - is_test || + test_mode || (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent); auto compute_format = @@ -133,7 +136,7 @@ class BatchNormKernel PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDeriveBNTensorDescriptor( bn_param_desc_, data_desc_, - is_test ? CUDNN_BATCHNORM_SPATIAL : mode_)); + test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_)); const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); @@ -143,7 +146,7 @@ class BatchNormKernel auto handle = dev_ctx.cudnn_handle(); // Now, depending on whether we are running test or not, we have two paths. - if (is_test || use_global_stats) { + if (test_mode || use_global_stats) { // only when test we use input to do computation. const auto *est_mean = ctx.Input("Mean"); const auto *est_var = ctx.Input("Variance"); diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index 919bdabe1b3abf4e75b19662d65c98d453504181..b7be0045258e7aafb64912f2cc75c9c9e05413b6 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -120,8 +120,10 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { const bool is_test = ctx.Attr("is_test"); const bool use_global_stats = ctx.Attr("use_global_stats"); const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); + const bool trainable_stats = ctx.Attr("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); - bool global_stats = is_test || use_global_stats; + bool global_stats = test_mode || use_global_stats; auto &dev_ctx = ctx.template device_context(); @@ -156,7 +158,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { auto flags = mkldnn::normalization_flags::use_scale_shift; // 001 if (global_stats) flags |= mkldnn::normalization_flags::use_global_stats; // 010 - if (fuse_with_relu && is_test) + if (fuse_with_relu && test_mode) flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100 BatchNormMKLDNNHandler handler( diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu index d79667cfdcbb37c99d1c6bab65fbea5830911d43..26fbe39a3c3691b6ce6414d75f6da216bc888017 100644 --- a/paddle/fluid/operators/sync_batch_norm_op.cu +++ b/paddle/fluid/operators/sync_batch_norm_op.cu @@ -28,6 +28,7 @@ class SyncBatchNormKernel const std::string layout_str = ctx.Attr("data_layout"); const DataLayout layout = framework::StringToDataLayout(layout_str); const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool trainable_stats = ctx.Attr("trainable_statistics"); PADDLE_ENFORCE_EQ(use_global_stats, false, platform::errors::InvalidArgument( "sync_batch_norm doesn't support " @@ -47,9 +48,10 @@ class SyncBatchNormKernel auto *saved_mean = ctx.Output("SavedMean"); auto *saved_inv_variance = ctx.Output("SavedVariance"); + bool test_mode = is_test && (!trainable_stats); SyncBatchNormFunctor( ctx, layout, x, y, est_mean, est_var, mean_out, variance_out, - saved_mean, saved_inv_variance, epsilon, momentum, is_test, + saved_mean, saved_inv_variance, epsilon, momentum, test_mode, use_global_stats); } }; diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 8f6b0a7d5a30ce84daf21fc8122f7e7439349a8f..e9139156a14919af1ab26872f27f75cc9d985989 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1276,12 +1276,12 @@ class BatchNorm(layers.Layer): variance_out = self._variance if in_dygraph_mode(): - _is_test = not self.training and not self._trainable_statistics attrs = ("momentum", self._momentum, "epsilon", self._epsilon, - "is_test", _is_test, "data_layout", self._data_layout, - "use_mkldnn", False, "fuse_with_relu", + "is_test", not self.training, "data_layout", + self._data_layout, "use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu, "use_global_stats", - self._use_global_stats) + self._use_global_stats, 'trainable_statistics', + self._trainable_statistics) batch_norm_out, _, _, _, _ = core.ops.batch_norm( input, self.weight, self.bias, self._mean, self._variance, mean_out, variance_out, *attrs) @@ -1298,7 +1298,8 @@ class BatchNorm(layers.Layer): "data_layout": self._data_layout, "use_mkldnn": False, "fuse_with_relu": self._fuse_with_relu, - "use_global_stats": self._use_global_stats + "use_global_stats": self._use_global_stats, + "trainable_statistics": self._trainable_statistics, } inputs = { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 487067d48ba4e3c0b49b14819bea1801036194ee..0ebf78884516a35a7a995e11944a52c39d370233 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2605,16 +2605,7 @@ class Block(object): """ if in_dygraph_mode(): attrs = kwargs.get("attrs", {}) - if _dygraph_tracer_._train_mode == False: - # eval mode - if ('trainable_statistics' not in attrs - ) or not attrs['trainable_statistics']: - attrs['is_test'] = True - else: - attrs['is_test'] = False - type = kwargs.get("type", None) - op = Operator( block=self, desc=None, diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 7beea896da0549526027fedfdba0bd1ddfa78ba5..a8c5b991b029192832c1efb33dec230a1929b871 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -623,5 +623,53 @@ class TestDygraphBatchNormAPIError(unittest.TestCase): self.assertRaises(TypeError, batch_norm, x2) +class TestDygraphBatchNormTrainableStats(unittest.TestCase): + def test_dygraph(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + shape = [4, 10, 4, 4] + + def compute(x, is_test, trainable_statistics): + with fluid.dygraph.guard(p): + bn = fluid.dygraph.BatchNorm( + shape[1], + is_test=is_test, + trainable_statistics=trainable_statistics) + y = bn(fluid.dygraph.to_variable(x)) + return y.numpy() + + x = np.random.randn(*shape).astype("float32") + y1 = compute(x, False, False) + y2 = compute(x, True, True) + self.assertTrue(np.allclose(y1, y2)) + + def test_static(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + exe = fluid.Executor(p) + shape = [4, 10, 16, 16] + + def compute(x_np, is_test, trainable_statistics): + with program_guard(Program(), Program()): + bn = fluid.dygraph.BatchNorm( + shape[1], + is_test=is_test, + trainable_statistics=trainable_statistics) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = bn(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + x = np.random.randn(*shape).astype("float32") + y1 = compute(x, False, False) + y2 = compute(x, True, True) + self.assertTrue(np.allclose(y1, y2)) + + if __name__ == '__main__': unittest.main()