提交 545afb2d 编写于 作者: X xiaoting 提交者: Hongyu Liu

Add trainable_statist attr for bn in dygraph (#17881)

* add import, test=develop

* fix fill_constant

* fix deconv

* add trainable_statist for bn in dygraph
上级 c1379bf2
......@@ -1024,6 +1024,8 @@ class BatchNorm(layers.Layer):
or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period.
trainable_statistics(bool, Default False): Whether to calculate mean and var in eval mode. In eval mode, when
setting trainable_statistics True, mean and variance will be calculated by current batch statistics.
Returns:
Variable: A tensor variable which is the result after applying batch normalization on the input.
......@@ -1053,7 +1055,8 @@ class BatchNorm(layers.Layer):
moving_variance_name=None,
do_model_average_for_mean_and_var=False,
fuse_with_relu=False,
use_global_stats=False):
use_global_stats=False,
trainable_statistics=False):
super(BatchNorm, self).__init__(name_scope, dtype)
self._param_attr = param_attr
self._bias_attr = bias_attr
......@@ -1111,6 +1114,7 @@ class BatchNorm(layers.Layer):
self._is_test = is_test
self._fuse_with_relu = fuse_with_relu
self._use_global_stats = use_global_stats
self._trainable_statistics = trainable_statistics
def _build_once(self, input):
pass
......@@ -1151,7 +1155,8 @@ class BatchNorm(layers.Layer):
"is_test": self._is_test,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats
"use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics
})
# Currently, we don't support inplace in dygraph mode
......
......@@ -1677,7 +1677,11 @@ class Block(object):
attrs = kwargs.get("attrs", {})
if _dygraph_tracer_._train_mode == False:
# eval mode
attrs['is_test'] = True
if ('trainable_statistics' not in attrs
) or not attrs['trainable_statistics']:
attrs['is_test'] = True
else:
attrs['is_test'] = False
op = Operator(
block=self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册