提交 545afb2d 编写于 作者: X xiaoting 提交者: Hongyu Liu

Add trainable_statist attr for bn in dygraph (#17881)

* add import, test=develop

* fix fill_constant

* fix deconv

* add trainable_statist for bn in dygraph
上级 c1379bf2
...@@ -1024,6 +1024,8 @@ class BatchNorm(layers.Layer): ...@@ -1024,6 +1024,8 @@ class BatchNorm(layers.Layer):
or is_test to true, and the behavior is equivalent. or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period. and variance are also used during train period.
trainable_statistics(bool, Default False): Whether to calculate mean and var in eval mode. In eval mode, when
setting trainable_statistics True, mean and variance will be calculated by current batch statistics.
Returns: Returns:
Variable: A tensor variable which is the result after applying batch normalization on the input. Variable: A tensor variable which is the result after applying batch normalization on the input.
...@@ -1053,7 +1055,8 @@ class BatchNorm(layers.Layer): ...@@ -1053,7 +1055,8 @@ class BatchNorm(layers.Layer):
moving_variance_name=None, moving_variance_name=None,
do_model_average_for_mean_and_var=False, do_model_average_for_mean_and_var=False,
fuse_with_relu=False, fuse_with_relu=False,
use_global_stats=False): use_global_stats=False,
trainable_statistics=False):
super(BatchNorm, self).__init__(name_scope, dtype) super(BatchNorm, self).__init__(name_scope, dtype)
self._param_attr = param_attr self._param_attr = param_attr
self._bias_attr = bias_attr self._bias_attr = bias_attr
...@@ -1111,6 +1114,7 @@ class BatchNorm(layers.Layer): ...@@ -1111,6 +1114,7 @@ class BatchNorm(layers.Layer):
self._is_test = is_test self._is_test = is_test
self._fuse_with_relu = fuse_with_relu self._fuse_with_relu = fuse_with_relu
self._use_global_stats = use_global_stats self._use_global_stats = use_global_stats
self._trainable_statistics = trainable_statistics
def _build_once(self, input): def _build_once(self, input):
pass pass
...@@ -1151,7 +1155,8 @@ class BatchNorm(layers.Layer): ...@@ -1151,7 +1155,8 @@ class BatchNorm(layers.Layer):
"is_test": self._is_test, "is_test": self._is_test,
"use_mkldnn": False, "use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu, "fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats "use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics
}) })
# Currently, we don't support inplace in dygraph mode # Currently, we don't support inplace in dygraph mode
......
...@@ -1677,7 +1677,11 @@ class Block(object): ...@@ -1677,7 +1677,11 @@ class Block(object):
attrs = kwargs.get("attrs", {}) attrs = kwargs.get("attrs", {})
if _dygraph_tracer_._train_mode == False: if _dygraph_tracer_._train_mode == False:
# eval mode # eval mode
attrs['is_test'] = True if ('trainable_statistics' not in attrs
) or not attrs['trainable_statistics']:
attrs['is_test'] = True
else:
attrs['is_test'] = False
op = Operator( op = Operator(
block=self, block=self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册