diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index 21890857a007925c7e759a6165934b9ad838bcae..d6461827792cb875fa99da400428eab21bc9f22a 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -248,8 +248,19 @@ class SingleTrainer(TranspileTrainer): _exe_strategy = fluid.ExecutionStrategy() # 0: kCoeffNumDevice; 1: One; 2: Customized - _build_strategy.gradient_scale_strategy = model_dict.get( - "gradient_scale_strategy", 0) + _gradient_scale_strategy = model_dict.get("gradient_scale_strategy", 0) + if _gradient_scale_strategy == 0: + gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.CoeffNumDevice + elif _gradient_scale_strategy == 1: + gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.One + elif _gradient_scale_strategy == 2: + gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized + else: + raise ValueError( + "Unsurpported config. gradient_scale_strategy must be one of [0, 1, 2]." + ) + _build_strategy.gradient_scale_strategy = gradient_scale_strategy + if "thread_num" in model_dict and model_dict["thread_num"] > 1: _build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce _exe_strategy.num_threads = model_dict["thread_num"]