From e33b9db60cfe2329a7689d04a66ca13cce62e7a9 Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Wed, 18 May 2022 15:14:00 +0800 Subject: [PATCH] fix summary trainable_params bug (#42798) * fix summary trainable_params bug --- python/paddle/hapi/model_summary.py | 8 ++++++-- python/paddle/tests/test_model.py | 21 ++++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 8cd95a5ea5..c3c043bd3f 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -301,14 +301,18 @@ def summary_string(model, input_size=None, dtypes=None, input=None): else: layer_state_dict = layer.state_dict() + summary[m_key]["trainable_params"] = 0 + trainable_flag = False for k, v in layer_state_dict.items(): params += np.prod(v.shape) try: if (getattr(getattr(layer, k), 'trainable')) and ( not getattr(getattr(layer, k), 'stop_gradient')): + summary[m_key]["trainable_params"] += np.prod(v.shape) summary[m_key]["trainable"] = True - else: + trainable_flag = True + elif not trainable_flag: summary[m_key]["trainable"] = False except: summary[m_key]["trainable"] = True @@ -427,7 +431,7 @@ def summary_string(model, input_size=None, dtypes=None, input=None): if "trainable" in summary[layer]: if summary[layer]["trainable"] == True: - trainable_params += summary[layer]["nb_params"] + trainable_params += summary[layer]["trainable_params"] summary_str += line_new + "\n" def _get_input_size(input_size, size): diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index fd3cb83d24..41de8ae189 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -90,7 +90,26 @@ class ModelOutter(paddle.nn.Layer): return y, 3 -class LeNetListInput(LeNetDygraph): +class LeNetListInput(paddle.nn.Layer): + def __init__(self, num_classes=10): + super(LeNetListInput, self).__init__() + self.num_classes = num_classes + self.cov = Conv2D(1, 6, 3, stride=1, padding=1) + for param in self.cov.parameters(): + param.trainable = False + self.features = Sequential( + self.cov, + ReLU(), + paddle.fluid.dygraph.Pool2D(2, 'max', 2), + Conv2D( + 6, 16, 5, stride=1, padding=0), + ReLU(), + paddle.fluid.dygraph.Pool2D(2, 'max', 2)) + + if num_classes > 0: + self.fc = Sequential( + Linear(400, 120), Linear(120, 84), Linear(84, 10)) + def forward(self, inputs): x = inputs[0] x = self.features(x) -- GitLab