diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py
index 4f939deac66e88f9c0618e2f05918b138d2c574a..fa5dd3673dc749e32b04c4c2b8f076ae856d0002 100644
--- a/python/paddle/fluid/incubate/fleet/collective/__init__.py
+++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py
@@ -152,7 +152,7 @@ class CollectiveOptimizer(DistributedOptimizer):
 
     def __init__(self, optimizer, strategy=DistributedStrategy()):
         super(CollectiveOptimizer, self).__init__(optimizer, strategy)
-        if strategy.forward_recompute:
+        if strategy is not None and strategy.forward_recompute:
             self.forward_recompute = True
             self.recompute_checkpoints = strategy.recompute_checkpoints
         else:
diff --git a/python/paddle/fluid/tests/unittests/test_fleet_api_input.py b/python/paddle/fluid/tests/unittests/test_fleet_api_input.py
index 49cdeaa6a426e84907cf7bbcdcd05fed5e782522..eb54470623cafb24ea216da65cd8b8a9bae9d57f 100644
--- a/python/paddle/fluid/tests/unittests/test_fleet_api_input.py
+++ b/python/paddle/fluid/tests/unittests/test_fleet_api_input.py
@@ -22,6 +22,7 @@ from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedCollectiveRol
 from paddle.fluid.incubate.fleet.base.role_maker import Role
 from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
 from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import TranspilerOptimizer
+from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer
 
 
 class DistributeTranspilerConfigTest(unittest.TestCase):
@@ -204,5 +205,11 @@ class UserDefinedCollectiveRoleMakerTest(unittest.TestCase):
         )  # current_id must be less than len(worker_endpoints)
 
 
+class CollectiveOptimizerTest(unittest.TestCase):
+    def test_ds_as_None(self):
+        optimizer = fluid.optimizer.AdamOptimizer()
+        dist_optimizer = CollectiveOptimizer(optimizer, strategy=None)
+
+
 if __name__ == '__main__':
     unittest.main()