diff --git a/pdseg/train.py b/pdseg/train.py index 8c017f9d355cba5abdd470b5dfa3063b8bcb4d77..59349b33fadb980b2b1eab4b196b263ed1f2aba5 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -462,7 +462,7 @@ def main(args): cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0)) cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) - cfg.check_and_infer(reset_dataset=True) + cfg.check_and_infer() print_info(pprint.pformat(cfg)) train(cfg)