From 559d9f2bb91dffa1ade8b2bff337c9729955a133 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Thu, 3 Sep 2020 13:43:48 +0800 Subject: [PATCH] Add summary for hapi (#26386) * add summary for hapi --- python/paddle/__init__.py | 1 + python/paddle/hapi/__init__.py | 5 +- python/paddle/hapi/model.py | 45 +++++- python/paddle/hapi/model_summary.py | 225 ++++++++++++++++++++++++++++ python/paddle/tests/test_model.py | 20 +++ python/paddle/utils/__init__.py | 1 + 6 files changed, 294 insertions(+), 3 deletions(-) create mode 100644 python/paddle/hapi/model_summary.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5f1ccf3f858..fab3a97486c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -267,5 +267,6 @@ from . import static # high-level api from .hapi import Model from .hapi import callbacks +from .hapi import summary import paddle.text import paddle.vision diff --git a/python/paddle/hapi/__init__.py b/python/paddle/hapi/__init__.py index 87f5a82525c..fb16b829d5b 100644 --- a/python/paddle/hapi/__init__.py +++ b/python/paddle/hapi/__init__.py @@ -14,14 +14,15 @@ from . import logger from . import callbacks +from . import model_summary from . import model from .model import * - +from .model_summary import summary from .dygraph_layer_patch import monkey_patch_layer logger.setup_logger() -__all__ = ['callbacks'] + model.__all__ +__all__ = ['callbacks'] + model.__all__ + ['summary'] monkey_patch_layer() diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index bba94d56cca..019e526dfbb 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -47,10 +47,10 @@ from paddle.io import DataLoader, Dataset, DistributedBatchSampler from paddle.fluid.executor import scope_guard, Executor from paddle.fluid.dygraph.layers import Layer from paddle.metric import Metric - from paddle.static import InputSpec as Input from .callbacks import config_callbacks +from .model_summary import summary __all__ = ['Model', ] @@ -1828,6 +1828,49 @@ class Model(object): return logs, outputs return logs + def summary(self, input_size=None, batch_size=None, dtype=None): + """Prints a string summary of the network. + + Args: + input_size (tuple|InputSpec|list[tuple|InputSpec], optional): size of input tensor. + if not set, input_size will get from ``self._inputs`` if network only have + one input, input_size can be tuple or InputSpec. if model have multiple + input, input_size must be a list which contain every input's shape. + Default: None. + batch_size (int, optional): batch size of input tensor, Default: None. + dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None. + + Returns: + Dict: a summary of the network including total params and total trainable params. + + Examples: + .. code-block:: python + + import paddle + from paddle.static import InputSpec + + dynamic = True + device = paddle.set_device('cpu') + paddle.disable_static(device) if dynamic else None + + input = InputSpec([None, 1, 28, 28], 'float32', 'image') + label = InputSpec([None, 1], 'int64', 'label') + + model = paddle.Model(paddle.vision.LeNet(classifier_activation=None), + input, label) + optim = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + model.prepare( + optim, + paddle.nn.CrossEntropyLoss()) + + params_info = model.summary() + print(params_info) + + """ + + return summary(self.network, self._inputs, batch_size, dtype) + def _verify_spec(self, specs, is_input=False): out_specs = [] diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py new file mode 100644 index 00000000000..ddafcbed8ec --- /dev/null +++ b/python/paddle/hapi/model_summary.py @@ -0,0 +1,225 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +from paddle.static import InputSpec + +from collections import OrderedDict + +__all__ = ['summary'] + + +def summary(net, input_size, batch_size=None, dtypes=None): + """Prints a string summary of the network. + + Args: + net (Layer): the network which must be a subinstance of Layer. + input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only + have one input, input_size can be tuple or InputSpec. if model + have multiple input, input_size must be a list which contain + every input's shape. + batch_size (int, optional): batch size of input tensor, Default: None. + dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None. + + Returns: + Dict: a summary of the network including total params and total trainable params. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + + class LeNet(nn.Layer): + def __init__(self, num_classes=10): + super(LeNet, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d( + 1, 6, 3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d( + 6, 16, 5, stride=1, padding=0), + nn.ReLU(), + nn.MaxPool2d(2, 2)) + + if num_classes > 0: + self.fc = nn.Sequential( + nn.Linear(400, 120), + nn.Linear(120, 84), + nn.Linear( + 84, 10)) + + def forward(self, inputs): + x = self.features(inputs) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x) + return x + + lenet = LeNet() + + params_info = paddle.summary(lenet, (1, 28, 28)) + print(params_info) + + """ + if isinstance(input_size, InputSpec): + _input_size = tuple(input_size.shape[1:]) + if batch_size is None: + batch_size = input_size.shape[0] + elif isinstance(input_size, list): + _input_size = [] + for item in input_size: + assert isinstance(item, + (list, InputSpec)), 'When input_size is list, \ + expect item in input_size is a tuple or InputSpec, but got {}'.format( + type(item)) + + if isinstance(item, InputSpec): + _input_size.append(tuple(item.shape[1:])) + if batch_size is None: + batch_size = item.shape[0] + else: + _input_size.append(item) + else: + _input_size = input_size + + if batch_size is None: + batch_size = -1 + + result, params_info = summary_string(net, _input_size, batch_size, dtypes) + print(result) + + return params_info + + +def summary_string(model, input_size, batch_size=-1, dtypes=None): + if dtypes == None: + dtypes = ['float32'] * len(input_size) + + summary_str = '' + + depth = len(list(model.sublayers())) + + def register_hook(module): + def hook(module, input, output): + class_name = str(module.__class__).split(".")[-1].split("'")[0] + + try: + module_idx = int(module._full_name.split('_')[-1]) + except: + module_idx = len(summary) + + m_key = "%s-%i" % (class_name, module_idx + 1) + summary[m_key] = OrderedDict() + summary[m_key]["input_shape"] = list(input[0].shape) + summary[m_key]["input_shape"][0] = batch_size + if isinstance(output, (list, tuple)): + summary[m_key]["output_shape"] = [[-1] + list(o.shape)[1:] + for o in output] + else: + summary[m_key]["output_shape"] = list(output.shape) + summary[m_key]["output_shape"][0] = batch_size + + params = 0 + if hasattr(module, "weight"): + params += np.prod(module.weight.shape) + summary[m_key]["trainable"] = module.weight.trainable or ( + not module.weight.stop_gradient) + if hasattr(module, "bias"): + params += np.prod(module.bias.shape) + summary[m_key]["nb_params"] = params + + if (not isinstance(module, nn.Sequential) and + not isinstance(module, nn.LayerList) and + (not (module == model) or depth < 1)): + + hooks.append(module.register_forward_post_hook(hook)) + + if isinstance(input_size, tuple): + input_size = [input_size] + + x = [ + paddle.rand( + [2] + list(in_size), dtype=dtype) + for in_size, dtype in zip(input_size, dtypes) + ] + + # create properties + summary = OrderedDict() + hooks = [] + + # register hook + model.apply(register_hook) + + # make a forward pass + model(*x) + + # remove these hooks + for h in hooks: + h.remove() + + table_width = 80 + summary_str += "-" * table_width + "\n" + line_new = "{:>15} {:>20} {:>20} {:>15}".format( + "Layer (type)", "Input Shape", "Output Shape", "Param #") + summary_str += line_new + "\n" + summary_str += "=" * table_width + "\n" + total_params = 0 + total_output = 0 + trainable_params = 0 + for layer in summary: + # input_shape, output_shape, trainable, nb_params + line_new = "{:>15} {:>20} {:>20} {:>15}".format( + layer, + str(summary[layer]["input_shape"]), + str(summary[layer]["output_shape"]), + "{0:,}".format(summary[layer]["nb_params"]), ) + total_params += summary[layer]["nb_params"] + + total_output += np.prod(summary[layer]["output_shape"]) + if "trainable" in summary[layer]: + if summary[layer]["trainable"] == True: + trainable_params += summary[layer]["nb_params"] + summary_str += line_new + "\n" + + # assume 4 bytes/number (float on cuda). + total_input_size = abs( + np.prod(sum(input_size, ())) * batch_size * 4. / (1024**2.)) + total_output_size = abs(2. * total_output * 4. / + (1024**2.)) # x2 for gradients + total_params_size = abs(total_params * 4. / (1024**2.)) + total_size = total_params_size + total_output_size + total_input_size + + summary_str += "=" * table_width + "\n" + summary_str += "Total params: {0:,}".format(total_params) + "\n" + summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n" + summary_str += "Non-trainable params: {0:,}".format(total_params - + trainable_params) + "\n" + summary_str += "-" * table_width + "\n" + summary_str += "Input size (MB): %0.2f" % total_input_size + "\n" + summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n" + summary_str += "Params size (MB): %0.2f" % total_params_size + "\n" + summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" + summary_str += "-" * table_width + "\n" + # return summary + return summary_str, { + 'total_params': total_params, + 'trainable_params': trainable_params + } diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index e078595dc95..0267cd5dbc1 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -499,6 +499,26 @@ class TestModelFunction(unittest.TestCase): self.assertTrue(params[0].shape[1] == 10) fluid.disable_dygraph() if dynamic else None + def test_summary(self): + def _get_param_from_state_dict(state_dict): + params = 0 + for k, v in state_dict.items(): + params += np.prod(v.numpy().shape) + return params + + for dynamic in [True, False]: + device = paddle.set_device('cpu') + fluid.enable_dygraph(device) if dynamic else None + net = MyModel() + inputs = [InputSpec([None, 20], 'float32', 'x')] + model = Model(net, inputs) + model.prepare() + params_info = model.summary() + gt_params = _get_param_from_state_dict(net.state_dict()) + + np.testing.assert_allclose(params_info['total_params'], gt_params) + print(params_info) + def test_export_deploy_model(self): for dynamic in [True, False]: fluid.enable_dygraph() if dynamic else None diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index 009e5586e6c..2a649c776b4 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -16,6 +16,7 @@ from .profiler import ProfilerOptions from .profiler import Profiler from .profiler import get_profiler from .deprecated import deprecated + from . import download __all__ = ['dump_config', 'deprecated', 'download'] -- GitLab