diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 8cd95a5ea583b65a114169a26ce037f5b85c53c3..c3c043bd3fc2b827aab7284a78a39c59c28382ce 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -301,14 +301,18 @@ def summary_string(model, input_size=None, dtypes=None, input=None): else: layer_state_dict = layer.state_dict() + summary[m_key]["trainable_params"] = 0 + trainable_flag = False for k, v in layer_state_dict.items(): params += np.prod(v.shape) try: if (getattr(getattr(layer, k), 'trainable')) and ( not getattr(getattr(layer, k), 'stop_gradient')): + summary[m_key]["trainable_params"] += np.prod(v.shape) summary[m_key]["trainable"] = True - else: + trainable_flag = True + elif not trainable_flag: summary[m_key]["trainable"] = False except: summary[m_key]["trainable"] = True @@ -427,7 +431,7 @@ def summary_string(model, input_size=None, dtypes=None, input=None): if "trainable" in summary[layer]: if summary[layer]["trainable"] == True: - trainable_params += summary[layer]["nb_params"] + trainable_params += summary[layer]["trainable_params"] summary_str += line_new + "\n" def _get_input_size(input_size, size): diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index fd3cb83d24e8a0bba23ee54bb3bb91de76964ed3..41de8ae189f85ecee534f990d3227b14786bd65e 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -90,7 +90,26 @@ class ModelOutter(paddle.nn.Layer): return y, 3 -class LeNetListInput(LeNetDygraph): +class LeNetListInput(paddle.nn.Layer): + def __init__(self, num_classes=10): + super(LeNetListInput, self).__init__() + self.num_classes = num_classes + self.cov = Conv2D(1, 6, 3, stride=1, padding=1) + for param in self.cov.parameters(): + param.trainable = False + self.features = Sequential( + self.cov, + ReLU(), + paddle.fluid.dygraph.Pool2D(2, 'max', 2), + Conv2D( + 6, 16, 5, stride=1, padding=0), + ReLU(), + paddle.fluid.dygraph.Pool2D(2, 'max', 2)) + + if num_classes > 0: + self.fc = Sequential( + Linear(400, 120), Linear(120, 84), Linear(84, 10)) + def forward(self, inputs): x = inputs[0] x = self.features(x)