diff --git a/pdseg/export_model.py b/pdseg/export_model.py index 93c9bea98862572454960c75a6d7548249deebe9..27423bb705ba94d6418e376af04ad06d5d8ccb8e 100644 --- a/pdseg/export_model.py +++ b/pdseg/export_model.py @@ -49,6 +49,32 @@ def parse_args(): sys.exit(1) return parser.parse_args() +def export_inference_config(): + deploy_cfg = '''DEPLOY: + USE_GPU : 1 + MODEL_PATH : "%s" + MODEL_FILENAME : "%s" + PARAMS_FILENAME : "%s" + EVAL_CROP_SIZE : %s + MEAN : %s + STD : %s + IMAGE_TYPE : "%s" + NUM_CLASSES : %d + CHANNELS : %d + PRE_PROCESSOR : "SegPreProcessor" + PREDICTOR_MODE : "ANALYSIS" + BATCH_SIZE : 1 + ''' % (cfg.FREEZE.SAVE_DIR, cfg.FREEZE.MODEL_FILENAME, + cfg.FREEZE.PARAMS_FILENAME, cfg.EVAL_CROP_SIZE, + cfg.MEAN, cfg.STD, cfg.DATASET.IMAGE_TYPE, + cfg.DATASET.NUM_CLASSES, len(cfg.STD)) + if not os.path.exists(cfg.FREEZE.SAVE_DIR): + os.mkdir(cfg.FREEZE.SAVE_DIR) + yaml_path = os.path.join(cfg.FREEZE.SAVE_DIR, 'deploy.yaml') + with open(yaml_path, "w") as fp: + fp.write(deploy_cfg) + return yaml_path + def export_inference_model(args): """ @@ -81,6 +107,9 @@ def export_inference_model(args): model_filename=cfg.FREEZE.MODEL_FILENAME, params_filename=cfg.FREEZE.PARAMS_FILENAME) print("Inference model exported!") + print("Exporting inference model config...") + deploy_cfg_path = export_inference_config() + print("Inference model saved : [%s]" % (deploy_cfg_path)) def main():