diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 8aca31921808552b6ca7d905911dae06ce323037..1a5fc109805e05ba71aca967b599b680fa302c9c 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -186,24 +186,25 @@ def batch_norm(x, else: trainable_statistics = not use_global_stats - if _non_static_mode(): - if in_dygraph_mode(): - batch_norm_out, _, _, _, _, _ = _C_ops.final_state_batch_norm( - x, weight, bias, running_mean, running_var, momentum, epsilon, - data_format, not training, use_global_stats, - trainable_statistics, False) - - elif _in_legacy_dygraph(): - # for dygraph need tuple - attrs = ("momentum", momentum, "epsilon", epsilon, "is_test", - not training, "data_layout", data_format, "use_mkldnn", - False, "fuse_with_relu", False, "use_global_stats", - use_global_stats, "trainable_statistics", - trainable_statistics) - - batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( - x, weight, bias, running_mean, running_var, None, mean_out, - variance_out, *attrs) + if in_dygraph_mode(): + batch_norm_out, _, _, _, _, _ = _C_ops.final_state_batch_norm( + x, weight, bias, running_mean, running_var, momentum, epsilon, + data_format, not training, use_global_stats, trainable_statistics, + False) + + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=None) + + elif _in_legacy_dygraph(): + # for dygraph need tuple + attrs = ("momentum", momentum, "epsilon", epsilon, "is_test", + not training, "data_layout", data_format, "use_mkldnn", False, + "fuse_with_relu", False, "use_global_stats", use_global_stats, + "trainable_statistics", trainable_statistics) + + batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( + x, weight, bias, running_mean, running_var, None, mean_out, + variance_out, *attrs) return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=None)