import os import yaml import argparse def str2bool(v): if v.lower() == 'true': return True else: return False def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--type', required=True, choices=["cls", "shitu"]) parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--mkldnn', type=str2bool, default=True) parser.add_argument('--gpu', type=str2bool, default=False) parser.add_argument('--cpu_thread', type=int, default=1) parser.add_argument('--tensorrt', type=str2bool, default=False) parser.add_argument('--precision', type=str, choices=["fp32", "fp16"]) parser.add_argument('--benchmark', type=str2bool, default=True) parser.add_argument( '--cls_yaml_path', type=str, default="deploy/configs/inference_cls.yaml") parser.add_argument( '--shitu_yaml_path', type=str, default="deploy/configs/inference_drink.yaml") parser.add_argument('--data_dir', type=str, required=True) parser.add_argument('--save_path', type=str, default='./') parser.add_argument('--cls_model_dir', type=str) parser.add_argument('--det_model_dir', type=str) args = parser.parse_args() return args def main(): args = parse_args() if args.type == "cls": save_path = os.path.join(args.save_path, os.path.basename(args.cls_yaml_path)) fd = open(args.cls_yaml_path) else: save_path = os.path.join(args.save_path, os.path.basename(args.shitu_yaml_path)) fd = open(args.shitu_yaml_path) config = yaml.load(fd, yaml.FullLoader) fd.close() config["Global"]["batch_size"] = args.batch_size config["Global"]["use_gpu"] = args.gpu config["Global"]["enable_mkldnn"] = args.mkldnn config["Global"]["benchmark"] = args.benchmark config["Global"]["use_tensorrt"] = args.tensorrt config["Global"]["use_fp16"] = True if args.precision == "fp16" else False if args.type == "cls": config["Global"]["infer_imgs"] = args.data_dir assert args.cls_model_dir config["Global"]["inference_model_dir"] = args.cls_model_dir else: config["Global"]["infer_imgs"] = os.path.join(args.data_dir, "test_images") config["IndexProcess"]["index_dir"] = os.path.join(args.data_dir, "index") assert args.cls_model_dir assert args.det_model_dir config["Global"]["det_inference_model_dir"] = args.det_model_dir config["Global"]["rec_inference_model_dir"] = args.cls_model_dir with open(save_path, 'w') as fd: yaml.dump(config, fd) print("Generate new yaml done") if __name__ == "__main__": main()