未验证 提交 f55d1c68 编写于 作者: M mapingshuo 提交者: GitHub

Fleet: deal with special case: strategy is None (#20359)

* special case: strategy is None
上级 1d82025e
...@@ -152,7 +152,7 @@ class CollectiveOptimizer(DistributedOptimizer): ...@@ -152,7 +152,7 @@ class CollectiveOptimizer(DistributedOptimizer):
def __init__(self, optimizer, strategy=DistributedStrategy()): def __init__(self, optimizer, strategy=DistributedStrategy()):
super(CollectiveOptimizer, self).__init__(optimizer, strategy) super(CollectiveOptimizer, self).__init__(optimizer, strategy)
if strategy.forward_recompute: if strategy is not None and strategy.forward_recompute:
self.forward_recompute = True self.forward_recompute = True
self.recompute_checkpoints = strategy.recompute_checkpoints self.recompute_checkpoints = strategy.recompute_checkpoints
else: else:
......
...@@ -22,6 +22,7 @@ from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedCollectiveRol ...@@ -22,6 +22,7 @@ from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedCollectiveRol
from paddle.fluid.incubate.fleet.base.role_maker import Role from paddle.fluid.incubate.fleet.base.role_maker import Role
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import TranspilerOptimizer from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import TranspilerOptimizer
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer
class DistributeTranspilerConfigTest(unittest.TestCase): class DistributeTranspilerConfigTest(unittest.TestCase):
...@@ -204,5 +205,11 @@ class UserDefinedCollectiveRoleMakerTest(unittest.TestCase): ...@@ -204,5 +205,11 @@ class UserDefinedCollectiveRoleMakerTest(unittest.TestCase):
) # current_id must be less than len(worker_endpoints) ) # current_id must be less than len(worker_endpoints)
class CollectiveOptimizerTest(unittest.TestCase):
def test_ds_as_None(self):
optimizer = fluid.optimizer.AdamOptimizer()
dist_optimizer = CollectiveOptimizer(optimizer, strategy=None)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册