command.py 2.4 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
from six import text_type as _text_type
import argparse
import sys


def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_dir",
        "-m",
        type=_text_type,
        default=None,
        help="define model directory path")
    parser.add_argument(
        "--save_dir",
        "-s",
        type=_text_type,
        default=None,
        help="path to save inference model")
    parser.add_argument(
        "--version",
        "-v",
        action="store_true",
        default=False,
        help="get version of PaddleX")
    parser.add_argument(
        "--export_inference",
        "-e",
        action="store_true",
        default=False,
        help="export inference model for C++/Python deployment")
C
Channingss 已提交
32 33 34 35
    parser.add_argument(
        "--fixed_input_shape",
        "-fs",
        default=None,
36
        help="export inference model with fixed input shape:[w,h]")
J
jiangjiajun 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    return parser


def main():
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = ""

    import paddlex as pdx

    if len(sys.argv) < 2:
        print("Use command 'paddlex -h` to print the help information\n")
        return
    parser = arg_parser()
    args = parser.parse_args()

    if args.version:
        print("PaddleX-{}".format(pdx.__version__))
        print("Repo: https://github.com/PaddlePaddle/PaddleX.git")
        print("Email: paddlex@baidu.com")
        return
    if args.export_inference:
        assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
        assert args.save_dir is not None, "--save_dir should be defined to save inference model"
C
Channingss 已提交
60
        fixed_input_shape = eval(args.fixed_input_shape)
61 62
        assert len(
            fixed_input_shape) == 2, "len of fixed input shape must == 2"
C
Channingss 已提交
63 64

        model = pdx.load_model(args.model_dir, fixed_input_shape)
65
        model.export_inference_model(args.save_dir)
J
jiangjiajun 已提交
66

C
Channingss 已提交
67 68 69 70 71 72 73 74 75 76
    if args.export_onnx:
        assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
        assert args.save_dir is not None, "--save_dir should be defined to save onnx model"
        fixed_input_shape = eval(args.fixed_input_shape)
        assert len(
            fixed_input_shape) == 2, "len of fixed input shape must == 2"

        model = pdx.load_model(args.model_dir, fixed_input_shape)
        model.export_onnx_model(args.save_dir)

J
jiangjiajun 已提交
77 78 79

if __name__ == "__main__":
    main()