pack_model_and_info.py 5.2 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5
#
6 7 8
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111

import argparse
import struct
import os
import subprocess

import flatbuffers

def generate_flatbuffer():
    status, path = subprocess.getstatusoutput('which flatc')
    if not status:
        cwd = os.path.dirname(os.path.dirname(__file__))
        fbs_file = os.path.abspath(os.path.join(cwd,
            "../../src/parse_model/pack_model.fbs"))
        cmd = path + ' -p -b '+fbs_file
        ret, _ = subprocess.getstatusoutput(str(cmd))
        if ret:
            raise Exception("flatc generate error!")
    else:
        raise Exception('no flatc in current environment, please build flatc '
                'and put in the system PATH!')

def main():
    parser = argparse.ArgumentParser(
            description='load a encrypted or not encrypted model and a '
            'json format of the infomation of the model, pack them to a file '
            'which can be loaded by lite.')
    parser.add_argument('--input-model', help='input a encrypted or not encrypted model')
    parser.add_argument('--input-info', help='input a encrypted or not encrypted '
            'json format file.')
    parser.add_argument('--model-name', help='the model name, this must match '
            'with the model name in model info', default = 'NONE')
    parser.add_argument('--model-cryption', help='the model encryption method '
            'name, this is used to find the right decryption method. e.g. '
            '--model_cryption = "AES_default", default is NONE.', default =
            'NONE')
    parser.add_argument('--info-cryption', help='the info encryption method '
            'name, this is used to find the right decryption method. e.g. '
            '--model_cryption = "AES_default", default is NONE.', default =
            'NONE')
    parser.add_argument('--info-parser', help='The information parse method name '
            'default is "LITE_default". ', default = 'LITE_default')
    parser.add_argument('--append', '-a', help='append another model to a '
            'packed model.')
    parser.add_argument('--output', '-o', help='output file of packed model.')

    args = parser.parse_args()

    generate_flatbuffer()
    assert not args.append, ('--append is not support yet')
    assert args.input_model, ('--input_model must be given')
    with open(args.input_model, 'rb') as fin:
        raw_model = fin.read()

    model_length = len(raw_model)

    if args.input_info:
        with open(args.input_info, 'rb') as fin:
            raw_info = fin.read()
            info_length = len(raw_info)
    else:
        raw_info = None
        info_length = 0

    # Generated by `flatc`.
    from model_parse import Model, ModelData, ModelHeader, ModelInfo, PackModel

    builder = flatbuffers.Builder(1024)

    model_name = builder.CreateString(args.model_name)
    model_cryption = builder.CreateString(args.model_cryption)
    info_cryption = builder.CreateString(args.info_cryption)
    info_parser = builder.CreateString(args.info_parser)

    info_data = builder.CreateByteVector(raw_info)
    arr_data = builder.CreateByteVector(raw_model)

    #model header
    ModelHeader.ModelHeaderStart(builder)
    ModelHeader.ModelHeaderAddName(builder, model_name)
    ModelHeader.ModelHeaderAddModelDecryptionMethod(builder, model_cryption)
    ModelHeader.ModelHeaderAddInfoDecryptionMethod(builder, info_cryption)
    ModelHeader.ModelHeaderAddInfoParseMethod(builder, info_parser)
    model_header = ModelHeader.ModelHeaderEnd(builder)

    #model info
    ModelInfo.ModelInfoStart(builder)
    ModelInfo.ModelInfoAddData(builder, info_data)
    model_info = ModelInfo.ModelInfoEnd(builder)

    #model data
    ModelData.ModelDataStart(builder)
    ModelData.ModelDataAddData(builder, arr_data)
    model_data = ModelData.ModelDataEnd(builder)

    Model.ModelStart(builder)
    Model.ModelAddHeader(builder, model_header)
    Model.ModelAddData(builder, model_data)
    Model.ModelAddInfo(builder, model_info)
    model = Model.ModelEnd(builder)

    PackModel.PackModelStartModelsVector(builder, 1)
    builder.PrependUOffsetTRelative(model)
112
    models = builder.EndVector()
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

    PackModel.PackModelStart(builder)
    PackModel.PackModelAddModels(builder, models)
    packed_model = PackModel.PackModelEnd(builder)

    builder.Finish(packed_model)
    buff = builder.Output()

    result = struct.pack(str(len("packed_model")) + 's', "packed_model".encode('ascii'))
    result += buff

    assert args.output, ('--output must be given')
    with open(args.output, 'wb') as fin:
        fin.write(result)

    print("Model packaged successfully!!!")
    print("model name is: {}.".format(args.model_name))
    print("model encryption method is: {}. ".format(args.model_cryption))
    print("model json infomation encryption method is: {}. ".format(args.info_cryption))
    print("model json infomation parse method is: {}. ".format(args.info_parser))
    print("packed model is write to {} ".format(args.output))

if __name__ == '__main__':
    main()