diff --git a/ppdet/engine/env.py b/ppdet/engine/env.py index 8d07e4b4acc42f2611f466c797eb183258bce404..0a896571db8bee03f3fdb172443af88622a912bd 100644 --- a/ppdet/engine/env.py +++ b/ppdet/engine/env.py @@ -26,8 +26,10 @@ from paddle.distributed import fleet __all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env'] -def init_fleet_env(): - fleet.init(is_collective=True) +def init_fleet_env(find_unused_parameters=False): + strategy = fleet.DistributedStrategy() + strategy.find_unused_parameters = find_unused_parameters + fleet.init(is_collective=True, strategy=strategy) def init_parallel_env(): diff --git a/tools/train.py b/tools/train.py index d9ef6d6f24d89d6322b6ecd01d358049386af971..211322fd5389b817c52066b16168734c22608ba2 100755 --- a/tools/train.py +++ b/tools/train.py @@ -90,7 +90,7 @@ def parse_args(): def run(FLAGS, cfg): # init fleet environment if cfg.fleet: - init_fleet_env() + init_fleet_env(cfg.get('find_unused_parameters', False)) else: # init parallel environment if nranks > 1 init_parallel_env()