未验证 提交 0dce81b8 编写于 作者: T tangwei12 提交者: GitHub

Merge branch 'master' into fix_customer_reader

......@@ -209,12 +209,10 @@ class RunnerBase(object):
if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context[
"is_infer"] == False:
if context["fleet_mode"]:
if context["fleet_mode"].upper() == "PS":
train_prog = context["model"][model_dict[
"name"]]["main_program"]
elif not context["is_fleet"] or context[
"fleet_mode"].upper() == "COLLECTIVE":
train_prog = context["model"][model_dict["name"]][
"main_program"]
else:
train_prog = context["model"][model_dict["name"]][
"default_main_program"]
startup_prog = context["model"][model_dict["name"]][
......
......@@ -114,15 +114,13 @@ runner:
print_interval: 1
phases: [phase1]
- name: local_ps_train
class: local_cluster_train
- name: single_multi_gpu_train
class: train
# num of epochs
epochs: 1
# device to run training or infer
device: cpu
selected_gpus: "0" # 选择多卡执行训练
work_num: 1
server_num: 1
device: gpu
selected_gpus: "0,1" # 选择多卡执行训练
save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 4 # save inference
save_step_interval: 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册