From 5577f4117865921175eeb4cf9fd3759747a929a7 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Wed, 6 Apr 2022 14:36:16 +0800 Subject: [PATCH] [Eager] Remove non static mode (#41422) * [Eager] Support test_layers's test cases switch to eager mode * Update batch_norm _C_ops action to fix CI * Use None instead of new EmptyTensor * Updated var name * Make sure to switch eager mode, Fix Coverage_CI * Remove _non_static_mode statement * Remove batch_norm dispensable input statement * Polish batch_norm code * Fix CI issue * Remove _non_static_mode() --- python/paddle/nn/functional/norm.py | 37 +++++++++++++++-------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 8aca319218..1a5fc10980 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -186,24 +186,25 @@ def batch_norm(x, else: trainable_statistics = not use_global_stats - if _non_static_mode(): - if in_dygraph_mode(): - batch_norm_out, _, _, _, _, _ = _C_ops.final_state_batch_norm( - x, weight, bias, running_mean, running_var, momentum, epsilon, - data_format, not training, use_global_stats, - trainable_statistics, False) - - elif _in_legacy_dygraph(): - # for dygraph need tuple - attrs = ("momentum", momentum, "epsilon", epsilon, "is_test", - not training, "data_layout", data_format, "use_mkldnn", - False, "fuse_with_relu", False, "use_global_stats", - use_global_stats, "trainable_statistics", - trainable_statistics) - - batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( - x, weight, bias, running_mean, running_var, None, mean_out, - variance_out, *attrs) + if in_dygraph_mode(): + batch_norm_out, _, _, _, _, _ = _C_ops.final_state_batch_norm( + x, weight, bias, running_mean, running_var, momentum, epsilon, + data_format, not training, use_global_stats, trainable_statistics, + False) + + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=None) + + elif _in_legacy_dygraph(): + # for dygraph need tuple + attrs = ("momentum", momentum, "epsilon", epsilon, "is_test", + not training, "data_layout", data_format, "use_mkldnn", False, + "fuse_with_relu", False, "use_global_stats", use_global_stats, + "trainable_statistics", trainable_statistics) + + batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( + x, weight, bias, running_mean, running_var, None, mean_out, + variance_out, *attrs) return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=None) -- GitLab