提交 97cc5f45 编写于 作者: M malin10

bug fix

上级 7bef765c
...@@ -253,6 +253,9 @@ class SingleTrainer(TranspileTrainer): ...@@ -253,6 +253,9 @@ class SingleTrainer(TranspileTrainer):
_build_strategy = fluid.BuildStrategy() _build_strategy = fluid.BuildStrategy()
_exe_strategy = fluid.ExecutionStrategy() _exe_strategy = fluid.ExecutionStrategy()
# 0: kCoeffNumDevice; 1: One; 2: Customized
_build_strategy.gradient_scale_strategy = model_dict.get(
"gradient_scale_strategy", 0)
if "thread_num" in model_dict and model_dict["thread_num"] > 1: if "thread_num" in model_dict and model_dict["thread_num"] > 1:
_build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce _build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
_exe_strategy.num_threads = model_dict["thread_num"] _exe_strategy.num_threads = model_dict["thread_num"]
......
...@@ -77,6 +77,7 @@ phase: ...@@ -77,6 +77,7 @@ phase:
model: "{workspace}/model.py" # user-defined model model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_train # select dataset by name dataset_name: dataset_train # select dataset by name
thread_num: 1 thread_num: 1
gradient_scale_strategy: 1
#- name: phase2 #- name: phase2
# model: "{workspace}/model.py" # user-defined model # model: "{workspace}/model.py" # user-defined model
# dataset_name: dataset_infer # select dataset by name # dataset_name: dataset_infer # select dataset by name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册