model_summary.py 13.9 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15
import warnings
L
LielinJiang 已提交
16
import numpy as np
L
LielinJiang 已提交
17
import numbers
L
LielinJiang 已提交
18 19 20 21 22 23 24

import paddle
import paddle.nn as nn
from paddle.static import InputSpec

from collections import OrderedDict

25
__all__ = []
L
LielinJiang 已提交
26 27


L
LielinJiang 已提交
28
def summary(net, input_size, dtypes=None):
L
LielinJiang 已提交
29 30 31 32 33 34 35
    """Prints a string summary of the network.

    Args:
        net (Layer): the network which must be a subinstance of Layer.
        input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only 
                    have one input, input_size can be tuple or InputSpec. if model
                    have multiple input, input_size must be a list which contain 
L
LielinJiang 已提交
36 37
                    every input's shape. Note that input_size only dim of
                    batch_size can be None or -1.
L
LielinJiang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
        dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.

    Returns:
        Dict: a summary of the network including total params and total trainable params.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            class LeNet(nn.Layer):
                def __init__(self, num_classes=10):
                    super(LeNet, self).__init__()
                    self.num_classes = num_classes
                    self.features = nn.Sequential(
C
cnn 已提交
54
                        nn.Conv2D(
L
LielinJiang 已提交
55 56
                            1, 6, 3, stride=1, padding=1),
                        nn.ReLU(),
C
cnn 已提交
57 58
                        nn.MaxPool2D(2, 2),
                        nn.Conv2D(
L
LielinJiang 已提交
59 60
                            6, 16, 5, stride=1, padding=0),
                        nn.ReLU(),
C
cnn 已提交
61
                        nn.MaxPool2D(2, 2))
L
LielinJiang 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

                    if num_classes > 0:
                        self.fc = nn.Sequential(
                            nn.Linear(400, 120),
                            nn.Linear(120, 84),
                            nn.Linear(
                                84, 10))

                def forward(self, inputs):
                    x = self.features(inputs)

                    if self.num_classes > 0:
                        x = paddle.flatten(x, 1)
                        x = self.fc(x)
                    return x

            lenet = LeNet()

L
LielinJiang 已提交
80
            params_info = paddle.summary(lenet, (1, 1, 28, 28))
L
LielinJiang 已提交
81 82
            print(params_info)

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
            # multi input demo
            class LeNetMultiInput(LeNet):

                def forward(self, inputs, y):
                    x = self.features(inputs)

                    if self.num_classes > 0:
                        x = paddle.flatten(x, 1)
                        x = self.fc(x + y)
                    return x
            
            lenet_multi_input = LeNetMultiInput()

            params_info = paddle.summary(lenet_multi_input, [(1, 1, 28, 28), (1, 400)], 
                                        ['float32', 'float32'])
            print(params_info)

L
LielinJiang 已提交
100 101
    """
    if isinstance(input_size, InputSpec):
L
LielinJiang 已提交
102
        _input_size = tuple(input_size.shape)
L
LielinJiang 已提交
103 104 105
    elif isinstance(input_size, list):
        _input_size = []
        for item in input_size:
106 107
            if isinstance(item, int):
                item = (item, )
L
LielinJiang 已提交
108
            assert isinstance(item,
109
                              (tuple, InputSpec)), 'When input_size is list, \
L
LielinJiang 已提交
110 111 112 113
            expect item in input_size is a tuple or InputSpec, but got {}'.format(
                                  type(item))

            if isinstance(item, InputSpec):
L
LielinJiang 已提交
114
                _input_size.append(tuple(item.shape))
L
LielinJiang 已提交
115 116
            else:
                _input_size.append(item)
117 118
    elif isinstance(input_size, int):
        _input_size = (input_size, )
L
LielinJiang 已提交
119 120 121
    else:
        _input_size = input_size

L
LielinJiang 已提交
122 123 124 125
    if not paddle.in_dynamic_mode():
        warnings.warn(
            "Your model was created in static mode, this may not get correct summary information!"
        )
126 127 128 129 130 131
        in_train_mode = False
    else:
        in_train_mode = net.training

    if in_train_mode:
        net.eval()
L
LielinJiang 已提交
132

L
LielinJiang 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
    def _is_shape(shape):
        for item in shape:
            if isinstance(item, (list, tuple)):
                return False
        return True

    def _check_shape(shape):
        num_unknown = 0
        new_shape = []
        for i in range(len(shape)):
            item = shape[i]
            if item is None or item == -1:
                num_unknown += 1
                if num_unknown > 1:
                    raise ValueError(
                        'Option input_size only the dim of batch_size can be None or -1.'
                    )
                item = 1
            elif isinstance(item, numbers.Number):
                if item <= 0:
                    raise ValueError(
                        "Expected element in input size greater than zero, but got {}".
                        format(item))
            new_shape.append(item)
        return tuple(new_shape)

    def _check_input(input_size):
        if isinstance(input_size, (list, tuple)) and _is_shape(input_size):
            return _check_shape(input_size)
        else:
            return [_check_input(i) for i in input_size]

    _input_size = _check_input(_input_size)
    result, params_info = summary_string(net, _input_size, dtypes)
L
LielinJiang 已提交
167 168
    print(result)

169 170 171
    if in_train_mode:
        net.train()

L
LielinJiang 已提交
172 173 174
    return params_info


175
@paddle.no_grad()
L
LielinJiang 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
def summary_string(model, input_size, dtypes=None):
    def _all_is_numper(items):
        for item in items:
            if not isinstance(item, numbers.Number):
                return False
        return True

    def _build_dtypes(input_size, dtype):
        if dtype is None:
            dtype = 'float32'

        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            return [dtype]
        else:
            return [_build_dtypes(i, dtype) for i in input_size]

    if not isinstance(dtypes, (list, tuple)):
        dtypes = _build_dtypes(input_size, dtypes)

    batch_size = 1
L
LielinJiang 已提交
196 197 198 199 200

    summary_str = ''

    depth = len(list(model.sublayers()))

L
LielinJiang 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213
    def _get_shape_from_tensor(x):
        if isinstance(x, (paddle.fluid.Variable, paddle.fluid.core.VarBase)):
            return list(x.shape)
        elif isinstance(x, (list, tuple)):
            return [_get_shape_from_tensor(xx) for xx in x]

    def _get_output_shape(output):
        if isinstance(output, (list, tuple)):
            output_shape = [_get_output_shape(o) for o in output]
        else:
            output_shape = list(output.shape)
        return output_shape

L
LielinJiang 已提交
214 215 216
    def register_hook(layer):
        def hook(layer, input, output):
            class_name = str(layer.__class__).split(".")[-1].split("'")[0]
L
LielinJiang 已提交
217 218

            try:
L
LielinJiang 已提交
219
                layer_idx = int(layer._full_name.split('_')[-1])
L
LielinJiang 已提交
220
            except:
L
LielinJiang 已提交
221
                layer_idx = len(summary)
L
LielinJiang 已提交
222

L
LielinJiang 已提交
223
            m_key = "%s-%i" % (class_name, layer_idx + 1)
L
LielinJiang 已提交
224
            summary[m_key] = OrderedDict()
L
LielinJiang 已提交
225 226 227 228 229 230 231 232 233 234 235 236

            try:
                summary[m_key]["input_shape"] = _get_shape_from_tensor(input)
            except:
                warnings.warn('Get layer {} input shape failed!')
                summary[m_key]["input_shape"] = []

            try:
                summary[m_key]["output_shape"] = _get_output_shape(output)
            except:
                warnings.warn('Get layer {} output shape failed!')
                summary[m_key]["output_shape"]
L
LielinJiang 已提交
237 238

            params = 0
L
LielinJiang 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256

            if paddle.in_dynamic_mode():
                layer_state_dict = layer._parameters
            else:
                layer_state_dict = layer.state_dict()

            for k, v in layer_state_dict.items():
                params += np.prod(v.shape)

                try:
                    if (getattr(getattr(layer, k), 'trainable')) and (
                            not getattr(getattr(layer, k), 'stop_gradient')):
                        summary[m_key]["trainable"] = True
                    else:
                        summary[m_key]["trainable"] = False
                except:
                    summary[m_key]["trainable"] = True

L
LielinJiang 已提交
257 258
            summary[m_key]["nb_params"] = params

L
LielinJiang 已提交
259 260 261 262 263
        if (not isinstance(layer, nn.Sequential) and
                not isinstance(layer, nn.LayerList) and
            (not (layer == model) or depth < 1)):

            hooks.append(layer.register_forward_post_hook(hook))
264 265 266
        # For rnn, gru and lstm layer
        elif hasattr(layer, 'could_use_cudnn') and layer.could_use_cudnn:
            hooks.append(layer.register_forward_post_hook(hook))
L
LielinJiang 已提交
267

L
LielinJiang 已提交
268 269 270
    if isinstance(input_size, tuple):
        input_size = [input_size]

L
LielinJiang 已提交
271 272 273 274 275 276
    def build_input(input_size, dtypes):
        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            if isinstance(dtypes, (list, tuple)):
                dtype = dtypes[0]
            else:
                dtype = dtypes
L
LielinJiang 已提交
277
            return paddle.cast(paddle.rand(list(input_size)), dtype)
L
LielinJiang 已提交
278 279 280 281
        else:
            return [
                build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
            ]
L
LielinJiang 已提交
282

L
LielinJiang 已提交
283
    x = build_input(input_size, dtypes)
L
LielinJiang 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

L
LielinJiang 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
    def _get_str_length(summary):
        head_length = {
            'layer_width': 15,
            'input_shape_width': 20,
            'output_shape_width': 20,
            'params_width': 15,
            'table_width': 75
        }

        for layer in summary:
            if head_length['output_shape_width'] < len(
                    str(summary[layer]["output_shape"])):
                head_length['output_shape_width'] = len(
                    str(summary[layer]["output_shape"]))
            if head_length['input_shape_width'] < len(
                    str(summary[layer]["input_shape"])):
                head_length['input_shape_width'] = len(
                    str(summary[layer]["input_shape"]))
            if head_length['layer_width'] < len(str(layer)):
                head_length['layer_width'] = len(str(layer))
            if head_length['params_width'] < len(
                    str(summary[layer]["nb_params"])):
                head_length['params_width'] = len(
                    str(summary[layer]["nb_params"]))

        _temp_width = 0
        for k, v in head_length.items():
            if k != 'table_width':
                _temp_width += v

        if head_length['table_width'] < _temp_width + 5:
            head_length['table_width'] = _temp_width + 5

        return head_length

    table_width = _get_str_length(summary)

    summary_str += "-" * table_width['table_width'] + "\n"
    line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
        "Layer (type)", table_width['layer_width'], "Input Shape",
        table_width['input_shape_width'], "Output Shape",
        table_width['output_shape_width'], "Param #",
        table_width['params_width'])
L
LielinJiang 已提交
342
    summary_str += line_new + "\n"
L
LielinJiang 已提交
343
    summary_str += "=" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
344 345 346
    total_params = 0
    total_output = 0
    trainable_params = 0
L
LielinJiang 已提交
347
    max_length = 0
L
LielinJiang 已提交
348 349
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
L
LielinJiang 已提交
350 351
        line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
            layer, table_width['layer_width'],
L
LielinJiang 已提交
352
            str(summary[layer]["input_shape"]),
L
LielinJiang 已提交
353
            table_width['input_shape_width'],
L
LielinJiang 已提交
354
            str(summary[layer]["output_shape"]),
L
LielinJiang 已提交
355 356 357
            table_width['output_shape_width'],
            "{0:,}".format(summary[layer]["nb_params"]),
            table_width['params_width'])
L
LielinJiang 已提交
358 359
        total_params += summary[layer]["nb_params"]

L
LielinJiang 已提交
360
        try:
361 362 363
            total_output += np.sum(
                np.prod(
                    summary[layer]["output_shape"], axis=-1))
L
LielinJiang 已提交
364 365
        except:
            for output_shape in summary[layer]["output_shape"]:
366
                total_output += np.sum(np.prod(output_shape, axis=-1))
L
LielinJiang 已提交
367

L
LielinJiang 已提交
368 369 370 371 372
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        summary_str += line_new + "\n"

L
LielinJiang 已提交
373 374 375 376 377 378 379 380 381
    def _get_input_size(input_size, size):
        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            size = abs(np.prod(input_size) * 4. / (1024**2.))
        else:
            size = sum([_get_input_size(i, size) for i in input_size])
        return size

    total_input_size = _get_input_size(input_size, 0)

L
LielinJiang 已提交
382 383 384 385 386
    total_output_size = abs(2. * total_output * 4. /
                            (1024**2.))  # x2 for gradients
    total_params_size = abs(total_params * 4. / (1024**2.))
    total_size = total_params_size + total_output_size + total_input_size

L
LielinJiang 已提交
387
    summary_str += "=" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
388 389 390 391
    summary_str += "Total params: {0:,}".format(total_params) + "\n"
    summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
    summary_str += "Non-trainable params: {0:,}".format(total_params -
                                                        trainable_params) + "\n"
L
LielinJiang 已提交
392
    summary_str += "-" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
393 394 395 396
    summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
    summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
    summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
    summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
L
LielinJiang 已提交
397 398
    summary_str += "-" * table_width['table_width'] + "\n"

L
LielinJiang 已提交
399 400 401 402 403
    # return summary
    return summary_str, {
        'total_params': total_params,
        'trainable_params': trainable_params
    }