From 40bd7a7a76b41f4eba0fb645e182185aea327598 Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Thu, 29 Jul 2021 13:54:34 +0800 Subject: [PATCH] add parameter of input in model.summary (#34165) * add input option in model.summary --- python/paddle/hapi/model.py | 2 +- python/paddle/hapi/model_summary.py | 78 +++++++++++++++++++++++++---- python/paddle/tests/test_model.py | 37 ++++++++++++++ 3 files changed, 105 insertions(+), 12 deletions(-) diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 1ac873ce9ca..abc7aedbd8a 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -2145,7 +2145,7 @@ class Model(object): _input_size = input_size else: _input_size = self._inputs - return summary(self.network, _input_size, dtype) + return summary(self.network, _input_size, dtypes=dtype) def _verify_spec(self, specs, shapes=None, dtypes=None, is_input=False): out_specs = [] diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 93f1a5a37a6..7e435fdc27b 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -25,7 +25,7 @@ from collections import OrderedDict __all__ = [] -def summary(net, input_size, dtypes=None): +def summary(net, input_size=None, dtypes=None, input=None): """Prints a string summary of the network. Args: @@ -34,8 +34,10 @@ def summary(net, input_size, dtypes=None): have one input, input_size can be tuple or InputSpec. if model have multiple input, input_size must be a list which contain every input's shape. Note that input_size only dim of - batch_size can be None or -1. + batch_size can be None or -1. Default: None. Note that + input_size and input cannot be None at the same time. dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None. + input: the input tensor. if input is given, input_size and dtype will be ignored, Default: None. Returns: Dict: a summary of the network including total params and total trainable params. @@ -94,10 +96,62 @@ def summary(net, input_size, dtypes=None): lenet_multi_input = LeNetMultiInput() params_info = paddle.summary(lenet_multi_input, [(1, 1, 28, 28), (1, 400)], - ['float32', 'float32']) + dtypes=['float32', 'float32']) + print(params_info) + + # list input demo + class LeNetListInput(LeNet): + + def forward(self, inputs): + x = self.features(inputs[0]) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x + inputs[1]) + return x + + lenet_list_input = LeNetListInput() + input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])] + params_info = paddle.summary(lenet_list_input, input=input_data) + print(params_info) + + # dict input demo + class LeNetDictInput(LeNet): + + def forward(self, inputs): + x = self.features(inputs['x1']) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x + inputs['x2']) + return x + + lenet_dict_input = LeNetDictInput() + input_data = {'x1': paddle.rand([1, 1, 28, 28]), + 'x2': paddle.rand([1, 400])} + params_info = paddle.summary(lenet_dict_input, input=input_data) print(params_info) """ + if input_size is None and input is None: + raise ValueError("input_size and input cannot be None at the same time") + + if input_size is None and input is not None: + if paddle.is_tensor(input): + input_size = tuple(input.shape) + elif isinstance(input, (list, tuple)): + input_size = [] + for x in input: + input_size.append(tuple(x.shape)) + elif isinstance(input, dict): + input_size = [] + for key in input.keys(): + input_size.append(tuple(input[key].shape)) + else: + raise ValueError( + "Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size." + ) + if isinstance(input_size, InputSpec): _input_size = tuple(input_size.shape) elif isinstance(input_size, list): @@ -163,7 +217,8 @@ def summary(net, input_size, dtypes=None): return [_check_input(i) for i in input_size] _input_size = _check_input(_input_size) - result, params_info = summary_string(net, _input_size, dtypes) + + result, params_info = summary_string(net, _input_size, dtypes, input) print(result) if in_train_mode: @@ -173,7 +228,7 @@ def summary(net, input_size, dtypes=None): @paddle.no_grad() -def summary_string(model, input_size, dtypes=None): +def summary_string(model, input_size=None, dtypes=None, input=None): def _all_is_numper(items): for item in items: if not isinstance(item, numbers.Number): @@ -280,17 +335,18 @@ def summary_string(model, input_size, dtypes=None): build_input(i, dtype) for i, dtype in zip(input_size, dtypes) ] - x = build_input(input_size, dtypes) - # create properties summary = OrderedDict() hooks = [] - # register hook model.apply(register_hook) - - # make a forward pass - model(*x) + if input is not None: + x = input + model(x) + else: + x = build_input(input_size, dtypes) + # make a forward pass + model(*x) # remove these hooks for h in hooks: diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 0a6675babb2..36478289ccb 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -68,6 +68,27 @@ class LeNetDygraph(paddle.nn.Layer): return x +class LeNetListInput(LeNetDygraph): + def forward(self, inputs): + x = inputs[0] + x = self.features(x) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x + inputs[1]) + return x + + +class LeNetDictInput(LeNetDygraph): + def forward(self, inputs): + x = self.features(inputs['x1']) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x + inputs['x2']) + return x + + class MnistDataset(MNIST): def __init__(self, mode, return_label=True, sample_num=None): super(MnistDataset, self).__init__(mode=mode) @@ -615,6 +636,22 @@ class TestModelFunction(unittest.TestCase): gt_params = _get_param_from_state_dict(rnn.state_dict()) np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) + def test_summary_input(self): + rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional') + input_data = paddle.rand([4, 23, 16]) + paddle.summary(rnn, input=input_data) + + lenet_List_input = LeNetListInput() + input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])] + paddle.summary(lenet_List_input, input=input_data) + + lenet_dict_input = LeNetDictInput() + input_data = { + 'x1': paddle.rand([1, 1, 28, 28]), + 'x2': paddle.rand([1, 400]) + } + paddle.summary(lenet_dict_input, input=input_data) + def test_summary_dtype(self): input_shape = (3, 1) net = paddle.nn.Embedding(10, 3, sparse=True) -- GitLab