diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index 26b8e2c3b12a0a2e888377ddad24bbeae7c6f3e1..9c7963a1ec2dcb3256aa90d1d10bf663045a7354 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -151,8 +151,10 @@ class CollectiveOptimizer(DistributedOptimizer): """ def __init__(self, optimizer, strategy=DistributedStrategy()): + if strategy is None: + strategy = DistributedStrategy() super(CollectiveOptimizer, self).__init__(optimizer, strategy) - if strategy is not None and strategy.forward_recompute: + if strategy.forward_recompute: self.forward_recompute = True self.recompute_checkpoints = strategy.recompute_checkpoints else: