提交 1fbb0875 编写于 作者: R ruri 提交者: GitHub

fix multi-cards multi-process bug (#4251)

* fix multi-card multi-process bug
上级 6345e607
......@@ -102,13 +102,16 @@ def validate(args,
test_batch_time_record = []
test_batch_metrics_record = []
test_batch_id = 0
compiled_program = best_strategy_compiled(
args,
test_prog,
test_fetch_list[0],
exe,
mode="val",
share_prog=train_prog)
if int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) > 1:
compiled_program = test_prog
else:
compiled_program = best_strategy_compiled(
args,
test_prog,
test_fetch_list[0],
exe,
mode="val",
share_prog=train_prog)
for batch in test_iter:
t1 = time.time()
test_batch_metrics = exe.run(program=compiled_program,
......
......@@ -85,8 +85,8 @@ def prepare_for_multi_process(exe, build_strategy, train_prog):
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
logger.info("PADDLE_TRAINERS_NUM", num_trainers)
logger.info("PADDLE_TRAINER_ID", trainer_id)
logger.info("PADDLE_TRAINERS_NUM %s" % num_trainers)
logger.info("PADDLE_TRAINER_ID %s" % trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册