pack_model_and_info.py 5.2 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5
#
6 7 8
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

import argparse
import struct
import os
import subprocess

import flatbuffers

def generate_flatbuffer():
    status, path = subprocess.getstatusoutput('which flatc')
    if not status:
        cwd = os.path.dirname(os.path.dirname(__file__))
        fbs_file = os.path.abspath(os.path.join(cwd,
            "../../src/parse_model/pack_model.fbs"))
        cmd = path + ' -p -b '+fbs_file
        ret, _ = subprocess.getstatusoutput(str(cmd))
        if ret:
            raise Exception("flatc generate error!")
    else:
        raise Exception('no flatc in current environment, please build flatc '
                'and put in the system PATH!')

def main():
    parser = argparse.ArgumentParser(
            description='load a encrypted or not encrypted model and a '
            'json format of the infomation of the model, pack them to a file '
            'which can be loaded by lite.')
    parser.add_argument('--input-model', help='input a encrypted or not encrypted model')
    parser.add_argument('--input-info', help='input a encrypted or not encrypted '
            'json format file.')
    parser.add_argument('--model-name', help='the model name, this must match '
            'with the model name in model info', default = 'NONE')
    parser.add_argument('--model-cryption', help='the model encryption method '
            'name, this is used to find the right decryption method. e.g. '
            '--model_cryption = "AES_default", default is NONE.', default =
            'NONE')
    parser.add_argument('--info-cryption', help='the info encryption method '
            'name, this is used to find the right decryption method. e.g. '
            '--model_cryption = "AES_default", default is NONE.', default =
            'NONE')
    parser.add_argument('--info-parser', help='The information parse method name '
            'default is "LITE_default". ', default = 'LITE_default')
    parser.add_argument('--append', '-a', help='append another model to a '
            'packed model.')
    parser.add_argument('--output', '-o', help='output file of packed model.')

    args = parser.parse_args()

    generate_flatbuffer()
    assert not args.append, ('--append is not support yet')
    assert args.input_model, ('--input_model must be given')
    with open(args.input_model, 'rb') as fin:
        raw_model = fin.read()

    model_length = len(raw_model)

    if args.input_info:
        with open(args.input_info, 'rb') as fin:
            raw_info = fin.read()
            info_length = len(raw_info)
    else:
        raw_info = None
        info_length = 0

    # Generated by `flatc`.
    from model_parse import Model, ModelData, ModelHeader, ModelInfo, PackModel

    builder = flatbuffers.Builder(1024)

    model_name = builder.CreateString(args.model_name)
    model_cryption = builder.CreateString(args.model_cryption)
    info_cryption = builder.CreateString(args.info_cryption)
    info_parser = builder.CreateString(args.info_parser)

    info_data = builder.CreateByteVector(raw_info)
    arr_data = builder.CreateByteVector(raw_model)

    #model header
    ModelHeader.ModelHeaderStart(builder)
    ModelHeader.ModelHeaderAddName(builder, model_name)
    ModelHeader.ModelHeaderAddModelDecryptionMethod(builder, model_cryption)
    ModelHeader.ModelHeaderAddInfoDecryptionMethod(builder, info_cryption)
    ModelHeader.ModelHeaderAddInfoParseMethod(builder, info_parser)
    model_header = ModelHeader.ModelHeaderEnd(builder)

    #model info
    ModelInfo.ModelInfoStart(builder)
    ModelInfo.ModelInfoAddData(builder, info_data)
    model_info = ModelInfo.ModelInfoEnd(builder)

    #model data
    ModelData.ModelDataStart(builder)
    ModelData.ModelDataAddData(builder, arr_data)
    model_data = ModelData.ModelDataEnd(builder)

    Model.ModelStart(builder)
    Model.ModelAddHeader(builder, model_header)
    Model.ModelAddData(builder, model_data)
    Model.ModelAddInfo(builder, model_info)
    model = Model.ModelEnd(builder)

    PackModel.PackModelStartModelsVector(builder, 1)
    builder.PrependUOffsetTRelative(model)
    models = builder.EndVector(1)

    PackModel.PackModelStart(builder)
    PackModel.PackModelAddModels(builder, models)
    packed_model = PackModel.PackModelEnd(builder)

    builder.Finish(packed_model)
    buff = builder.Output()

    result = struct.pack(str(len("packed_model")) + 's', "packed_model".encode('ascii'))
    result += buff

    assert args.output, ('--output must be given')
    with open(args.output, 'wb') as fin:
        fin.write(result)

    print("Model packaged successfully!!!")
    print("model name is: {}.".format(args.model_name))
    print("model encryption method is: {}. ".format(args.model_cryption))
    print("model json infomation encryption method is: {}. ".format(args.info_cryption))
    print("model json infomation parse method is: {}. ".format(args.info_parser))
    print("packed model is write to {} ".format(args.output))

if __name__ == '__main__':
    main()