diff --git a/test/local_test_cityscapes.py b/test/local_test_cityscapes.py index 0bf3c6aeb04e551da0cba454c6e9db7575efb1bb..6618695a60aae5f07230c546337b611d7c1cc78a 100644 --- a/test/local_test_cityscapes.py +++ b/test/local_test_cityscapes.py @@ -14,6 +14,7 @@ from test_utils import download_file_and_uncompress, train, eval, vis, export_model import os +import argparse LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset") @@ -43,7 +44,16 @@ if __name__ == "__main__": vis_dir = os.path.join(LOCAL_PATH, "visual", model_name) saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name) - devices = ['0'] + parser = argparse.ArgumentParser(description="PaddleSeg loacl test") + parser.add_argument("--devices", + dest="devices", + help="GPU id of running. if more than one, use spacing to separate.", + nargs="+", + default=0, + type=int) + args = parser.parse_args() + + devices = [str(x) for x in args.devices] export_model( flags=["--cfg", cfg],