未验证 提交 777be5a0 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #220 from vslyu/fix_save_step

fix save step bug
......@@ -209,12 +209,10 @@ class RunnerBase(object):
if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context[
"is_infer"] == False:
if context["fleet_mode"]:
if context["fleet_mode"].upper() == "PS":
train_prog = context["model"][model_dict[
"name"]]["main_program"]
elif not context["is_fleet"] or context[
"fleet_mode"].upper() == "COLLECTIVE":
if context["fleet_mode"].upper() == "PS":
train_prog = context["model"][model_dict["name"]][
"main_program"]
else:
train_prog = context["model"][model_dict["name"]][
"default_main_program"]
startup_prog = context["model"][model_dict["name"]][
......
......@@ -114,15 +114,13 @@ runner:
print_interval: 1
phases: [phase1]
- name: local_ps_train
class: local_cluster_train
- name: single_multi_gpu_train
class: train
# num of epochs
epochs: 1
# device to run training or infer
device: cpu
selected_gpus: "0" # 选择多卡执行训练
work_num: 1
server_num: 1
device: gpu
selected_gpus: "0,1" # 选择多卡执行训练
save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 4 # save inference
save_step_interval: 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册