提交 c4cee259 编写于 作者: M malin10

bug fix

上级 a82e6155
...@@ -248,10 +248,9 @@ class SingleTrainer(TranspileTrainer): ...@@ -248,10 +248,9 @@ class SingleTrainer(TranspileTrainer):
_exe_strategy = fluid.ExecutionStrategy() _exe_strategy = fluid.ExecutionStrategy()
# 0: kCoeffNumDevice; 1: One; 2: Customized # 0: kCoeffNumDevice; 1: One; 2: Customized
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.CoeffNumDevice
_gradient_scale_strategy = model_dict.get("gradient_scale_strategy", 0) _gradient_scale_strategy = model_dict.get("gradient_scale_strategy", 0)
if _gradient_scale_strategy == 0: if _gradient_scale_strategy == 1:
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.CoeffNumDevice
elif _gradient_scale_strategy == 1:
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.One gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.One
elif _gradient_scale_strategy == 2: elif _gradient_scale_strategy == 2:
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册