From c0b04f0b9510b1de2f07c11892a60cf4a53e5963 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Sat, 12 Oct 2019 12:55:29 +0800 Subject: [PATCH] Fleet: special case: strategy is None (#20427) * special case: strategy is None --- python/paddle/fluid/incubate/fleet/collective/__init__.py | 2 +- .../paddle/fluid/tests/unittests/test_fleet_api_input.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index 4f939deac66..fa5dd3673dc 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -152,7 +152,7 @@ class CollectiveOptimizer(DistributedOptimizer): def __init__(self, optimizer, strategy=DistributedStrategy()): super(CollectiveOptimizer, self).__init__(optimizer, strategy) - if strategy.forward_recompute: + if strategy is not None and strategy.forward_recompute: self.forward_recompute = True self.recompute_checkpoints = strategy.recompute_checkpoints else: diff --git a/python/paddle/fluid/tests/unittests/test_fleet_api_input.py b/python/paddle/fluid/tests/unittests/test_fleet_api_input.py index 49cdeaa6a42..eb54470623c 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_api_input.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_api_input.py @@ -22,6 +22,7 @@ from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedCollectiveRol from paddle.fluid.incubate.fleet.base.role_maker import Role from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import TranspilerOptimizer +from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer class DistributeTranspilerConfigTest(unittest.TestCase): @@ -204,5 +205,11 @@ class UserDefinedCollectiveRoleMakerTest(unittest.TestCase): ) # current_id must be less than len(worker_endpoints) +class CollectiveOptimizerTest(unittest.TestCase): + def test_ds_as_None(self): + optimizer = fluid.optimizer.AdamOptimizer() + dist_optimizer = CollectiveOptimizer(optimizer, strategy=None) + + if __name__ == '__main__': unittest.main() -- GitLab