未验证 提交 78a27a2b 编写于 作者: L LielinJiang 提交者: GitHub

Reproduce summary api (#27367)

* reproduce summary api
上级 29f1560d
...@@ -1813,7 +1813,7 @@ class Model(object): ...@@ -1813,7 +1813,7 @@ class Model(object):
return logs, outputs return logs, outputs
return logs return logs
def summary(self, input_size=None, batch_size=None, dtype=None): def summary(self, input_size=None, dtype=None):
"""Prints a string summary of the network. """Prints a string summary of the network.
Args: Args:
...@@ -1822,7 +1822,6 @@ class Model(object): ...@@ -1822,7 +1822,6 @@ class Model(object):
one input, input_size can be tuple or InputSpec. if model have multiple one input, input_size can be tuple or InputSpec. if model have multiple
input, input_size must be a list which contain every input's shape. input, input_size must be a list which contain every input's shape.
Default: None. Default: None.
batch_size (int, optional): batch size of input tensor, Default: None.
dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None. dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
Returns: Returns:
...@@ -1859,7 +1858,7 @@ class Model(object): ...@@ -1859,7 +1858,7 @@ class Model(object):
_input_size = input_size _input_size = input_size
else: else:
_input_size = self._inputs _input_size = self._inputs
return summary(self.network, _input_size, batch_size, dtype) return summary(self.network, _input_size, dtype)
def _verify_spec(self, specs, is_input=False): def _verify_spec(self, specs, is_input=False):
out_specs = [] out_specs = []
......
...@@ -25,7 +25,7 @@ from collections import OrderedDict ...@@ -25,7 +25,7 @@ from collections import OrderedDict
__all__ = ['summary'] __all__ = ['summary']
def summary(net, input_size, batch_size=None, dtypes=None): def summary(net, input_size, dtypes=None):
"""Prints a string summary of the network. """Prints a string summary of the network.
Args: Args:
...@@ -33,8 +33,8 @@ def summary(net, input_size, batch_size=None, dtypes=None): ...@@ -33,8 +33,8 @@ def summary(net, input_size, batch_size=None, dtypes=None):
input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only
have one input, input_size can be tuple or InputSpec. if model have one input, input_size can be tuple or InputSpec. if model
have multiple input, input_size must be a list which contain have multiple input, input_size must be a list which contain
every input's shape. every input's shape. Note that input_size only dim of
batch_size (int, optional): batch size of input tensor, Default: None. batch_size can be None or -1.
dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None. dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
Returns: Returns:
...@@ -77,14 +77,12 @@ def summary(net, input_size, batch_size=None, dtypes=None): ...@@ -77,14 +77,12 @@ def summary(net, input_size, batch_size=None, dtypes=None):
lenet = LeNet() lenet = LeNet()
params_info = paddle.summary(lenet, (1, 28, 28)) params_info = paddle.summary(lenet, (1, 1, 28, 28))
print(params_info) print(params_info)
""" """
if isinstance(input_size, InputSpec): if isinstance(input_size, InputSpec):
_input_size = tuple(input_size.shape[1:]) _input_size = tuple(input_size.shape)
if batch_size is None:
batch_size = input_size.shape[0]
elif isinstance(input_size, list): elif isinstance(input_size, list):
_input_size = [] _input_size = []
for item in input_size: for item in input_size:
...@@ -96,9 +94,7 @@ def summary(net, input_size, batch_size=None, dtypes=None): ...@@ -96,9 +94,7 @@ def summary(net, input_size, batch_size=None, dtypes=None):
type(item)) type(item))
if isinstance(item, InputSpec): if isinstance(item, InputSpec):
_input_size.append(tuple(item.shape[1:])) _input_size.append(tuple(item.shape))
if batch_size is None:
batch_size = item.shape[0]
else: else:
_input_size.append(item) _input_size.append(item)
elif isinstance(input_size, int): elif isinstance(input_size, int):
...@@ -106,28 +102,88 @@ def summary(net, input_size, batch_size=None, dtypes=None): ...@@ -106,28 +102,88 @@ def summary(net, input_size, batch_size=None, dtypes=None):
else: else:
_input_size = input_size _input_size = input_size
if batch_size is None:
batch_size = -1
if not paddle.in_dynamic_mode(): if not paddle.in_dynamic_mode():
warnings.warn( warnings.warn(
"Your model was created in static mode, this may not get correct summary information!" "Your model was created in static mode, this may not get correct summary information!"
) )
result, params_info = summary_string(net, _input_size, batch_size, dtypes) def _is_shape(shape):
for item in shape:
if isinstance(item, (list, tuple)):
return False
return True
def _check_shape(shape):
num_unknown = 0
new_shape = []
for i in range(len(shape)):
item = shape[i]
if item is None or item == -1:
num_unknown += 1
if num_unknown > 1:
raise ValueError(
'Option input_size only the dim of batch_size can be None or -1.'
)
item = 1
elif isinstance(item, numbers.Number):
if item <= 0:
raise ValueError(
"Expected element in input size greater than zero, but got {}".
format(item))
new_shape.append(item)
return tuple(new_shape)
def _check_input(input_size):
if isinstance(input_size, (list, tuple)) and _is_shape(input_size):
return _check_shape(input_size)
else:
return [_check_input(i) for i in input_size]
_input_size = _check_input(_input_size)
result, params_info = summary_string(net, _input_size, dtypes)
print(result) print(result)
return params_info return params_info
def summary_string(model, input_size, batch_size=-1, dtypes=None): def summary_string(model, input_size, dtypes=None):
if dtypes == None: def _all_is_numper(items):
dtypes = ['float32'] * len(input_size) for item in items:
if not isinstance(item, numbers.Number):
return False
return True
def _build_dtypes(input_size, dtype):
if dtype is None:
dtype = 'float32'
if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
return [dtype]
else:
return [_build_dtypes(i, dtype) for i in input_size]
if not isinstance(dtypes, (list, tuple)):
dtypes = _build_dtypes(input_size, dtypes)
batch_size = 1
summary_str = '' summary_str = ''
depth = len(list(model.sublayers())) depth = len(list(model.sublayers()))
def _get_shape_from_tensor(x):
if isinstance(x, (paddle.fluid.Variable, paddle.fluid.core.VarBase)):
return list(x.shape)
elif isinstance(x, (list, tuple)):
return [_get_shape_from_tensor(xx) for xx in x]
def _get_output_shape(output):
if isinstance(output, (list, tuple)):
output_shape = [_get_output_shape(o) for o in output]
else:
output_shape = list(output.shape)
return output_shape
def register_hook(layer): def register_hook(layer):
def hook(layer, input, output): def hook(layer, input, output):
class_name = str(layer.__class__).split(".")[-1].split("'")[0] class_name = str(layer.__class__).split(".")[-1].split("'")[0]
...@@ -139,14 +195,18 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -139,14 +195,18 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
m_key = "%s-%i" % (class_name, layer_idx + 1) m_key = "%s-%i" % (class_name, layer_idx + 1)
summary[m_key] = OrderedDict() summary[m_key] = OrderedDict()
summary[m_key]["input_shape"] = list(input[0].shape)
summary[m_key]["input_shape"][0] = batch_size try:
if isinstance(output, (list, tuple)): summary[m_key]["input_shape"] = _get_shape_from_tensor(input)
summary[m_key]["output_shape"] = [[-1] + list(o.shape)[1:] except:
for o in output] warnings.warn('Get layer {} input shape failed!')
else: summary[m_key]["input_shape"] = []
summary[m_key]["output_shape"] = list(output.shape)
summary[m_key]["output_shape"][0] = batch_size try:
summary[m_key]["output_shape"] = _get_output_shape(output)
except:
warnings.warn('Get layer {} output shape failed!')
summary[m_key]["output_shape"]
params = 0 params = 0
...@@ -175,29 +235,22 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -175,29 +235,22 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
hooks.append(layer.register_forward_post_hook(hook)) hooks.append(layer.register_forward_post_hook(hook))
def _check_input_size(input_sizes):
for input_size in input_sizes:
for item in input_size:
if not isinstance(item, numbers.Number):
raise TypeError(
"Expected item in input size be a number, but got {}".
format(type(item)))
if item <= 0:
raise ValueError(
"Expected item in input size greater than zero, but got {}".
format(item))
if isinstance(input_size, tuple): if isinstance(input_size, tuple):
input_size = [input_size] input_size = [input_size]
_check_input_size(input_size) def build_input(input_size, dtypes):
if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
if isinstance(dtypes, (list, tuple)):
dtype = dtypes[0]
else:
dtype = dtypes
return paddle.rand(list(input_size), dtype)
else:
return [
build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
]
x = [ x = build_input(input_size, dtypes)
paddle.rand(
[2] + list(in_size), dtype=dtype)
for in_size, dtype in zip(input_size, dtypes)
]
# create properties # create properties
summary = OrderedDict() summary = OrderedDict()
...@@ -213,22 +266,65 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -213,22 +266,65 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
for h in hooks: for h in hooks:
h.remove() h.remove()
table_width = 80 def _get_str_length(summary):
summary_str += "-" * table_width + "\n" head_length = {
line_new = "{:>15} {:>20} {:>20} {:>15}".format( 'layer_width': 15,
"Layer (type)", "Input Shape", "Output Shape", "Param #") 'input_shape_width': 20,
'output_shape_width': 20,
'params_width': 15,
'table_width': 75
}
for layer in summary:
if head_length['output_shape_width'] < len(
str(summary[layer]["output_shape"])):
head_length['output_shape_width'] = len(
str(summary[layer]["output_shape"]))
if head_length['input_shape_width'] < len(
str(summary[layer]["input_shape"])):
head_length['input_shape_width'] = len(
str(summary[layer]["input_shape"]))
if head_length['layer_width'] < len(str(layer)):
head_length['layer_width'] = len(str(layer))
if head_length['params_width'] < len(
str(summary[layer]["nb_params"])):
head_length['params_width'] = len(
str(summary[layer]["nb_params"]))
_temp_width = 0
for k, v in head_length.items():
if k != 'table_width':
_temp_width += v
if head_length['table_width'] < _temp_width + 5:
head_length['table_width'] = _temp_width + 5
return head_length
table_width = _get_str_length(summary)
summary_str += "-" * table_width['table_width'] + "\n"
line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
"Layer (type)", table_width['layer_width'], "Input Shape",
table_width['input_shape_width'], "Output Shape",
table_width['output_shape_width'], "Param #",
table_width['params_width'])
summary_str += line_new + "\n" summary_str += line_new + "\n"
summary_str += "=" * table_width + "\n" summary_str += "=" * table_width['table_width'] + "\n"
total_params = 0 total_params = 0
total_output = 0 total_output = 0
trainable_params = 0 trainable_params = 0
max_length = 0
for layer in summary: for layer in summary:
# input_shape, output_shape, trainable, nb_params # input_shape, output_shape, trainable, nb_params
line_new = "{:>15} {:>20} {:>20} {:>15}".format( line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
layer, layer, table_width['layer_width'],
str(summary[layer]["input_shape"]), str(summary[layer]["input_shape"]),
table_width['input_shape_width'],
str(summary[layer]["output_shape"]), str(summary[layer]["output_shape"]),
"{0:,}".format(summary[layer]["nb_params"]), ) table_width['output_shape_width'],
"{0:,}".format(summary[layer]["nb_params"]),
table_width['params_width'])
total_params += summary[layer]["nb_params"] total_params += summary[layer]["nb_params"]
try: try:
...@@ -242,25 +338,32 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -242,25 +338,32 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
trainable_params += summary[layer]["nb_params"] trainable_params += summary[layer]["nb_params"]
summary_str += line_new + "\n" summary_str += line_new + "\n"
# assume 4 bytes/number (float on cuda). def _get_input_size(input_size, size):
total_input_size = abs( if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
np.prod(sum(input_size, ())) * batch_size * 4. / (1024**2.)) size = abs(np.prod(input_size) * 4. / (1024**2.))
else:
size = sum([_get_input_size(i, size) for i in input_size])
return size
total_input_size = _get_input_size(input_size, 0)
total_output_size = abs(2. * total_output * 4. / total_output_size = abs(2. * total_output * 4. /
(1024**2.)) # x2 for gradients (1024**2.)) # x2 for gradients
total_params_size = abs(total_params * 4. / (1024**2.)) total_params_size = abs(total_params * 4. / (1024**2.))
total_size = total_params_size + total_output_size + total_input_size total_size = total_params_size + total_output_size + total_input_size
summary_str += "=" * table_width + "\n" summary_str += "=" * table_width['table_width'] + "\n"
summary_str += "Total params: {0:,}".format(total_params) + "\n" summary_str += "Total params: {0:,}".format(total_params) + "\n"
summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n" summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
summary_str += "Non-trainable params: {0:,}".format(total_params - summary_str += "Non-trainable params: {0:,}".format(total_params -
trainable_params) + "\n" trainable_params) + "\n"
summary_str += "-" * table_width + "\n" summary_str += "-" * table_width['table_width'] + "\n"
summary_str += "Input size (MB): %0.2f" % total_input_size + "\n" summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n" summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
summary_str += "Params size (MB): %0.2f" % total_params_size + "\n" summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
summary_str += "-" * table_width + "\n" summary_str += "-" * table_width['table_width'] + "\n"
# return summary # return summary
return summary_str, { return summary_str, {
'total_params': total_params, 'total_params': total_params,
......
...@@ -494,17 +494,22 @@ class TestModelFunction(unittest.TestCase): ...@@ -494,17 +494,22 @@ class TestModelFunction(unittest.TestCase):
model.summary(input_size=(20)) model.summary(input_size=(20))
model.summary(input_size=[(20)]) model.summary(input_size=[(20)])
model.summary(input_size=(20), batch_size=2) model.summary(input_size=(20), dtype='float32')
def test_summary_nlp(self): def test_summary_nlp(self):
paddle.enable_static() paddle.enable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2,
paddle.summary(nlp_net, (1, 2)) hidden_size=3,
num_layers=3,
direction="bidirectional")
paddle.summary(nlp_net, (1, 1, 2))
rnn = paddle.nn.LSTM(16, 32, 2)
paddle.summary(rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
def test_summary_error(self): def test_summary_error(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, '2')) paddle.summary(nlp_net, (1, 1, '2'))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
...@@ -512,7 +517,7 @@ class TestModelFunction(unittest.TestCase): ...@@ -512,7 +517,7 @@ class TestModelFunction(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 2)) paddle.summary(nlp_net, (1, 1, 2))
def test_export_deploy_model(self): def test_export_deploy_model(self):
for dynamic in [True, False]: for dynamic in [True, False]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册