diff --git a/PaddleCV/image_classification/train.py b/PaddleCV/image_classification/train.py index 144155b907166a41bfd32707329eb39d44349598..d51420a805776a2738036b4cdf4393b679539c98 100755 --- a/PaddleCV/image_classification/train.py +++ b/PaddleCV/image_classification/train.py @@ -189,12 +189,12 @@ def train(args): #Create test_prog and set layers' is_test params to True test_prog = test_prog.clone(for_test=True) - if args.use_pure_fp16: - cast_parameters_to_fp16(startup_prog) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup_prog) + if args.use_pure_fp16: + cast_parameters_to_fp16(exe, train_prog, fluid.global_scope()) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))