未验证 提交 d84eb9b3 编写于 作者: L LielinJiang 提交者: GitHub

keep network mode unchange when use summary api (#27754)

* keep summary mode unchange

* add no grad decorator
上级 b9c7c66e
...@@ -106,6 +106,12 @@ def summary(net, input_size, dtypes=None): ...@@ -106,6 +106,12 @@ def summary(net, input_size, dtypes=None):
warnings.warn( warnings.warn(
"Your model was created in static mode, this may not get correct summary information!" "Your model was created in static mode, this may not get correct summary information!"
) )
in_train_mode = False
else:
in_train_mode = net.training
if in_train_mode:
net.eval()
def _is_shape(shape): def _is_shape(shape):
for item in shape: for item in shape:
...@@ -143,9 +149,13 @@ def summary(net, input_size, dtypes=None): ...@@ -143,9 +149,13 @@ def summary(net, input_size, dtypes=None):
result, params_info = summary_string(net, _input_size, dtypes) result, params_info = summary_string(net, _input_size, dtypes)
print(result) print(result)
if in_train_mode:
net.train()
return params_info return params_info
@paddle.no_grad()
def summary_string(model, input_size, dtypes=None): def summary_string(model, input_size, dtypes=None):
def _all_is_numper(items): def _all_is_numper(items):
for item in items: for item in items:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册