command.py 2.9 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
from six import text_type as _text_type
import argparse
import sys


def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_dir",
        "-m",
        type=_text_type,
        default=None,
        help="define model directory path")
    parser.add_argument(
        "--save_dir",
        "-s",
        type=_text_type,
        default=None,
        help="path to save inference model")
    parser.add_argument(
        "--version",
        "-v",
        action="store_true",
        default=False,
        help="get version of PaddleX")
    parser.add_argument(
        "--export_inference",
        "-e",
        action="store_true",
        default=False,
        help="export inference model for C++/Python deployment")
32 33 34 35 36 37
    parser.add_argument(
        "--export_onnx",
        "-eo",
        action="store_true",
        default=False,
        help="export onnx model for deployment")
C
Channingss 已提交
38 39 40 41
    parser.add_argument(
        "--fixed_input_shape",
        "-fs",
        default=None,
42
        help="export inference model with fixed input shape:[w,h]")
J
jiangjiajun 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    return parser


def main():
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = ""

    import paddlex as pdx

    if len(sys.argv) < 2:
        print("Use command 'paddlex -h` to print the help information\n")
        return
    parser = arg_parser()
    args = parser.parse_args()

    if args.version:
        print("PaddleX-{}".format(pdx.__version__))
        print("Repo: https://github.com/PaddlePaddle/PaddleX.git")
        print("Email: paddlex@baidu.com")
        return
C
Channingss 已提交
63

J
jiangjiajun 已提交
64 65 66
    if args.export_inference:
        assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
        assert args.save_dir is not None, "--save_dir should be defined to save inference model"
C
Channingss 已提交
67 68 69 70 71

        fixed_input_shape = None
        if args.fixed_input_shape is not None:
            fixed_input_shape = eval(args.fixed_input_shape)
            assert len(
C
Channingss 已提交
72 73
                fixed_input_shape
            ) == 2, "len of fixed input shape must == 2, such as [224,224]"
C
Channingss 已提交
74 75

        model = pdx.load_model(args.model_dir, fixed_input_shape)
76
        model.export_inference_model(args.save_dir)
J
jiangjiajun 已提交
77

C
Channingss 已提交
78 79
    if args.export_onnx:
        assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
C
Channingss 已提交
80 81
        assert args.save_dir is not None, "--save_dir should be defined to create onnx model"
        assert args.fixed_input_shape is not None, "--fixed_input_shape should be defined [w,h] to create onnx model, such as [224,224]"
C
Channingss 已提交
82

C
Channingss 已提交
83
        fixed_input_shape = []
C
Channingss 已提交
84 85 86
        if args.fixed_input_shape is not None:
            fixed_input_shape = eval(args.fixed_input_shape)
            assert len(
C
Channingss 已提交
87 88
                fixed_input_shape
            ) == 2, "len of fixed input shape must == 2, such as [224,224]"
C
Channingss 已提交
89
        model = pdx.load_model(args.model_dir, fixed_input_shape)
C
Channingss 已提交
90
        pdx.convertor.export_onnx_model(model, args.save_dir)
C
Channingss 已提交
91

J
jiangjiajun 已提交
92 93 94

if __name__ == "__main__":
    main()