From 545afb2d740b31f42753b75fc39cf0c531e80acb Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Thu, 6 Jun 2019 20:22:00 +0800 Subject: [PATCH] Add trainable_statist attr for bn in dygraph (#17881) * add import, test=develop * fix fill_constant * fix deconv * add trainable_statist for bn in dygraph --- python/paddle/fluid/dygraph/nn.py | 9 +++++++-- python/paddle/fluid/framework.py | 6 +++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 4bc3611028..9bdfef4029 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1024,6 +1024,8 @@ class BatchNorm(layers.Layer): or is_test to true, and the behavior is equivalent. In train mode, when setting use_global_stats True, the global mean and variance are also used during train period. + trainable_statistics(bool, Default False): Whether to calculate mean and var in eval mode. In eval mode, when + setting trainable_statistics True, mean and variance will be calculated by current batch statistics. Returns: Variable: A tensor variable which is the result after applying batch normalization on the input. @@ -1053,7 +1055,8 @@ class BatchNorm(layers.Layer): moving_variance_name=None, do_model_average_for_mean_and_var=False, fuse_with_relu=False, - use_global_stats=False): + use_global_stats=False, + trainable_statistics=False): super(BatchNorm, self).__init__(name_scope, dtype) self._param_attr = param_attr self._bias_attr = bias_attr @@ -1111,6 +1114,7 @@ class BatchNorm(layers.Layer): self._is_test = is_test self._fuse_with_relu = fuse_with_relu self._use_global_stats = use_global_stats + self._trainable_statistics = trainable_statistics def _build_once(self, input): pass @@ -1151,7 +1155,8 @@ class BatchNorm(layers.Layer): "is_test": self._is_test, "use_mkldnn": False, "fuse_with_relu": self._fuse_with_relu, - "use_global_stats": self._use_global_stats + "use_global_stats": self._use_global_stats, + "trainable_statistics": self._trainable_statistics }) # Currently, we don't support inplace in dygraph mode diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index ab2d1d4049..53c9c83b2c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1677,7 +1677,11 @@ class Block(object): attrs = kwargs.get("attrs", {}) if _dygraph_tracer_._train_mode == False: # eval mode - attrs['is_test'] = True + if ('trainable_statistics' not in attrs + ) or not attrs['trainable_statistics']: + attrs['is_test'] = True + else: + attrs['is_test'] = False op = Operator( block=self, -- GitLab