未验证 提交 74625fc4 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #33 from 123malin/bug_fix

bug fix for gradient_scale_strategy
...@@ -248,8 +248,19 @@ class SingleTrainer(TranspileTrainer): ...@@ -248,8 +248,19 @@ class SingleTrainer(TranspileTrainer):
_exe_strategy = fluid.ExecutionStrategy() _exe_strategy = fluid.ExecutionStrategy()
# 0: kCoeffNumDevice; 1: One; 2: Customized # 0: kCoeffNumDevice; 1: One; 2: Customized
_build_strategy.gradient_scale_strategy = model_dict.get( _gradient_scale_strategy = model_dict.get("gradient_scale_strategy", 0)
"gradient_scale_strategy", 0) if _gradient_scale_strategy == 0:
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.CoeffNumDevice
elif _gradient_scale_strategy == 1:
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.One
elif _gradient_scale_strategy == 2:
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
else:
raise ValueError(
"Unsurpported config. gradient_scale_strategy must be one of [0, 1, 2]."
)
_build_strategy.gradient_scale_strategy = gradient_scale_strategy
if "thread_num" in model_dict and model_dict["thread_num"] > 1: if "thread_num" in model_dict and model_dict["thread_num"] > 1:
_build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce _build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
_exe_strategy.num_threads = model_dict["thread_num"] _exe_strategy.num_threads = model_dict["thread_num"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册