未验证 提交 c75b091b 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] fix sync batch norm to inplace (#45028)

* fix sync batch norm to inplace
上级 05f7d0c5
......@@ -2560,7 +2560,7 @@
backward : swish_grad
# sync_batch_norm
- api : sync_batch_norm
- api : sync_batch_norm_
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta :
......@@ -2569,6 +2569,7 @@
func : sync_batch_norm
data_type : x
backward : sync_batch_norm_grad
inplace : (mean -> mean_out), (variance -> variance_out)
# take_along_axis
- api : take_along_axis
......
......@@ -1110,7 +1110,7 @@ class SyncBatchNorm(_BatchNormBase):
### train mode: use mini-batch stats, eval mode: use global stats
### use_global_stats only support False in sync_batch_norm
if in_dygraph_mode():
sync_batch_norm_out, _, _, _, _, _ = _C_ops.final_state_sync_batch_norm(
sync_batch_norm_out, _, _, _, _, _ = _C_ops.final_state_sync_batch_norm_(
x, self.weight, self.bias, self._mean, self._variance,
self._momentum, self._epsilon, self._data_format,
not self.training, False, False, False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册