dump_model.py 3.5 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9 10 11 12
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse

import numpy as np
import yaml
13
from megengine import jit, tensor
14
from megengine.module.external import ExternOprSubgraph
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29


# "1,3,224,224" -> (1,3,224,224)
def str2tuple(x):
    x = x.split(",")
    x = [int(a) for a in x]
    x = tuple(x)
    return x


def main():
    parser = argparse.ArgumentParser(
        description="load a .pb model and convert to corresponding "
        "load-and-run model"
    )
30 31
    parser.add_argument("--input", help="mace model file")
    parser.add_argument("--param", help="mace param file")
32
    parser.add_argument(
33
        "--output", help="converted model that can be fed to dump_with_testcase_mge.py"
34
    )
35
    parser.add_argument("--config", help="config file with yaml format")
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    args = parser.parse_args()

    with open(args.config, "r") as f:
        configs = yaml.load(f)

    for model_name in configs["models"]:
        # ignore several sub models currently
        sub_model = configs["models"][model_name]["subgraphs"][0]

        # input/output shapes
        isizes = [str2tuple(x) for x in sub_model["input_shapes"]]

        # input/output names
        input_names = sub_model["input_tensors"]
        if "check_tensors" in sub_model:
            output_names = sub_model["check_tensors"]
            osizes = [str2tuple(x) for x in sub_model["check_shapes"]]
        else:
            output_names = sub_model["output_tensors"]
            osizes = [str2tuple(x) for x in sub_model["output_shapes"]]

        with open(args.input, "rb") as fin:
            raw_model = fin.read()
        with open(args.param, "rb") as fin:
            raw_param = fin.read()

        model_size = (len(raw_model)).to_bytes(4, byteorder="little")
        param_size = (len(raw_param)).to_bytes(4, byteorder="little")

        n_inputs = (len(input_names)).to_bytes(4, byteorder="little")
        n_outputs = (len(output_names)).to_bytes(4, byteorder="little")

        names_buffer = n_inputs + n_outputs
        for iname in input_names:
            names_buffer += (len(iname)).to_bytes(4, byteorder="little")
            names_buffer += str.encode(iname)
        for oname in output_names:
            names_buffer += (len(oname)).to_bytes(4, byteorder="little")
            names_buffer += str.encode(oname)

        shapes_buffer = n_outputs
        for oshape in osizes:
            shapes_buffer += (len(oshape)).to_bytes(4, byteorder="little")
            for oi in oshape:
                shapes_buffer += oi.to_bytes(4, byteorder="little")

        # raw content contains:
        # input/output names + output shapes + model buffer + param buffer
        wk_raw_content = (
            names_buffer
            + shapes_buffer
            + model_size
            + raw_model
            + param_size
            + raw_param
        )

93
        net = ExternOprSubgraph(osizes, "mace", wk_raw_content)
94
        net.eval()
95

96
        @jit.trace(record_only=True)
97 98 99 100
        def inference(inputs):
            return net(inputs)

        inputs = [
101
            tensor(np.random.random(isizes[i]).astype(np.float32)) for i in range(len(isizes))
102
        ]
103
        inference(*inputs)
104
        inference.dump(args.output)
105 106 107 108


if __name__ == "__main__":
    main()