提交 8e5db1cb 编写于 作者: L liuyuhui

fix bug

上级 369d2592
...@@ -212,11 +212,9 @@ class RunnerBase(object): ...@@ -212,11 +212,9 @@ class RunnerBase(object):
if context["fleet_mode"].upper() == "PS": if context["fleet_mode"].upper() == "PS":
train_prog = context["model"][model_dict["name"]][ train_prog = context["model"][model_dict["name"]][
"main_program"] "main_program"]
print("condition 1")
else: else:
train_prog = context["model"][model_dict["name"]][ train_prog = context["model"][model_dict["name"]][
"default_main_program"] "default_main_program"]
print("condition 2")
startup_prog = context["model"][model_dict["name"]][ startup_prog = context["model"][model_dict["name"]][
"startup_program"] "startup_program"]
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
......
...@@ -114,15 +114,13 @@ runner: ...@@ -114,15 +114,13 @@ runner:
print_interval: 1 print_interval: 1
phases: [phase1] phases: [phase1]
- name: local_ps_train - name: single_multi_gpu_train
class: local_cluster_train class: train
# num of epochs # num of epochs
epochs: 1 epochs: 1
# device to run training or infer # device to run training or infer
device: cpu device: gpu
selected_gpus: "0" # 选择多卡执行训练 selected_gpus: "0,1" # 选择多卡执行训练
work_num: 1
server_num: 1
save_checkpoint_interval: 1 # save model interval of epochs save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 4 # save inference save_inference_interval: 4 # save inference
save_step_interval: 1 save_step_interval: 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册