diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 3ead3fc295c0b2e8772b16e2aeb3a4fd1f2be75a..c46a53e910df079bc9eddc655010084b26f09684 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -106,6 +106,12 @@ def summary(net, input_size, dtypes=None): warnings.warn( "Your model was created in static mode, this may not get correct summary information!" ) + in_train_mode = False + else: + in_train_mode = net.training + + if in_train_mode: + net.eval() def _is_shape(shape): for item in shape: @@ -143,9 +149,13 @@ def summary(net, input_size, dtypes=None): result, params_info = summary_string(net, _input_size, dtypes) print(result) + if in_train_mode: + net.train() + return params_info +@paddle.no_grad() def summary_string(model, input_size, dtypes=None): def _all_is_numper(items): for item in items: