未验证 提交 e33b9db6 编写于 作者: W wangna11BD 提交者: GitHub

fix summary trainable_params bug (#42798)

* fix summary trainable_params bug
上级 bebaee37
...@@ -301,14 +301,18 @@ def summary_string(model, input_size=None, dtypes=None, input=None): ...@@ -301,14 +301,18 @@ def summary_string(model, input_size=None, dtypes=None, input=None):
else: else:
layer_state_dict = layer.state_dict() layer_state_dict = layer.state_dict()
summary[m_key]["trainable_params"] = 0
trainable_flag = False
for k, v in layer_state_dict.items(): for k, v in layer_state_dict.items():
params += np.prod(v.shape) params += np.prod(v.shape)
try: try:
if (getattr(getattr(layer, k), 'trainable')) and ( if (getattr(getattr(layer, k), 'trainable')) and (
not getattr(getattr(layer, k), 'stop_gradient')): not getattr(getattr(layer, k), 'stop_gradient')):
summary[m_key]["trainable_params"] += np.prod(v.shape)
summary[m_key]["trainable"] = True summary[m_key]["trainable"] = True
else: trainable_flag = True
elif not trainable_flag:
summary[m_key]["trainable"] = False summary[m_key]["trainable"] = False
except: except:
summary[m_key]["trainable"] = True summary[m_key]["trainable"] = True
...@@ -427,7 +431,7 @@ def summary_string(model, input_size=None, dtypes=None, input=None): ...@@ -427,7 +431,7 @@ def summary_string(model, input_size=None, dtypes=None, input=None):
if "trainable" in summary[layer]: if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True: if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"] trainable_params += summary[layer]["trainable_params"]
summary_str += line_new + "\n" summary_str += line_new + "\n"
def _get_input_size(input_size, size): def _get_input_size(input_size, size):
......
...@@ -90,7 +90,26 @@ class ModelOutter(paddle.nn.Layer): ...@@ -90,7 +90,26 @@ class ModelOutter(paddle.nn.Layer):
return y, 3 return y, 3
class LeNetListInput(LeNetDygraph): class LeNetListInput(paddle.nn.Layer):
def __init__(self, num_classes=10):
super(LeNetListInput, self).__init__()
self.num_classes = num_classes
self.cov = Conv2D(1, 6, 3, stride=1, padding=1)
for param in self.cov.parameters():
param.trainable = False
self.features = Sequential(
self.cov,
ReLU(),
paddle.fluid.dygraph.Pool2D(2, 'max', 2),
Conv2D(
6, 16, 5, stride=1, padding=0),
ReLU(),
paddle.fluid.dygraph.Pool2D(2, 'max', 2))
if num_classes > 0:
self.fc = Sequential(
Linear(400, 120), Linear(120, 84), Linear(84, 10))
def forward(self, inputs): def forward(self, inputs):
x = inputs[0] x = inputs[0]
x = self.features(x) x = self.features(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册