提交 1fbb0875 编写于 作者: R ruri 提交者: GitHub

fix multi-cards multi-process bug (#4251)

* fix multi-card multi-process bug
上级 6345e607
...@@ -102,13 +102,16 @@ def validate(args, ...@@ -102,13 +102,16 @@ def validate(args,
test_batch_time_record = [] test_batch_time_record = []
test_batch_metrics_record = [] test_batch_metrics_record = []
test_batch_id = 0 test_batch_id = 0
compiled_program = best_strategy_compiled( if int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) > 1:
args, compiled_program = test_prog
test_prog, else:
test_fetch_list[0], compiled_program = best_strategy_compiled(
exe, args,
mode="val", test_prog,
share_prog=train_prog) test_fetch_list[0],
exe,
mode="val",
share_prog=train_prog)
for batch in test_iter: for batch in test_iter:
t1 = time.time() t1 = time.time()
test_batch_metrics = exe.run(program=compiled_program, test_batch_metrics = exe.run(program=compiled_program,
......
...@@ -85,8 +85,8 @@ def prepare_for_multi_process(exe, build_strategy, train_prog): ...@@ -85,8 +85,8 @@ def prepare_for_multi_process(exe, build_strategy, train_prog):
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0)) trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return if num_trainers < 2: return
logger.info("PADDLE_TRAINERS_NUM", num_trainers) logger.info("PADDLE_TRAINERS_NUM %s" % num_trainers)
logger.info("PADDLE_TRAINER_ID", trainer_id) logger.info("PADDLE_TRAINER_ID %s" % trainer_id)
build_strategy.num_trainers = num_trainers build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model, # NOTE(zcd): use multi processes to train the model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册