未验证 提交 f95e8ffc 编写于 作者: L LielinJiang 提交者: GitHub

Fix conv and summary api bug (#27023)

* fix conv output_size has no default value bug
* fix summary bug
上级 0dfe26d0
......@@ -1868,8 +1868,13 @@ class Model(object):
print(params_info)
"""
return summary(self.network, self._inputs, batch_size, dtype)
assert (input_size is not None or self._inputs is not None
), "'input_size' or 'self._input' must be set"
if input_size is not None:
_input_size = input_size
else:
_input_size = self._inputs
return summary(self.network, _input_size, batch_size, dtype)
def _verify_spec(self, specs, is_input=False):
out_specs = []
......
......@@ -86,8 +86,10 @@ def summary(net, input_size, batch_size=None, dtypes=None):
elif isinstance(input_size, list):
_input_size = []
for item in input_size:
if isinstance(item, int):
item = (item, )
assert isinstance(item,
(list, InputSpec)), 'When input_size is list, \
(tuple, InputSpec)), 'When input_size is list, \
expect item in input_size is a tuple or InputSpec, but got {}'.format(
type(item))
......@@ -97,6 +99,8 @@ def summary(net, input_size, batch_size=None, dtypes=None):
batch_size = item.shape[0]
else:
_input_size.append(item)
elif isinstance(input_size, int):
_input_size = (input_size, )
else:
_input_size = input_size
......@@ -138,11 +142,11 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
summary[m_key]["output_shape"][0] = batch_size
params = 0
if hasattr(module, "weight"):
if hasattr(module, "weight") and hasattr(module.weight, "shape"):
params += np.prod(module.weight.shape)
summary[m_key]["trainable"] = module.weight.trainable or (
not module.weight.stop_gradient)
if hasattr(module, "bias"):
if hasattr(module, "bias") and hasattr(module.bias, "shape"):
params += np.prod(module.bias.shape)
summary[m_key]["nb_params"] = params
......
......@@ -1084,7 +1084,7 @@ class ConvTranspose3d(_ConvNd):
bias_attr=bias_attr,
data_format=data_format)
def forward(self, x, output_size):
def forward(self, x, output_size=None):
if output_size is None:
output_padding = self.output_padding
else:
......
......@@ -519,6 +519,10 @@ class TestModelFunction(unittest.TestCase):
np.testing.assert_allclose(params_info['total_params'], gt_params)
print(params_info)
model.summary(input_size=(20))
model.summary(input_size=[(20)])
model.summary(input_size=(20), batch_size=2)
def test_export_deploy_model(self):
for dynamic in [True, False]:
fluid.enable_dygraph() if dynamic else None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册