model_summary.py 13.3 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15
import warnings
L
LielinJiang 已提交
16
import numpy as np
L
LielinJiang 已提交
17
import numbers
L
LielinJiang 已提交
18 19 20 21 22 23 24 25 26 27

import paddle
import paddle.nn as nn
from paddle.static import InputSpec

from collections import OrderedDict

__all__ = ['summary']


L
LielinJiang 已提交
28
def summary(net, input_size, dtypes=None):
L
LielinJiang 已提交
29 30 31 32 33 34 35
    """Prints a string summary of the network.

    Args:
        net (Layer): the network which must be a subinstance of Layer.
        input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only 
                    have one input, input_size can be tuple or InputSpec. if model
                    have multiple input, input_size must be a list which contain 
L
LielinJiang 已提交
36 37
                    every input's shape. Note that input_size only dim of
                    batch_size can be None or -1.
L
LielinJiang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
        dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.

    Returns:
        Dict: a summary of the network including total params and total trainable params.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            class LeNet(nn.Layer):
                def __init__(self, num_classes=10):
                    super(LeNet, self).__init__()
                    self.num_classes = num_classes
                    self.features = nn.Sequential(
C
cnn 已提交
54
                        nn.Conv2D(
L
LielinJiang 已提交
55 56
                            1, 6, 3, stride=1, padding=1),
                        nn.ReLU(),
C
cnn 已提交
57 58
                        nn.MaxPool2D(2, 2),
                        nn.Conv2D(
L
LielinJiang 已提交
59 60
                            6, 16, 5, stride=1, padding=0),
                        nn.ReLU(),
C
cnn 已提交
61
                        nn.MaxPool2D(2, 2))
L
LielinJiang 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

                    if num_classes > 0:
                        self.fc = nn.Sequential(
                            nn.Linear(400, 120),
                            nn.Linear(120, 84),
                            nn.Linear(
                                84, 10))

                def forward(self, inputs):
                    x = self.features(inputs)

                    if self.num_classes > 0:
                        x = paddle.flatten(x, 1)
                        x = self.fc(x)
                    return x

            lenet = LeNet()

L
LielinJiang 已提交
80
            params_info = paddle.summary(lenet, (1, 1, 28, 28))
L
LielinJiang 已提交
81 82 83 84
            print(params_info)

    """
    if isinstance(input_size, InputSpec):
L
LielinJiang 已提交
85
        _input_size = tuple(input_size.shape)
L
LielinJiang 已提交
86 87 88
    elif isinstance(input_size, list):
        _input_size = []
        for item in input_size:
89 90
            if isinstance(item, int):
                item = (item, )
L
LielinJiang 已提交
91
            assert isinstance(item,
92
                              (tuple, InputSpec)), 'When input_size is list, \
L
LielinJiang 已提交
93 94 95 96
            expect item in input_size is a tuple or InputSpec, but got {}'.format(
                                  type(item))

            if isinstance(item, InputSpec):
L
LielinJiang 已提交
97
                _input_size.append(tuple(item.shape))
L
LielinJiang 已提交
98 99
            else:
                _input_size.append(item)
100 101
    elif isinstance(input_size, int):
        _input_size = (input_size, )
L
LielinJiang 已提交
102 103 104
    else:
        _input_size = input_size

L
LielinJiang 已提交
105 106 107 108
    if not paddle.in_dynamic_mode():
        warnings.warn(
            "Your model was created in static mode, this may not get correct summary information!"
        )
109 110 111 112 113 114
        in_train_mode = False
    else:
        in_train_mode = net.training

    if in_train_mode:
        net.eval()
L
LielinJiang 已提交
115

L
LielinJiang 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    def _is_shape(shape):
        for item in shape:
            if isinstance(item, (list, tuple)):
                return False
        return True

    def _check_shape(shape):
        num_unknown = 0
        new_shape = []
        for i in range(len(shape)):
            item = shape[i]
            if item is None or item == -1:
                num_unknown += 1
                if num_unknown > 1:
                    raise ValueError(
                        'Option input_size only the dim of batch_size can be None or -1.'
                    )
                item = 1
            elif isinstance(item, numbers.Number):
                if item <= 0:
                    raise ValueError(
                        "Expected element in input size greater than zero, but got {}".
                        format(item))
            new_shape.append(item)
        return tuple(new_shape)

    def _check_input(input_size):
        if isinstance(input_size, (list, tuple)) and _is_shape(input_size):
            return _check_shape(input_size)
        else:
            return [_check_input(i) for i in input_size]

    _input_size = _check_input(_input_size)
    result, params_info = summary_string(net, _input_size, dtypes)
L
LielinJiang 已提交
150 151
    print(result)

152 153 154
    if in_train_mode:
        net.train()

L
LielinJiang 已提交
155 156 157
    return params_info


158
@paddle.no_grad()
L
LielinJiang 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
def summary_string(model, input_size, dtypes=None):
    def _all_is_numper(items):
        for item in items:
            if not isinstance(item, numbers.Number):
                return False
        return True

    def _build_dtypes(input_size, dtype):
        if dtype is None:
            dtype = 'float32'

        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            return [dtype]
        else:
            return [_build_dtypes(i, dtype) for i in input_size]

    if not isinstance(dtypes, (list, tuple)):
        dtypes = _build_dtypes(input_size, dtypes)

    batch_size = 1
L
LielinJiang 已提交
179 180 181 182 183

    summary_str = ''

    depth = len(list(model.sublayers()))

L
LielinJiang 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196
    def _get_shape_from_tensor(x):
        if isinstance(x, (paddle.fluid.Variable, paddle.fluid.core.VarBase)):
            return list(x.shape)
        elif isinstance(x, (list, tuple)):
            return [_get_shape_from_tensor(xx) for xx in x]

    def _get_output_shape(output):
        if isinstance(output, (list, tuple)):
            output_shape = [_get_output_shape(o) for o in output]
        else:
            output_shape = list(output.shape)
        return output_shape

L
LielinJiang 已提交
197 198 199
    def register_hook(layer):
        def hook(layer, input, output):
            class_name = str(layer.__class__).split(".")[-1].split("'")[0]
L
LielinJiang 已提交
200 201

            try:
L
LielinJiang 已提交
202
                layer_idx = int(layer._full_name.split('_')[-1])
L
LielinJiang 已提交
203
            except:
L
LielinJiang 已提交
204
                layer_idx = len(summary)
L
LielinJiang 已提交
205

L
LielinJiang 已提交
206
            m_key = "%s-%i" % (class_name, layer_idx + 1)
L
LielinJiang 已提交
207
            summary[m_key] = OrderedDict()
L
LielinJiang 已提交
208 209 210 211 212 213 214 215 216 217 218 219

            try:
                summary[m_key]["input_shape"] = _get_shape_from_tensor(input)
            except:
                warnings.warn('Get layer {} input shape failed!')
                summary[m_key]["input_shape"] = []

            try:
                summary[m_key]["output_shape"] = _get_output_shape(output)
            except:
                warnings.warn('Get layer {} output shape failed!')
                summary[m_key]["output_shape"]
L
LielinJiang 已提交
220 221

            params = 0
L
LielinJiang 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239

            if paddle.in_dynamic_mode():
                layer_state_dict = layer._parameters
            else:
                layer_state_dict = layer.state_dict()

            for k, v in layer_state_dict.items():
                params += np.prod(v.shape)

                try:
                    if (getattr(getattr(layer, k), 'trainable')) and (
                            not getattr(getattr(layer, k), 'stop_gradient')):
                        summary[m_key]["trainable"] = True
                    else:
                        summary[m_key]["trainable"] = False
                except:
                    summary[m_key]["trainable"] = True

L
LielinJiang 已提交
240 241
            summary[m_key]["nb_params"] = params

L
LielinJiang 已提交
242 243 244 245 246
        if (not isinstance(layer, nn.Sequential) and
                not isinstance(layer, nn.LayerList) and
            (not (layer == model) or depth < 1)):

            hooks.append(layer.register_forward_post_hook(hook))
247 248 249
        # For rnn, gru and lstm layer
        elif hasattr(layer, 'could_use_cudnn') and layer.could_use_cudnn:
            hooks.append(layer.register_forward_post_hook(hook))
L
LielinJiang 已提交
250

L
LielinJiang 已提交
251 252 253
    if isinstance(input_size, tuple):
        input_size = [input_size]

L
LielinJiang 已提交
254 255 256 257 258 259
    def build_input(input_size, dtypes):
        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            if isinstance(dtypes, (list, tuple)):
                dtype = dtypes[0]
            else:
                dtype = dtypes
L
LielinJiang 已提交
260
            return paddle.cast(paddle.rand(list(input_size)), dtype)
L
LielinJiang 已提交
261 262 263 264
        else:
            return [
                build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
            ]
L
LielinJiang 已提交
265

L
LielinJiang 已提交
266
    x = build_input(input_size, dtypes)
L
LielinJiang 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

L
LielinJiang 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
    def _get_str_length(summary):
        head_length = {
            'layer_width': 15,
            'input_shape_width': 20,
            'output_shape_width': 20,
            'params_width': 15,
            'table_width': 75
        }

        for layer in summary:
            if head_length['output_shape_width'] < len(
                    str(summary[layer]["output_shape"])):
                head_length['output_shape_width'] = len(
                    str(summary[layer]["output_shape"]))
            if head_length['input_shape_width'] < len(
                    str(summary[layer]["input_shape"])):
                head_length['input_shape_width'] = len(
                    str(summary[layer]["input_shape"]))
            if head_length['layer_width'] < len(str(layer)):
                head_length['layer_width'] = len(str(layer))
            if head_length['params_width'] < len(
                    str(summary[layer]["nb_params"])):
                head_length['params_width'] = len(
                    str(summary[layer]["nb_params"]))

        _temp_width = 0
        for k, v in head_length.items():
            if k != 'table_width':
                _temp_width += v

        if head_length['table_width'] < _temp_width + 5:
            head_length['table_width'] = _temp_width + 5

        return head_length

    table_width = _get_str_length(summary)

    summary_str += "-" * table_width['table_width'] + "\n"
    line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
        "Layer (type)", table_width['layer_width'], "Input Shape",
        table_width['input_shape_width'], "Output Shape",
        table_width['output_shape_width'], "Param #",
        table_width['params_width'])
L
LielinJiang 已提交
325
    summary_str += line_new + "\n"
L
LielinJiang 已提交
326
    summary_str += "=" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
327 328 329
    total_params = 0
    total_output = 0
    trainable_params = 0
L
LielinJiang 已提交
330
    max_length = 0
L
LielinJiang 已提交
331 332
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
L
LielinJiang 已提交
333 334
        line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
            layer, table_width['layer_width'],
L
LielinJiang 已提交
335
            str(summary[layer]["input_shape"]),
L
LielinJiang 已提交
336
            table_width['input_shape_width'],
L
LielinJiang 已提交
337
            str(summary[layer]["output_shape"]),
L
LielinJiang 已提交
338 339 340
            table_width['output_shape_width'],
            "{0:,}".format(summary[layer]["nb_params"]),
            table_width['params_width'])
L
LielinJiang 已提交
341 342
        total_params += summary[layer]["nb_params"]

L
LielinJiang 已提交
343 344 345 346 347 348
        try:
            total_output += np.prod(summary[layer]["output_shape"])
        except:
            for output_shape in summary[layer]["output_shape"]:
                total_output += np.prod(output_shape)

L
LielinJiang 已提交
349 350 351 352 353
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        summary_str += line_new + "\n"

L
LielinJiang 已提交
354 355 356 357 358 359 360 361 362
    def _get_input_size(input_size, size):
        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            size = abs(np.prod(input_size) * 4. / (1024**2.))
        else:
            size = sum([_get_input_size(i, size) for i in input_size])
        return size

    total_input_size = _get_input_size(input_size, 0)

L
LielinJiang 已提交
363 364 365 366 367
    total_output_size = abs(2. * total_output * 4. /
                            (1024**2.))  # x2 for gradients
    total_params_size = abs(total_params * 4. / (1024**2.))
    total_size = total_params_size + total_output_size + total_input_size

L
LielinJiang 已提交
368
    summary_str += "=" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
369 370 371 372
    summary_str += "Total params: {0:,}".format(total_params) + "\n"
    summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
    summary_str += "Non-trainable params: {0:,}".format(total_params -
                                                        trainable_params) + "\n"
L
LielinJiang 已提交
373
    summary_str += "-" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
374 375 376 377
    summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
    summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
    summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
    summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
L
LielinJiang 已提交
378 379
    summary_str += "-" * table_width['table_width'] + "\n"

L
LielinJiang 已提交
380 381 382 383 384
    # return summary
    return summary_str, {
        'total_params': total_params,
        'trainable_params': trainable_params
    }