未验证 提交 cb568212 编写于 作者: S shangliang Xu 提交者: GitHub

add find_unused_parameters in fleet init (#3570)

上级 af424a9a
......@@ -26,8 +26,10 @@ from paddle.distributed import fleet
__all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env']
def init_fleet_env():
fleet.init(is_collective=True)
def init_fleet_env(find_unused_parameters=False):
strategy = fleet.DistributedStrategy()
strategy.find_unused_parameters = find_unused_parameters
fleet.init(is_collective=True, strategy=strategy)
def init_parallel_env():
......
......@@ -90,7 +90,7 @@ def parse_args():
def run(FLAGS, cfg):
# init fleet environment
if cfg.fleet:
init_fleet_env()
init_fleet_env(cfg.get('find_unused_parameters', False))
else:
# init parallel environment if nranks > 1
init_parallel_env()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册