From 78a27a2b0d7ad7b6676dc34ae305faf3ee5b0482 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Thu, 24 Sep 2020 12:54:53 +0800 Subject: [PATCH] Reproduce summary api (#27367) * reproduce summary api --- python/paddle/hapi/model.py | 5 +- python/paddle/hapi/model_summary.py | 219 ++++++++++++++++++++-------- python/paddle/tests/test_model.py | 15 +- 3 files changed, 173 insertions(+), 66 deletions(-) diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index d41852c9d7f..53928ebed1b 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -1813,7 +1813,7 @@ class Model(object): return logs, outputs return logs - def summary(self, input_size=None, batch_size=None, dtype=None): + def summary(self, input_size=None, dtype=None): """Prints a string summary of the network. Args: @@ -1822,7 +1822,6 @@ class Model(object): one input, input_size can be tuple or InputSpec. if model have multiple input, input_size must be a list which contain every input's shape. Default: None. - batch_size (int, optional): batch size of input tensor, Default: None. dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None. Returns: @@ -1859,7 +1858,7 @@ class Model(object): _input_size = input_size else: _input_size = self._inputs - return summary(self.network, _input_size, batch_size, dtype) + return summary(self.network, _input_size, dtype) def _verify_spec(self, specs, is_input=False): out_specs = [] diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index d388ba62f2a..3ead3fc295c 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -25,7 +25,7 @@ from collections import OrderedDict __all__ = ['summary'] -def summary(net, input_size, batch_size=None, dtypes=None): +def summary(net, input_size, dtypes=None): """Prints a string summary of the network. Args: @@ -33,8 +33,8 @@ def summary(net, input_size, batch_size=None, dtypes=None): input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only have one input, input_size can be tuple or InputSpec. if model have multiple input, input_size must be a list which contain - every input's shape. - batch_size (int, optional): batch size of input tensor, Default: None. + every input's shape. Note that input_size only dim of + batch_size can be None or -1. dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None. Returns: @@ -77,14 +77,12 @@ def summary(net, input_size, batch_size=None, dtypes=None): lenet = LeNet() - params_info = paddle.summary(lenet, (1, 28, 28)) + params_info = paddle.summary(lenet, (1, 1, 28, 28)) print(params_info) """ if isinstance(input_size, InputSpec): - _input_size = tuple(input_size.shape[1:]) - if batch_size is None: - batch_size = input_size.shape[0] + _input_size = tuple(input_size.shape) elif isinstance(input_size, list): _input_size = [] for item in input_size: @@ -96,9 +94,7 @@ def summary(net, input_size, batch_size=None, dtypes=None): type(item)) if isinstance(item, InputSpec): - _input_size.append(tuple(item.shape[1:])) - if batch_size is None: - batch_size = item.shape[0] + _input_size.append(tuple(item.shape)) else: _input_size.append(item) elif isinstance(input_size, int): @@ -106,28 +102,88 @@ def summary(net, input_size, batch_size=None, dtypes=None): else: _input_size = input_size - if batch_size is None: - batch_size = -1 - if not paddle.in_dynamic_mode(): warnings.warn( "Your model was created in static mode, this may not get correct summary information!" ) - result, params_info = summary_string(net, _input_size, batch_size, dtypes) + def _is_shape(shape): + for item in shape: + if isinstance(item, (list, tuple)): + return False + return True + + def _check_shape(shape): + num_unknown = 0 + new_shape = [] + for i in range(len(shape)): + item = shape[i] + if item is None or item == -1: + num_unknown += 1 + if num_unknown > 1: + raise ValueError( + 'Option input_size only the dim of batch_size can be None or -1.' + ) + item = 1 + elif isinstance(item, numbers.Number): + if item <= 0: + raise ValueError( + "Expected element in input size greater than zero, but got {}". + format(item)) + new_shape.append(item) + return tuple(new_shape) + + def _check_input(input_size): + if isinstance(input_size, (list, tuple)) and _is_shape(input_size): + return _check_shape(input_size) + else: + return [_check_input(i) for i in input_size] + + _input_size = _check_input(_input_size) + result, params_info = summary_string(net, _input_size, dtypes) print(result) return params_info -def summary_string(model, input_size, batch_size=-1, dtypes=None): - if dtypes == None: - dtypes = ['float32'] * len(input_size) +def summary_string(model, input_size, dtypes=None): + def _all_is_numper(items): + for item in items: + if not isinstance(item, numbers.Number): + return False + return True + + def _build_dtypes(input_size, dtype): + if dtype is None: + dtype = 'float32' + + if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size): + return [dtype] + else: + return [_build_dtypes(i, dtype) for i in input_size] + + if not isinstance(dtypes, (list, tuple)): + dtypes = _build_dtypes(input_size, dtypes) + + batch_size = 1 summary_str = '' depth = len(list(model.sublayers())) + def _get_shape_from_tensor(x): + if isinstance(x, (paddle.fluid.Variable, paddle.fluid.core.VarBase)): + return list(x.shape) + elif isinstance(x, (list, tuple)): + return [_get_shape_from_tensor(xx) for xx in x] + + def _get_output_shape(output): + if isinstance(output, (list, tuple)): + output_shape = [_get_output_shape(o) for o in output] + else: + output_shape = list(output.shape) + return output_shape + def register_hook(layer): def hook(layer, input, output): class_name = str(layer.__class__).split(".")[-1].split("'")[0] @@ -139,14 +195,18 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): m_key = "%s-%i" % (class_name, layer_idx + 1) summary[m_key] = OrderedDict() - summary[m_key]["input_shape"] = list(input[0].shape) - summary[m_key]["input_shape"][0] = batch_size - if isinstance(output, (list, tuple)): - summary[m_key]["output_shape"] = [[-1] + list(o.shape)[1:] - for o in output] - else: - summary[m_key]["output_shape"] = list(output.shape) - summary[m_key]["output_shape"][0] = batch_size + + try: + summary[m_key]["input_shape"] = _get_shape_from_tensor(input) + except: + warnings.warn('Get layer {} input shape failed!') + summary[m_key]["input_shape"] = [] + + try: + summary[m_key]["output_shape"] = _get_output_shape(output) + except: + warnings.warn('Get layer {} output shape failed!') + summary[m_key]["output_shape"] params = 0 @@ -175,29 +235,22 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): hooks.append(layer.register_forward_post_hook(hook)) - def _check_input_size(input_sizes): - for input_size in input_sizes: - for item in input_size: - if not isinstance(item, numbers.Number): - raise TypeError( - "Expected item in input size be a number, but got {}". - format(type(item))) - - if item <= 0: - raise ValueError( - "Expected item in input size greater than zero, but got {}". - format(item)) - if isinstance(input_size, tuple): input_size = [input_size] - _check_input_size(input_size) + def build_input(input_size, dtypes): + if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size): + if isinstance(dtypes, (list, tuple)): + dtype = dtypes[0] + else: + dtype = dtypes + return paddle.rand(list(input_size), dtype) + else: + return [ + build_input(i, dtype) for i, dtype in zip(input_size, dtypes) + ] - x = [ - paddle.rand( - [2] + list(in_size), dtype=dtype) - for in_size, dtype in zip(input_size, dtypes) - ] + x = build_input(input_size, dtypes) # create properties summary = OrderedDict() @@ -213,22 +266,65 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): for h in hooks: h.remove() - table_width = 80 - summary_str += "-" * table_width + "\n" - line_new = "{:>15} {:>20} {:>20} {:>15}".format( - "Layer (type)", "Input Shape", "Output Shape", "Param #") + def _get_str_length(summary): + head_length = { + 'layer_width': 15, + 'input_shape_width': 20, + 'output_shape_width': 20, + 'params_width': 15, + 'table_width': 75 + } + + for layer in summary: + if head_length['output_shape_width'] < len( + str(summary[layer]["output_shape"])): + head_length['output_shape_width'] = len( + str(summary[layer]["output_shape"])) + if head_length['input_shape_width'] < len( + str(summary[layer]["input_shape"])): + head_length['input_shape_width'] = len( + str(summary[layer]["input_shape"])) + if head_length['layer_width'] < len(str(layer)): + head_length['layer_width'] = len(str(layer)) + if head_length['params_width'] < len( + str(summary[layer]["nb_params"])): + head_length['params_width'] = len( + str(summary[layer]["nb_params"])) + + _temp_width = 0 + for k, v in head_length.items(): + if k != 'table_width': + _temp_width += v + + if head_length['table_width'] < _temp_width + 5: + head_length['table_width'] = _temp_width + 5 + + return head_length + + table_width = _get_str_length(summary) + + summary_str += "-" * table_width['table_width'] + "\n" + line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format( + "Layer (type)", table_width['layer_width'], "Input Shape", + table_width['input_shape_width'], "Output Shape", + table_width['output_shape_width'], "Param #", + table_width['params_width']) summary_str += line_new + "\n" - summary_str += "=" * table_width + "\n" + summary_str += "=" * table_width['table_width'] + "\n" total_params = 0 total_output = 0 trainable_params = 0 + max_length = 0 for layer in summary: # input_shape, output_shape, trainable, nb_params - line_new = "{:>15} {:>20} {:>20} {:>15}".format( - layer, + line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format( + layer, table_width['layer_width'], str(summary[layer]["input_shape"]), + table_width['input_shape_width'], str(summary[layer]["output_shape"]), - "{0:,}".format(summary[layer]["nb_params"]), ) + table_width['output_shape_width'], + "{0:,}".format(summary[layer]["nb_params"]), + table_width['params_width']) total_params += summary[layer]["nb_params"] try: @@ -242,25 +338,32 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): trainable_params += summary[layer]["nb_params"] summary_str += line_new + "\n" - # assume 4 bytes/number (float on cuda). - total_input_size = abs( - np.prod(sum(input_size, ())) * batch_size * 4. / (1024**2.)) + def _get_input_size(input_size, size): + if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size): + size = abs(np.prod(input_size) * 4. / (1024**2.)) + else: + size = sum([_get_input_size(i, size) for i in input_size]) + return size + + total_input_size = _get_input_size(input_size, 0) + total_output_size = abs(2. * total_output * 4. / (1024**2.)) # x2 for gradients total_params_size = abs(total_params * 4. / (1024**2.)) total_size = total_params_size + total_output_size + total_input_size - summary_str += "=" * table_width + "\n" + summary_str += "=" * table_width['table_width'] + "\n" summary_str += "Total params: {0:,}".format(total_params) + "\n" summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n" summary_str += "Non-trainable params: {0:,}".format(total_params - trainable_params) + "\n" - summary_str += "-" * table_width + "\n" + summary_str += "-" * table_width['table_width'] + "\n" summary_str += "Input size (MB): %0.2f" % total_input_size + "\n" summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n" summary_str += "Params size (MB): %0.2f" % total_params_size + "\n" summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" - summary_str += "-" * table_width + "\n" + summary_str += "-" * table_width['table_width'] + "\n" + # return summary return summary_str, { 'total_params': total_params, diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 62cc39c1f7b..c89cbbbfbda 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -494,17 +494,22 @@ class TestModelFunction(unittest.TestCase): model.summary(input_size=(20)) model.summary(input_size=[(20)]) - model.summary(input_size=(20), batch_size=2) + model.summary(input_size=(20), dtype='float32') def test_summary_nlp(self): paddle.enable_static() - nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) - paddle.summary(nlp_net, (1, 2)) + nlp_net = paddle.nn.GRU(input_size=2, + hidden_size=3, + num_layers=3, + direction="bidirectional") + paddle.summary(nlp_net, (1, 1, 2)) + rnn = paddle.nn.LSTM(16, 32, 2) + paddle.summary(rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))]) def test_summary_error(self): with self.assertRaises(TypeError): nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) - paddle.summary(nlp_net, (1, '2')) + paddle.summary(nlp_net, (1, 1, '2')) with self.assertRaises(ValueError): nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) @@ -512,7 +517,7 @@ class TestModelFunction(unittest.TestCase): paddle.disable_static() nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) - paddle.summary(nlp_net, (1, 2)) + paddle.summary(nlp_net, (1, 1, 2)) def test_export_deploy_model(self): for dynamic in [True, False]: -- GitLab