未验证 提交 b67db1d9 编写于 作者: Z zhang wenhui 提交者: GitHub

Merge pull request #1486 from frankwhzhang/fix_bug

fix parallel bugs
......@@ -69,10 +69,8 @@ def train():
# Initialize executor
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
if training_role == "PSERVER":
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
......@@ -84,7 +82,6 @@ def train():
model_dir = args.model_dir
fetch_list = [avg_cost.name]
exe.run(fluid.default_startup_program())
total_time = 0.0
for pass_idx in six.moves.xrange(pass_num):
epoch_idx = pass_idx + 1
......
......@@ -73,6 +73,7 @@ def train():
# Initialize executor
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
......@@ -83,7 +84,6 @@ def train():
pass_num = args.pass_num
model_dir = args.model_dir
fetch_list = [avg_cost.name]
exe.run(fluid.default_startup_program())
total_time = 0.0
for pass_idx in range(pass_num):
epoch_idx = pass_idx + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册