未验证 提交 f95e8ffc 编写于 作者: L LielinJiang 提交者: GitHub

Fix conv and summary api bug (#27023)

* fix conv output_size has no default value bug
* fix summary bug
上级 0dfe26d0
...@@ -1868,8 +1868,13 @@ class Model(object): ...@@ -1868,8 +1868,13 @@ class Model(object):
print(params_info) print(params_info)
""" """
assert (input_size is not None or self._inputs is not None
return summary(self.network, self._inputs, batch_size, dtype) ), "'input_size' or 'self._input' must be set"
if input_size is not None:
_input_size = input_size
else:
_input_size = self._inputs
return summary(self.network, _input_size, batch_size, dtype)
def _verify_spec(self, specs, is_input=False): def _verify_spec(self, specs, is_input=False):
out_specs = [] out_specs = []
......
...@@ -86,8 +86,10 @@ def summary(net, input_size, batch_size=None, dtypes=None): ...@@ -86,8 +86,10 @@ def summary(net, input_size, batch_size=None, dtypes=None):
elif isinstance(input_size, list): elif isinstance(input_size, list):
_input_size = [] _input_size = []
for item in input_size: for item in input_size:
if isinstance(item, int):
item = (item, )
assert isinstance(item, assert isinstance(item,
(list, InputSpec)), 'When input_size is list, \ (tuple, InputSpec)), 'When input_size is list, \
expect item in input_size is a tuple or InputSpec, but got {}'.format( expect item in input_size is a tuple or InputSpec, but got {}'.format(
type(item)) type(item))
...@@ -97,6 +99,8 @@ def summary(net, input_size, batch_size=None, dtypes=None): ...@@ -97,6 +99,8 @@ def summary(net, input_size, batch_size=None, dtypes=None):
batch_size = item.shape[0] batch_size = item.shape[0]
else: else:
_input_size.append(item) _input_size.append(item)
elif isinstance(input_size, int):
_input_size = (input_size, )
else: else:
_input_size = input_size _input_size = input_size
...@@ -138,11 +142,11 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -138,11 +142,11 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
summary[m_key]["output_shape"][0] = batch_size summary[m_key]["output_shape"][0] = batch_size
params = 0 params = 0
if hasattr(module, "weight"): if hasattr(module, "weight") and hasattr(module.weight, "shape"):
params += np.prod(module.weight.shape) params += np.prod(module.weight.shape)
summary[m_key]["trainable"] = module.weight.trainable or ( summary[m_key]["trainable"] = module.weight.trainable or (
not module.weight.stop_gradient) not module.weight.stop_gradient)
if hasattr(module, "bias"): if hasattr(module, "bias") and hasattr(module.bias, "shape"):
params += np.prod(module.bias.shape) params += np.prod(module.bias.shape)
summary[m_key]["nb_params"] = params summary[m_key]["nb_params"] = params
......
...@@ -1084,7 +1084,7 @@ class ConvTranspose3d(_ConvNd): ...@@ -1084,7 +1084,7 @@ class ConvTranspose3d(_ConvNd):
bias_attr=bias_attr, bias_attr=bias_attr,
data_format=data_format) data_format=data_format)
def forward(self, x, output_size): def forward(self, x, output_size=None):
if output_size is None: if output_size is None:
output_padding = self.output_padding output_padding = self.output_padding
else: else:
......
...@@ -519,6 +519,10 @@ class TestModelFunction(unittest.TestCase): ...@@ -519,6 +519,10 @@ class TestModelFunction(unittest.TestCase):
np.testing.assert_allclose(params_info['total_params'], gt_params) np.testing.assert_allclose(params_info['total_params'], gt_params)
print(params_info) print(params_info)
model.summary(input_size=(20))
model.summary(input_size=[(20)])
model.summary(input_size=(20), batch_size=2)
def test_export_deploy_model(self): def test_export_deploy_model(self):
for dynamic in [True, False]: for dynamic in [True, False]:
fluid.enable_dygraph() if dynamic else None fluid.enable_dygraph() if dynamic else None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册