From a82e615582d9903374f5c812d620919f3bec2279 Mon Sep 17 00:00:00 2001 From: malin10 Date: Wed, 3 Jun 2020 13:24:55 +0800 Subject: [PATCH] bug fix --- core/trainers/single_trainer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index 21890857..abff7492 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -248,8 +248,15 @@ class SingleTrainer(TranspileTrainer): _exe_strategy = fluid.ExecutionStrategy() # 0: kCoeffNumDevice; 1: One; 2: Customized - _build_strategy.gradient_scale_strategy = model_dict.get( - "gradient_scale_strategy", 0) + _gradient_scale_strategy = model_dict.get("gradient_scale_strategy", 0) + if _gradient_scale_strategy == 0: + gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.CoeffNumDevice + elif _gradient_scale_strategy == 1: + gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.One + elif _gradient_scale_strategy == 2: + gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized + _build_strategy.gradient_scale_strategy = gradient_scale_strategy + if "thread_num" in model_dict and model_dict["thread_num"] > 1: _build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce _exe_strategy.num_threads = model_dict["thread_num"] -- GitLab