未验证 提交 c75b091b 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] fix sync batch norm to inplace (#45028)

* fix sync batch norm to inplace
上级 05f7d0c5
...@@ -2560,7 +2560,7 @@ ...@@ -2560,7 +2560,7 @@
backward : swish_grad backward : swish_grad
# sync_batch_norm # sync_batch_norm
- api : sync_batch_norm - api : sync_batch_norm_
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta : infer_meta :
...@@ -2569,6 +2569,7 @@ ...@@ -2569,6 +2569,7 @@
func : sync_batch_norm func : sync_batch_norm
data_type : x data_type : x
backward : sync_batch_norm_grad backward : sync_batch_norm_grad
inplace : (mean -> mean_out), (variance -> variance_out)
# take_along_axis # take_along_axis
- api : take_along_axis - api : take_along_axis
......
...@@ -1110,7 +1110,7 @@ class SyncBatchNorm(_BatchNormBase): ...@@ -1110,7 +1110,7 @@ class SyncBatchNorm(_BatchNormBase):
### train mode: use mini-batch stats, eval mode: use global stats ### train mode: use mini-batch stats, eval mode: use global stats
### use_global_stats only support False in sync_batch_norm ### use_global_stats only support False in sync_batch_norm
if in_dygraph_mode(): if in_dygraph_mode():
sync_batch_norm_out, _, _, _, _, _ = _C_ops.final_state_sync_batch_norm( sync_batch_norm_out, _, _, _, _, _ = _C_ops.final_state_sync_batch_norm_(
x, self.weight, self.bias, self._mean, self._variance, x, self.weight, self.bias, self._mean, self._variance,
self._momentum, self._epsilon, self._data_format, self._momentum, self._epsilon, self._data_format,
not self.training, False, False, False) not self.training, False, False, False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册