From d84eb9b33f8751e50c73c07ed0d88379d9a406e9 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Fri, 9 Oct 2020 11:39:52 +0800 Subject: [PATCH] keep network mode unchange when use summary api (#27754) * keep summary mode unchange * add no grad decorator --- python/paddle/hapi/model_summary.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 3ead3fc295..c46a53e910 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -106,6 +106,12 @@ def summary(net, input_size, dtypes=None): warnings.warn( "Your model was created in static mode, this may not get correct summary information!" ) + in_train_mode = False + else: + in_train_mode = net.training + + if in_train_mode: + net.eval() def _is_shape(shape): for item in shape: @@ -143,9 +149,13 @@ def summary(net, input_size, dtypes=None): result, params_info = summary_string(net, _input_size, dtypes) print(result) + if in_train_mode: + net.train() + return params_info +@paddle.no_grad() def summary_string(model, input_size, dtypes=None): def _all_is_numper(items): for item in items: -- GitLab