未验证 提交 1a4c2938 编写于 作者: Y Yibing Liu 提交者: GitHub

Update train.py

上级 8a08737a
......@@ -322,9 +322,10 @@ def train(args):
exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope
build_strategy = fluid.BuildStrategy()
if sys.platform == "win32" and nccl2_num_trainers > 1:
if not sys.platform == "win32":
build_strategy.num_trainers = nccl2_num_trainers
elif nccl2_num_trainers > 1:
raise ValueError("Windows platform doesn't support distributed training!")
build_strategy.num_trainers = nccl2_num_trainers
build_strategy.trainer_id = nccl2_trainer_id
# use_ngraph is for CPU only, please refer to README_ngraph.md for details
use_ngraph = os.getenv('FLAGS_use_ngraph')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册