diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 4bc361102845824f0161922824a9207e1073cbcf..9bdfef4029e49d4967bcbcdf1bf77fa8b1479e8a 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1024,6 +1024,8 @@ class BatchNorm(layers.Layer): or is_test to true, and the behavior is equivalent. In train mode, when setting use_global_stats True, the global mean and variance are also used during train period. + trainable_statistics(bool, Default False): Whether to calculate mean and var in eval mode. In eval mode, when + setting trainable_statistics True, mean and variance will be calculated by current batch statistics. Returns: Variable: A tensor variable which is the result after applying batch normalization on the input. @@ -1053,7 +1055,8 @@ class BatchNorm(layers.Layer): moving_variance_name=None, do_model_average_for_mean_and_var=False, fuse_with_relu=False, - use_global_stats=False): + use_global_stats=False, + trainable_statistics=False): super(BatchNorm, self).__init__(name_scope, dtype) self._param_attr = param_attr self._bias_attr = bias_attr @@ -1111,6 +1114,7 @@ class BatchNorm(layers.Layer): self._is_test = is_test self._fuse_with_relu = fuse_with_relu self._use_global_stats = use_global_stats + self._trainable_statistics = trainable_statistics def _build_once(self, input): pass @@ -1151,7 +1155,8 @@ class BatchNorm(layers.Layer): "is_test": self._is_test, "use_mkldnn": False, "fuse_with_relu": self._fuse_with_relu, - "use_global_stats": self._use_global_stats + "use_global_stats": self._use_global_stats, + "trainable_statistics": self._trainable_statistics }) # Currently, we don't support inplace in dygraph mode diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index ab2d1d4049709c80ed3c4bea682155d8462b2750..53c9c83b2c473e6a56a661aa57c69266c823eab3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1677,7 +1677,11 @@ class Block(object): attrs = kwargs.get("attrs", {}) if _dygraph_tracer_._train_mode == False: # eval mode - attrs['is_test'] = True + if ('trainable_statistics' not in attrs + ) or not attrs['trainable_statistics']: + attrs['is_test'] = True + else: + attrs['is_test'] = False op = Operator( block=self,