model_summary.py 12.9 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15
import warnings
L
LielinJiang 已提交
16
import numpy as np
L
LielinJiang 已提交
17
import numbers
L
LielinJiang 已提交
18 19 20 21 22 23 24 25 26 27

import paddle
import paddle.nn as nn
from paddle.static import InputSpec

from collections import OrderedDict

__all__ = ['summary']


L
LielinJiang 已提交
28
def summary(net, input_size, dtypes=None):
L
LielinJiang 已提交
29 30 31 32 33 34 35
    """Prints a string summary of the network.

    Args:
        net (Layer): the network which must be a subinstance of Layer.
        input_size (tuple|InputSpec|list[tuple|InputSpec]): size of input tensor. if model only 
                    have one input, input_size can be tuple or InputSpec. if model
                    have multiple input, input_size must be a list which contain 
L
LielinJiang 已提交
36 37
                    every input's shape. Note that input_size only dim of
                    batch_size can be None or -1.
L
LielinJiang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.

    Returns:
        Dict: a summary of the network including total params and total trainable params.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            class LeNet(nn.Layer):
                def __init__(self, num_classes=10):
                    super(LeNet, self).__init__()
                    self.num_classes = num_classes
                    self.features = nn.Sequential(
                        nn.Conv2d(
                            1, 6, 3, stride=1, padding=1),
                        nn.ReLU(),
                        nn.MaxPool2d(2, 2),
                        nn.Conv2d(
                            6, 16, 5, stride=1, padding=0),
                        nn.ReLU(),
                        nn.MaxPool2d(2, 2))

                    if num_classes > 0:
                        self.fc = nn.Sequential(
                            nn.Linear(400, 120),
                            nn.Linear(120, 84),
                            nn.Linear(
                                84, 10))

                def forward(self, inputs):
                    x = self.features(inputs)

                    if self.num_classes > 0:
                        x = paddle.flatten(x, 1)
                        x = self.fc(x)
                    return x

            lenet = LeNet()

L
LielinJiang 已提交
80
            params_info = paddle.summary(lenet, (1, 1, 28, 28))
L
LielinJiang 已提交
81 82 83 84
            print(params_info)

    """
    if isinstance(input_size, InputSpec):
L
LielinJiang 已提交
85
        _input_size = tuple(input_size.shape)
L
LielinJiang 已提交
86 87 88
    elif isinstance(input_size, list):
        _input_size = []
        for item in input_size:
89 90
            if isinstance(item, int):
                item = (item, )
L
LielinJiang 已提交
91
            assert isinstance(item,
92
                              (tuple, InputSpec)), 'When input_size is list, \
L
LielinJiang 已提交
93 94 95 96
            expect item in input_size is a tuple or InputSpec, but got {}'.format(
                                  type(item))

            if isinstance(item, InputSpec):
L
LielinJiang 已提交
97
                _input_size.append(tuple(item.shape))
L
LielinJiang 已提交
98 99
            else:
                _input_size.append(item)
100 101
    elif isinstance(input_size, int):
        _input_size = (input_size, )
L
LielinJiang 已提交
102 103 104
    else:
        _input_size = input_size

L
LielinJiang 已提交
105 106 107 108 109
    if not paddle.in_dynamic_mode():
        warnings.warn(
            "Your model was created in static mode, this may not get correct summary information!"
        )

L
LielinJiang 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    def _is_shape(shape):
        for item in shape:
            if isinstance(item, (list, tuple)):
                return False
        return True

    def _check_shape(shape):
        num_unknown = 0
        new_shape = []
        for i in range(len(shape)):
            item = shape[i]
            if item is None or item == -1:
                num_unknown += 1
                if num_unknown > 1:
                    raise ValueError(
                        'Option input_size only the dim of batch_size can be None or -1.'
                    )
                item = 1
            elif isinstance(item, numbers.Number):
                if item <= 0:
                    raise ValueError(
                        "Expected element in input size greater than zero, but got {}".
                        format(item))
            new_shape.append(item)
        return tuple(new_shape)

    def _check_input(input_size):
        if isinstance(input_size, (list, tuple)) and _is_shape(input_size):
            return _check_shape(input_size)
        else:
            return [_check_input(i) for i in input_size]

    _input_size = _check_input(_input_size)
    result, params_info = summary_string(net, _input_size, dtypes)
L
LielinJiang 已提交
144 145 146 147 148
    print(result)

    return params_info


L
LielinJiang 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
def summary_string(model, input_size, dtypes=None):
    def _all_is_numper(items):
        for item in items:
            if not isinstance(item, numbers.Number):
                return False
        return True

    def _build_dtypes(input_size, dtype):
        if dtype is None:
            dtype = 'float32'

        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            return [dtype]
        else:
            return [_build_dtypes(i, dtype) for i in input_size]

    if not isinstance(dtypes, (list, tuple)):
        dtypes = _build_dtypes(input_size, dtypes)

    batch_size = 1
L
LielinJiang 已提交
169 170 171 172 173

    summary_str = ''

    depth = len(list(model.sublayers()))

L
LielinJiang 已提交
174 175 176 177 178 179 180 181 182 183 184 185 186
    def _get_shape_from_tensor(x):
        if isinstance(x, (paddle.fluid.Variable, paddle.fluid.core.VarBase)):
            return list(x.shape)
        elif isinstance(x, (list, tuple)):
            return [_get_shape_from_tensor(xx) for xx in x]

    def _get_output_shape(output):
        if isinstance(output, (list, tuple)):
            output_shape = [_get_output_shape(o) for o in output]
        else:
            output_shape = list(output.shape)
        return output_shape

L
LielinJiang 已提交
187 188 189
    def register_hook(layer):
        def hook(layer, input, output):
            class_name = str(layer.__class__).split(".")[-1].split("'")[0]
L
LielinJiang 已提交
190 191

            try:
L
LielinJiang 已提交
192
                layer_idx = int(layer._full_name.split('_')[-1])
L
LielinJiang 已提交
193
            except:
L
LielinJiang 已提交
194
                layer_idx = len(summary)
L
LielinJiang 已提交
195

L
LielinJiang 已提交
196
            m_key = "%s-%i" % (class_name, layer_idx + 1)
L
LielinJiang 已提交
197
            summary[m_key] = OrderedDict()
L
LielinJiang 已提交
198 199 200 201 202 203 204 205 206 207 208 209

            try:
                summary[m_key]["input_shape"] = _get_shape_from_tensor(input)
            except:
                warnings.warn('Get layer {} input shape failed!')
                summary[m_key]["input_shape"] = []

            try:
                summary[m_key]["output_shape"] = _get_output_shape(output)
            except:
                warnings.warn('Get layer {} output shape failed!')
                summary[m_key]["output_shape"]
L
LielinJiang 已提交
210 211

            params = 0
L
LielinJiang 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229

            if paddle.in_dynamic_mode():
                layer_state_dict = layer._parameters
            else:
                layer_state_dict = layer.state_dict()

            for k, v in layer_state_dict.items():
                params += np.prod(v.shape)

                try:
                    if (getattr(getattr(layer, k), 'trainable')) and (
                            not getattr(getattr(layer, k), 'stop_gradient')):
                        summary[m_key]["trainable"] = True
                    else:
                        summary[m_key]["trainable"] = False
                except:
                    summary[m_key]["trainable"] = True

L
LielinJiang 已提交
230 231
            summary[m_key]["nb_params"] = params

L
LielinJiang 已提交
232 233 234 235 236 237
        if (not isinstance(layer, nn.Sequential) and
                not isinstance(layer, nn.LayerList) and
            (not (layer == model) or depth < 1)):

            hooks.append(layer.register_forward_post_hook(hook))

L
LielinJiang 已提交
238 239 240
    if isinstance(input_size, tuple):
        input_size = [input_size]

L
LielinJiang 已提交
241 242 243 244 245 246 247 248 249 250 251
    def build_input(input_size, dtypes):
        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            if isinstance(dtypes, (list, tuple)):
                dtype = dtypes[0]
            else:
                dtype = dtypes
            return paddle.rand(list(input_size), dtype)
        else:
            return [
                build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
            ]
L
LielinJiang 已提交
252

L
LielinJiang 已提交
253
    x = build_input(input_size, dtypes)
L
LielinJiang 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

L
LielinJiang 已提交
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
    def _get_str_length(summary):
        head_length = {
            'layer_width': 15,
            'input_shape_width': 20,
            'output_shape_width': 20,
            'params_width': 15,
            'table_width': 75
        }

        for layer in summary:
            if head_length['output_shape_width'] < len(
                    str(summary[layer]["output_shape"])):
                head_length['output_shape_width'] = len(
                    str(summary[layer]["output_shape"]))
            if head_length['input_shape_width'] < len(
                    str(summary[layer]["input_shape"])):
                head_length['input_shape_width'] = len(
                    str(summary[layer]["input_shape"]))
            if head_length['layer_width'] < len(str(layer)):
                head_length['layer_width'] = len(str(layer))
            if head_length['params_width'] < len(
                    str(summary[layer]["nb_params"])):
                head_length['params_width'] = len(
                    str(summary[layer]["nb_params"]))

        _temp_width = 0
        for k, v in head_length.items():
            if k != 'table_width':
                _temp_width += v

        if head_length['table_width'] < _temp_width + 5:
            head_length['table_width'] = _temp_width + 5

        return head_length

    table_width = _get_str_length(summary)

    summary_str += "-" * table_width['table_width'] + "\n"
    line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
        "Layer (type)", table_width['layer_width'], "Input Shape",
        table_width['input_shape_width'], "Output Shape",
        table_width['output_shape_width'], "Param #",
        table_width['params_width'])
L
LielinJiang 已提交
312
    summary_str += line_new + "\n"
L
LielinJiang 已提交
313
    summary_str += "=" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
314 315 316
    total_params = 0
    total_output = 0
    trainable_params = 0
L
LielinJiang 已提交
317
    max_length = 0
L
LielinJiang 已提交
318 319
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
L
LielinJiang 已提交
320 321
        line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
            layer, table_width['layer_width'],
L
LielinJiang 已提交
322
            str(summary[layer]["input_shape"]),
L
LielinJiang 已提交
323
            table_width['input_shape_width'],
L
LielinJiang 已提交
324
            str(summary[layer]["output_shape"]),
L
LielinJiang 已提交
325 326 327
            table_width['output_shape_width'],
            "{0:,}".format(summary[layer]["nb_params"]),
            table_width['params_width'])
L
LielinJiang 已提交
328 329
        total_params += summary[layer]["nb_params"]

L
LielinJiang 已提交
330 331 332 333 334 335
        try:
            total_output += np.prod(summary[layer]["output_shape"])
        except:
            for output_shape in summary[layer]["output_shape"]:
                total_output += np.prod(output_shape)

L
LielinJiang 已提交
336 337 338 339 340
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        summary_str += line_new + "\n"

L
LielinJiang 已提交
341 342 343 344 345 346 347 348 349
    def _get_input_size(input_size, size):
        if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
            size = abs(np.prod(input_size) * 4. / (1024**2.))
        else:
            size = sum([_get_input_size(i, size) for i in input_size])
        return size

    total_input_size = _get_input_size(input_size, 0)

L
LielinJiang 已提交
350 351 352 353 354
    total_output_size = abs(2. * total_output * 4. /
                            (1024**2.))  # x2 for gradients
    total_params_size = abs(total_params * 4. / (1024**2.))
    total_size = total_params_size + total_output_size + total_input_size

L
LielinJiang 已提交
355
    summary_str += "=" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
356 357 358 359
    summary_str += "Total params: {0:,}".format(total_params) + "\n"
    summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
    summary_str += "Non-trainable params: {0:,}".format(total_params -
                                                        trainable_params) + "\n"
L
LielinJiang 已提交
360
    summary_str += "-" * table_width['table_width'] + "\n"
L
LielinJiang 已提交
361 362 363 364
    summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
    summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
    summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
    summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
L
LielinJiang 已提交
365 366
    summary_str += "-" * table_width['table_width'] + "\n"

L
LielinJiang 已提交
367 368 369 370 371
    # return summary
    return summary_str, {
        'total_params': total_params,
        'trainable_params': trainable_params
    }