generate_cpp_yaml.py 2.8 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
import os
import yaml
import argparse


def str2bool(v):
    if v.lower() == 'true':
        return True
    else:
        return False


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--type', required=True, choices=["cls", "shitu"])
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--mkldnn', type=str2bool, default=True)
    parser.add_argument('--gpu', type=str2bool, default=False)
    parser.add_argument('--cpu_thread', type=int, default=1)
    parser.add_argument('--tensorrt', type=str2bool, default=False)
    parser.add_argument('--precision', type=str, choices=["fp32", "fp16"])
    parser.add_argument('--benchmark', type=str2bool, default=True)
    parser.add_argument(
        '--cls_yaml_path',
        type=str,
        default="deploy/configs/inference_cls.yaml")
    parser.add_argument(
        '--shitu_yaml_path',
        type=str,
        default="deploy/configs/inference_drink.yaml")
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--save_path', type=str, default='./')
    parser.add_argument('--cls_model_dir', type=str)
    parser.add_argument('--det_model_dir', type=str)
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    if args.type == "cls":
        save_path = os.path.join(args.save_path,
                                 os.path.basename(args.cls_yaml_path))
        fd = open(args.cls_yaml_path)
    else:
        save_path = os.path.join(args.save_path,
                                 os.path.basename(args.shitu_yaml_path))
        fd = open(args.shitu_yaml_path)
    config = yaml.load(fd, yaml.FullLoader)
    fd.close()

    config["Global"]["batch_size"] = args.batch_size
    config["Global"]["use_gpu"] = args.gpu
    config["Global"]["enable_mkldnn"] = args.mkldnn
    config["Global"]["benchmark"] = args.benchmark
    config["Global"]["use_tensorrt"] = args.tensorrt
    config["Global"]["use_fp16"] = True if args.precision == "fp16" else False
    if args.type == "cls":
        config["Global"]["infer_imgs"] = args.data_dir
        assert args.cls_model_dir
        config["Global"]["inference_model_dir"] = args.cls_model_dir
    else:
        config["Global"]["infer_imgs"] = os.path.join(args.data_dir,
                                                      "test_images")
        config["IndexProcess"]["index_dir"] = os.path.join(args.data_dir,
                                                           "index")
        assert args.cls_model_dir
        assert args.det_model_dir
        config["Global"]["det_inference_model_dir"] = args.det_model_dir
        config["Global"]["rec_inference_model_dir"] = args.cls_model_dir

    with open(save_path, 'w') as fd:
        yaml.dump(config, fd)
    print("Generate new yaml done")


if __name__ == "__main__":
    main()