未验证 提交 cb568212 编写于 作者: S shangliang Xu 提交者: GitHub

add find_unused_parameters in fleet init (#3570)

上级 af424a9a
...@@ -26,8 +26,10 @@ from paddle.distributed import fleet ...@@ -26,8 +26,10 @@ from paddle.distributed import fleet
__all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env'] __all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env']
def init_fleet_env(): def init_fleet_env(find_unused_parameters=False):
fleet.init(is_collective=True) strategy = fleet.DistributedStrategy()
strategy.find_unused_parameters = find_unused_parameters
fleet.init(is_collective=True, strategy=strategy)
def init_parallel_env(): def init_parallel_env():
......
...@@ -90,7 +90,7 @@ def parse_args(): ...@@ -90,7 +90,7 @@ def parse_args():
def run(FLAGS, cfg): def run(FLAGS, cfg):
# init fleet environment # init fleet environment
if cfg.fleet: if cfg.fleet:
init_fleet_env() init_fleet_env(cfg.get('find_unused_parameters', False))
else: else:
# init parallel environment if nranks > 1 # init parallel environment if nranks > 1
init_parallel_env() init_parallel_env()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册