diff --git a/test/local_test_pet.py b/test/local_test_pet.py index 1596f4b2b7be92c32dd8cad1d8be1d79e794d142..7d0cf58cd1235575fc960769d5142865993b5763 100644 --- a/test/local_test_pet.py +++ b/test/local_test_pet.py @@ -14,6 +14,7 @@ from test_utils import download_file_and_uncompress, train, eval, vis, export_model import os +import argparse LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset") @@ -44,7 +45,16 @@ if __name__ == "__main__": vis_dir = os.path.join(LOCAL_PATH, "visual", model_name) saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name) - devices = ['0'] + parser = argparse.ArgumentParser(description="PaddleSeg loacl test") + parser.add_argument("--devices", + dest="devices", + help="GPU id of running. if more than one, use spacing to separate.", + nargs="+", + default=0, + type=int) + args = parser.parse_args() + + devices = [str(x) for x in args.devices] train( flags=["--cfg", cfg, "--use_gpu", "--log_steps", "10"],