model_summary.py 13.1 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15
import warnings
L
LielinJiang 已提交
16
import numpy as np
L
LielinJiang 已提交
17
import numbers
L
LielinJiang 已提交
18 19 20 21 22 23 24 25 26 27

import paddle
import paddle.nn as nn
from paddle.static import InputSpec

from collections import OrderedDict

__all__ = ['summary']


L
LielinJiang 已提交
28
def summary(net, input_size, dtypes=None):
L
LielinJiang 已提交
29 30 31 32 33 34 35
    """Prints a string summary of the network.

    Args:
        net (Layer): the network which must be a subinstance of Layer.
        input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only 
                    have one input, input_size can be tuple or InputSpec. if model
                    have multiple input, input_size must be a list which contain 
L
LielinJiang 已提交
36 37
                    every input's shape. Note that input_size only dim of
                    batch_size can be None or -1.
L
LielinJiang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.

    Returns:
        Dict: a summary of the network including total params and total trainable params.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            class LeNet(nn.Layer):
                def __init__(self, num_classes=10):
                    super(LeNet, self).__init__()
                    self.num_classes = num_classes
                    self.features = nn.Sequential(
                        nn.Conv2d(
                            1, 6, 3, stride=1, padding=1),
                        nn.ReLU(),
                        nn.MaxPool2d(2, 2),
                        nn.Conv2d(
                            6, 16, 5, stride=1, padding=0),
                        nn.ReLU(),
                        nn.MaxPool2d(2, 2))

                    if num_classes > 0:
                        self.fc = nn.Sequential(
                            nn.Linear(400, 120),
                            nn.Linear(120, 84),
                            nn.Linear(
                                84, 10))

                def forward(self, inputs):
                    x = self.features(inputs)

                    if self.num_classes > 0:
                        x = paddle.flatten(x, 1)
                        x = self.fc(x)
                    return x

            lenet = LeNet()

L
LielinJiang 已提交
80
            params_info = paddle.summary(lenet, (1, 1, 28, 28))
L
LielinJiang 已提交
81 82 83 84
            print(params_info)

    """
    if isinstance(input_size, InputSpec):
L
LielinJiang 已提交
85
        _input_size = tuple(input_size.shape)
L
LielinJiang 已提交
86 87 88
    elif isinstance(input_size, list):
        _input_size = []
        for item in input_size:
89 90
            if isinstance(item, int):
                item = (item, )
L
LielinJiang 已提交
91
            assert isinstance(item,
92
                              (tuple, InputSpec)), 'When input_size is list, \
L
LielinJiang 已提交
93 94 95 96
            expect item in input_size is a tuple or InputSpec, but got {}'.format(
                                  type(item))

            if isinstance(item, InputSpec):
L
LielinJiang 已提交
97
                _input_size.append(tuple(item.shape))
L
LielinJiang 已提交
98 99
            else:
                _input_size.append(item)
100 101
    elif isinstance(input_size, int):
        _input_size = (input_size, )
L
LielinJiang 已提交
102 103 104
    else:
        _input_size = input_size

L
LielinJiang 已提交
105 106 107 108
    if not paddle.in_dynamic_mode():
        warnings.warn(
            "Your model was created in static mode, this may not get correct summary information!"
        )
109 110 111 112 113 114
        in_train_mode = False
    else:
        in_train_mode = net.training

    if in_train_mode:
        net.eval()
L
LielinJiang 已提交
115

L
LielinJiang 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    def _is_shape(shape):
        for item in shape:
            if isinstance(item, (list, tuple)):
                return False
        return True

    def _check_shape(shape):
        num_unknown = 0
        new_shape = []
        for i in range(len(shape)):
            item = shape[i]
            if item is None or item == -1:
                num_unknown += 1
                if num_unknown > 1:
                    raise ValueError(
                        'Option input_size only the dim of batch_size can be None or -1.'
                    )
                item = 1
            elif isinstance(item, numbers.Number):
                if item <= 0:
                    raise ValueError(
                        "Expected element in input size greater than zero, but got {}".
                        format(item))
            new_shape.append(item)
        return tuple(new_shape)

    def _check_input(input_size):
        if isinstance(input_size, (list, tuple)) and _is_shape(input_size):
            return _check_shape(input_size)
        else:
            return [_check_input(i) for i in input_size]

    _input_size = _check_input(_input_size)
    result, params_info = summary_string(net, _input_size, dtypes)
L
LielinJiang 已提交
150 151
    print(result)

152 153 154
    if in_train_mode:
        net.train()

L
LielinJiang 已提交
155 156 157
    return params_info


158
@paddle.no_grad()
L
LielinJiang 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
def summary_string(model, input_size, dtypes=None):
    def _all_is_numper(items):
        for item in items:
            if not isinstance(item, numbers.Number):
                return False
        return True

    def _build_dtypes(input_size, dtype):
        if dtype is None:
            dtype = 'float32'

        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            return [dtype]
        else:
            return [_build_dtypes(i, dtype) for i in input_size]

    if not isinstance(dtypes, (list, tuple)):
        dtypes = _build_dtypes(input_size, dtypes)

    batch_size = 1
L
LielinJiang 已提交
179 180 181 182 183

    summary_str = ''

    depth = len(list(model.sublayers()))

L
LielinJiang 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196
    def _get_shape_from_tensor(x):
        if isinstance(x, (paddle.fluid.Variable, paddle.fluid.core.VarBase)):
            return list(x.shape)
        elif isinstance(x, (list, tuple)):
            return [_get_shape_from_tensor(xx) for xx in x]

    def _get_output_shape(output):
        if isinstance(output, (list, tuple)):
            output_shape = [_get_output_shape(o) for o in output]
        else:
            output_shape = list(output.shape)
        return output_shape

L
LielinJiang 已提交
197 198 199
    def register_hook(layer):
        def hook(layer, input, output):
            class_name = str(layer.__class__).split(".")[-1].split("'")[0]
L
LielinJiang 已提交
200 201

            try:
L
LielinJiang 已提交
202
                layer_idx = int(layer._full_name.split('_')[-1])
L
LielinJiang 已提交
203
            except:
L
LielinJiang 已提交
204
                layer_idx = len(summary)
L
LielinJiang 已提交
205

L
LielinJiang 已提交
206
            m_key = "%s-%i" % (class_name, layer_idx + 1)
L
LielinJiang 已提交
207
            summary[m_key] = OrderedDict()
L
LielinJiang 已提交
208 209 210 211 212 213 214 215 216 217 218 219

            try:
                summary[m_key]["input_shape"] = _get_shape_from_tensor(input)
            except:
                warnings.warn('Get layer {} input shape failed!')
                summary[m_key]["input_shape"] = []

            try:
                summary[m_key]["output_shape"] = _get_output_shape(output)
            except:
                warnings.warn('Get layer {} output shape failed!')
                summary[m_key]["output_shape"]
L
LielinJiang 已提交
220 221

            params = 0
L
LielinJiang 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239

            if paddle.in_dynamic_mode():
                layer_state_dict = layer._parameters
            else:
                layer_state_dict = layer.state_dict()

            for k, v in layer_state_dict.items():
                params += np.prod(v.shape)

                try:
                    if (getattr(getattr(layer, k), 'trainable')) and (
                            not getattr(getattr(layer, k), 'stop_gradient')):
                        summary[m_key]["trainable"] = True
                    else:
                        summary[m_key]["trainable"] = False
                except:
                    summary[m_key]["trainable"] = True

L
LielinJiang 已提交
240 241
            summary[m_key]["nb_params"] = params

L
LielinJiang 已提交
242 243 244 245 246 247
        if (not isinstance(layer, nn.Sequential) and
                not isinstance(layer, nn.LayerList) and
            (not (layer == model) or depth < 1)):

            hooks.append(layer.register_forward_post_hook(hook))

L
LielinJiang 已提交
248 249 250
    if isinstance(input_size, tuple):
        input_size = [input_size]

L
LielinJiang 已提交
251 252 253 254 255 256
    def build_input(input_size, dtypes):
        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            if isinstance(dtypes, (list, tuple)):
                dtype = dtypes[0]
            else:
                dtype = dtypes
L
LielinJiang 已提交
257
            return paddle.cast(paddle.rand(list(input_size)), dtype)
L
LielinJiang 已提交
258 259 260 261
        else:
            return [
                build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
            ]
L
LielinJiang 已提交
262

L
LielinJiang 已提交
263
    x = build_input(input_size, dtypes)
L
LielinJiang 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

L
LielinJiang 已提交
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
    def _get_str_length(summary):
        head_length = {
            'layer_width': 15,
            'input_shape_width': 20,
            'output_shape_width': 20,
            'params_width': 15,
            'table_width': 75
        }

        for layer in summary:
            if head_length['output_shape_width'] < len(
                    str(summary[layer]["output_shape"])):
                head_length['output_shape_width'] = len(
                    str(summary[layer]["output_shape"]))
            if head_length['input_shape_width'] < len(
                    str(summary[layer]["input_shape"])):
                head_length['input_shape_width'] = len(
                    str(summary[layer]["input_shape"]))
            if head_length['layer_width'] < len(str(layer)):
                head_length['layer_width'] = len(str(layer))
            if head_length['params_width'] < len(
                    str(summary[layer]["nb_params"])):
                head_length['params_width'] = len(
                    str(summary[layer]["nb_params"]))

        _temp_width = 0
        for k, v in head_length.items():
            if k != 'table_width':
                _temp_width += v

        if head_length['table_width'] < _temp_width + 5:
            head_length['table_width'] = _temp_width + 5

        return head_length

    table_width = _get_str_length(summary)

    summary_str += "-" * table_width['table_width'] + "\n"
    line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
        "Layer (type)", table_width['layer_width'], "Input Shape",
        table_width['input_shape_width'], "Output Shape",
        table_width['output_shape_width'], "Param #",
        table_width['params_width'])
L
LielinJiang 已提交
322
    summary_str += line_new + "\n"
L
LielinJiang 已提交
323
    summary_str += "=" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
324 325 326
    total_params = 0
    total_output = 0
    trainable_params = 0
L
LielinJiang 已提交
327
    max_length = 0
L
LielinJiang 已提交
328 329
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
L
LielinJiang 已提交
330 331
        line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
            layer, table_width['layer_width'],
L
LielinJiang 已提交
332
            str(summary[layer]["input_shape"]),
L
LielinJiang 已提交
333
            table_width['input_shape_width'],
L
LielinJiang 已提交
334
            str(summary[layer]["output_shape"]),
L
LielinJiang 已提交
335 336 337
            table_width['output_shape_width'],
            "{0:,}".format(summary[layer]["nb_params"]),
            table_width['params_width'])
L
LielinJiang 已提交
338 339
        total_params += summary[layer]["nb_params"]

L
LielinJiang 已提交
340 341 342 343 344 345
        try:
            total_output += np.prod(summary[layer]["output_shape"])
        except:
            for output_shape in summary[layer]["output_shape"]:
                total_output += np.prod(output_shape)

L
LielinJiang 已提交
346 347 348 349 350
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        summary_str += line_new + "\n"

L
LielinJiang 已提交
351 352 353 354 355 356 357 358 359
    def _get_input_size(input_size, size):
        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            size = abs(np.prod(input_size) * 4. / (1024**2.))
        else:
            size = sum([_get_input_size(i, size) for i in input_size])
        return size

    total_input_size = _get_input_size(input_size, 0)

L
LielinJiang 已提交
360 361 362 363 364
    total_output_size = abs(2. * total_output * 4. /
                            (1024**2.))  # x2 for gradients
    total_params_size = abs(total_params * 4. / (1024**2.))
    total_size = total_params_size + total_output_size + total_input_size

L
LielinJiang 已提交
365
    summary_str += "=" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
366 367 368 369
    summary_str += "Total params: {0:,}".format(total_params) + "\n"
    summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
    summary_str += "Non-trainable params: {0:,}".format(total_params -
                                                        trainable_params) + "\n"
L
LielinJiang 已提交
370
    summary_str += "-" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
371 372 373 374
    summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
    summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
    summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
    summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
L
LielinJiang 已提交
375 376
    summary_str += "-" * table_width['table_width'] + "\n"

L
LielinJiang 已提交
377 378 379 380 381
    # return summary
    return summary_str, {
        'total_params': total_params,
        'trainable_params': trainable_params
    }