提交 1fbb0875 编写于 作者: R ruri 提交者: GitHub

fix multi-cards multi-process bug (#4251)

* fix multi-card multi-process bug
上级 6345e607
...@@ -102,6 +102,9 @@ def validate(args, ...@@ -102,6 +102,9 @@ def validate(args,
test_batch_time_record = [] test_batch_time_record = []
test_batch_metrics_record = [] test_batch_metrics_record = []
test_batch_id = 0 test_batch_id = 0
if int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) > 1:
compiled_program = test_prog
else:
compiled_program = best_strategy_compiled( compiled_program = best_strategy_compiled(
args, args,
test_prog, test_prog,
......
...@@ -85,8 +85,8 @@ def prepare_for_multi_process(exe, build_strategy, train_prog): ...@@ -85,8 +85,8 @@ def prepare_for_multi_process(exe, build_strategy, train_prog):
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0)) trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return if num_trainers < 2: return
logger.info("PADDLE_TRAINERS_NUM", num_trainers) logger.info("PADDLE_TRAINERS_NUM %s" % num_trainers)
logger.info("PADDLE_TRAINER_ID", trainer_id) logger.info("PADDLE_TRAINER_ID %s" % trainer_id)
build_strategy.num_trainers = num_trainers build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model, # NOTE(zcd): use multi processes to train the model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册