dump_model_mgb.py 5.8 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from megskull.graph import NodeFilter, FpropEnv
from megskull.opr.all import AssertEqual, DataProvider, BatchNormalization
from megskull.utils.logconf import get_logger
from meghair.utils import io
import megbrain as mgb

import argparse
import struct
import re
import os

import numpy as np
import cv2

logger = get_logger(__name__)

def optimize_for_inference(args, outputs):
    args_map = {
        'enable_io16xc32': 'f16_io_f32_comp',
        'enable_ioc16': 'f16_io_comp',
        'enable_hwcd4': 'use_nhwcd4',
        'enable_nchw4': 'use_nchw4',
        'enable_nchw88': 'use_nchw88',
        'enable_nchw44': 'use_nchw44',
        'enable_nchw44_dot': 'use_nchw44_dot',
        'enable_nchw32': 'use_nchw32',
        'enable_chwn4': 'use_chwn4',
        'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
        'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z',
    }
    kwargs = {}
    for k, v in args_map.items():
        if getattr(args, k):
            assert args.optimize_for_inference, (
                'optimize_for_inference should be set when {} is given'.format(
                    k))
            kwargs[v] = True

    if args.optimize_for_inference:
        return mgb.optimize_for_inference(outputs, **kwargs)

    return outputs

def main():
    parser = argparse.ArgumentParser(
        description='Dump the Python Megbrain model to C++ model, by the way '
        'optimizing for inference',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('input', help='input pkl model file ')
    parser.add_argument('-o', '--output', help='output file', required=True)
    parser.add_argument('--init-bn', action='store_true',
                        help='initialize untrained batch-normalization, to '
                        'avoid NaN or Inf results')
    parser.add_argument('--silent', action='store_true',
                        help='set verbose to False in AssertEqual opr')
    parser.add_argument('--optimize-for-inference', action='store_true',
                        help='enbale optimization for inference')
    parser.add_argument('--discard-var-name', action='store_true',
                        help='discard variable and param names in the '
                        'generated output')
    parser.add_argument('--output-strip-info', action='store_true',
                        help='output code strip information')
    parser.add_argument('--enable-io16xc32', action='store_true',
                        help='transform the mode to float16 io float32 compute')
    parser.add_argument('--enable-ioc16', action='store_true',
                        help='transform the dtype of the model to float16 io '
                        'and compute')
    parser.add_argument('--enable-fuse-conv-bias-nonlinearity',
                        action='store_true',
                        help='fuse convolution bias and nonlinearity opr to a '
                        'conv_bias opr and compute')
    parser.add_argument('--enable-hwcd4', action='store_true',
                        help='transform the model format from NCHW to NHWCD4 '
                        'for inference; you may need to disable CUDA and set '
                        'MGB_USE_MEGDNN_DBG=2')
    parser.add_argument('--enable-nchw4', action='store_true',
                        help='transform the model format from NCHW to NCHW4 '
                        'for inference')
    parser.add_argument('--enable-nchw88', action='store_true',
                        help='transform the model format from NCHW to NCHW88 '
                        'for inference')
    parser.add_argument('--enable-nchw44', action='store_true',
                        help='transform the model format from NCHW to NCHW44 '
                        'for inference')
    parser.add_argument('--enable-nchw44-dot', action='store_true',
                        help='transform the model format from NCHW to NCHW44_DOT '
                        'for optimizing armv8.2 dot in inference')
    parser.add_argument('--enable-chwn4', action='store_true',
                        help='transform the model format to CHWN4 '
                        'for inference, mainly used for nvidia tensorcore')
    parser.add_argument('--enable-nchw32', action='store_true',
                        help='transform the model format from NCHW4 to NCHW32 '
                        'for inference on nvidia TensoCore')
    parser.add_argument('--enable-fuse-conv-bias-with-z', action='store_true',
                        help='fuse conv_bias with z input for inference on '
                        'nvidia GPU (this optimization pass will result in mismatch '
                        'of the precision of output of training and inference)')
    args = parser.parse_args()

    env = FpropEnv(verbose_fprop=False)


    outputs = io.load_network(args.input).outputs

    output_mgbvars = list(map(env.get_mgbvar, outputs))

    output_mgbvars = optimize_for_inference(args, output_mgbvars)

    if args.discard_var_name:
        sereg_kwargs = dict(keep_var_name=0, keep_param_name=False)
    else:
        sereg_kwargs = dict(keep_var_name=2, keep_param_name=True)

    stat = mgb.serialize_comp_graph_to_file(
        args.output, output_mgbvars, append=False,
        output_strip_info=args.output_strip_info,
        **sereg_kwargs)
    logger.info('graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'.
                format(stat.tot_bytes / 1024,
                       (stat.tot_bytes - stat.tensor_value_bytes) / 1024))

if __name__ == '__main__':
    main()