未验证 提交 4e8f18ab 编写于 作者: D Dong Daxiang 提交者: GitHub

Get final strategy (#27602)

* add get final strategy for user to print final strategy
上级 d01f6269
...@@ -30,6 +30,7 @@ __all__ = [ ...@@ -30,6 +30,7 @@ __all__ = [
] ]
fleet = Fleet() fleet = Fleet()
_final_strategy = fleet._final_strategy
init = fleet.init init = fleet.init
is_first_worker = fleet.is_first_worker is_first_worker = fleet.is_first_worker
worker_index = fleet.worker_index worker_index = fleet.worker_index
......
...@@ -1244,8 +1244,7 @@ class DistributedStrategy(object): ...@@ -1244,8 +1244,7 @@ class DistributedStrategy(object):
if getattr(self.strategy, f.name): if getattr(self.strategy, f.name):
draws += border + "\n" draws += border + "\n"
draws += h1_format.format( draws += h1_format.format(
"{} = True, please check {}_configs".format( "{}=True <-> {}_configs".format(f.name, f.name))
f.name, f.name))
draws += line + "\n" draws += line + "\n"
my_configs = getattr(self.strategy, my_configs = getattr(self.strategy,
f.name + "_configs") f.name + "_configs")
......
...@@ -119,6 +119,8 @@ class Fleet(object): ...@@ -119,6 +119,8 @@ class Fleet(object):
self.strategy_compiler = None self.strategy_compiler = None
self._is_collective = False self._is_collective = False
self._runtime_handle = None self._runtime_handle = None
self._util = None
self._context = {}
def init(self, role_maker=None, is_collective=False): def init(self, role_maker=None, is_collective=False):
""" """
...@@ -233,7 +235,7 @@ class Fleet(object): ...@@ -233,7 +235,7 @@ class Fleet(object):
Returns: Returns:
int: worker numbers int: worker numbers
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -569,8 +571,9 @@ class Fleet(object): ...@@ -569,8 +571,9 @@ class Fleet(object):
if strategy == None: if strategy == None:
strategy = DistributedStrategy() strategy = DistributedStrategy()
self.user_defined_strategy = strategy
self.valid_strategy = None self._user_defined_strategy = copy.deepcopy(strategy)
self._context = {}
return self return self
@dygraph_only @dygraph_only
...@@ -909,6 +912,15 @@ class Fleet(object): ...@@ -909,6 +912,15 @@ class Fleet(object):
# imitate target optimizer retrieval # imitate target optimizer retrieval
return self.user_defined_optimizer.clear_grad() return self.user_defined_optimizer.clear_grad()
def _final_strategy(self):
if "valid_strategy" not in self._context:
print(
"WARNING: You may need to call minimize function before this function is called"
)
return {}
else:
return self._context["valid_strategy"]
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -958,12 +970,15 @@ class Fleet(object): ...@@ -958,12 +970,15 @@ class Fleet(object):
# for more examples, please reference https://github.com/PaddlePaddle/FleetX # for more examples, please reference https://github.com/PaddlePaddle/FleetX
""" """
context = {}
context["user_defined_strategy"] = copy.deepcopy(
self._user_defined_strategy)
if paddle.fluid.framework.in_dygraph_mode(): if paddle.fluid.framework.in_dygraph_mode():
# imitate target optimizer retrieval # imitate target optimizer retrieval
target_opt = self.user_defined_optimizer target_opt = self.user_defined_optimizer
self._context = context
return target_opt.minimize(loss) return target_opt.minimize(loss)
context = {}
# cache original feed forward program # cache original feed forward program
self.origin_main_program = loss.block.program self.origin_main_program = loss.block.program
context["origin_main_program"] = self.origin_main_program context["origin_main_program"] = self.origin_main_program
...@@ -984,17 +999,19 @@ class Fleet(object): ...@@ -984,17 +999,19 @@ class Fleet(object):
MetaOptimizerFactory()._get_valid_meta_optimizers( MetaOptimizerFactory()._get_valid_meta_optimizers(
self.user_defined_optimizer) self.user_defined_optimizer)
context["user_defined_strategy"] = copy.copy(self.user_defined_strategy) context["user_defined_strategy"] = copy.deepcopy(
self._user_defined_strategy)
copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
# trigger the auto-parallel in very strict condition # trigger the auto-parallel in very strict condition
# strategy = DistributedStrategy() # strategy = DistributedStrategy()
# strategy.auto = True # strategy.auto = True
# optimizer = paddle.optimizer.SGD(learning_rate=0.1) # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
# optimizer = fleet.distributed_optimizer(optimizer, strategy) # optimizer = fleet.distributed_optimizer(optimizer, strategy)
if self.user_defined_strategy._is_strict_auto(): if copy_user_defined_strategy._is_strict_auto():
# turn on all the strategy for each optimizer # turn on all the strategy for each optimizer
for opt in distributed_optimizer_list: for opt in distributed_optimizer_list:
opt._enable_strategy(self.user_defined_strategy, context) opt._enable_strategy(copy_user_defined_strategy, context)
valid_optimizer_list = [] valid_optimizer_list = []
valid_graph_optimizer_list = [] valid_graph_optimizer_list = []
...@@ -1003,7 +1020,7 @@ class Fleet(object): ...@@ -1003,7 +1020,7 @@ class Fleet(object):
for opt in distributed_optimizer_list: for opt in distributed_optimizer_list:
opt._set_basic_info(loss, self._role_maker, opt._set_basic_info(loss, self._role_maker,
self.user_defined_optimizer, self.user_defined_optimizer,
self.user_defined_strategy) copy_user_defined_strategy)
if opt._can_apply() and not opt._is_graph_out(): if opt._can_apply() and not opt._is_graph_out():
valid_optimizer_list.append(opt) valid_optimizer_list.append(opt)
elif opt._can_apply() and opt._is_graph_out(): elif opt._can_apply() and opt._is_graph_out():
...@@ -1014,13 +1031,15 @@ class Fleet(object): ...@@ -1014,13 +1031,15 @@ class Fleet(object):
meta_optimizer, graph_optimizer = \ meta_optimizer, graph_optimizer = \
self.strategy_compiler.generate_optimizer( self.strategy_compiler.generate_optimizer(
loss, self._role_maker, self.user_defined_optimizer, loss, self._role_maker, self.user_defined_optimizer,
self.user_defined_strategy, valid_optimizer_list, copy_user_defined_strategy, valid_optimizer_list,
valid_graph_optimizer_list) valid_graph_optimizer_list)
valid_strategy = self.strategy_compiler._get_valid_strategy( valid_strategy = self.strategy_compiler._get_valid_strategy(
self.user_defined_strategy, can_not_apply_optimizer_list) copy_user_defined_strategy, can_not_apply_optimizer_list)
context["valid_strategy"] = copy.deepcopy(valid_strategy)
context["valid_strategy"] = valid_strategy self._context = context
self.valid_strategy = valid_strategy self.valid_strategy = valid_strategy
self.valid_strategy._enable_env() self.valid_strategy._enable_env()
......
...@@ -60,8 +60,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): ...@@ -60,8 +60,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync) self.assertTrue(fleet._final_strategy().a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs a_sync_configs = fleet._final_strategy().a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 0) self.assertTrue(a_sync_configs['k_steps'] == 0)
......
...@@ -72,8 +72,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): ...@@ -72,8 +72,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync) self.assertTrue(fleet._final_strategy().a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs a_sync_configs = fleet._final_strategy().a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 0) self.assertTrue(a_sync_configs['k_steps'] == 0)
......
...@@ -60,8 +60,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): ...@@ -60,8 +60,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync) self.assertTrue(fleet._final_strategy().a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs a_sync_configs = fleet._final_strategy().a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 800) self.assertTrue(a_sync_configs['k_steps'] == 800)
......
...@@ -18,6 +18,8 @@ import unittest ...@@ -18,6 +18,8 @@ import unittest
import paddle import paddle
import os import os
paddle.enable_static()
class TestFleetAMPOptimizer(unittest.TestCase): class TestFleetAMPOptimizer(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -55,6 +57,8 @@ class TestFleetAMPOptimizer(unittest.TestCase): ...@@ -55,6 +57,8 @@ class TestFleetAMPOptimizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
strategy = fleet._final_strategy()
ops = [op.type for op in avg_cost.block.ops] ops = [op.type for op in avg_cost.block.ops]
self.assertIn('cast', ops) self.assertIn('cast', ops)
self.assertIn('check_finite_and_unscale', ops) self.assertIn('check_finite_and_unscale', ops)
......
...@@ -18,6 +18,8 @@ import os ...@@ -18,6 +18,8 @@ import os
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
paddle.enable_static()
class TestDistributedStrategyAuto(unittest.TestCase): class TestDistributedStrategyAuto(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -167,6 +167,8 @@ class TestFleetDygraph(unittest.TestCase): ...@@ -167,6 +167,8 @@ class TestFleetDygraph(unittest.TestCase):
state_dict = adam.state_dict() state_dict = adam.state_dict()
adam.set_state_dict(state_dict) adam.set_state_dict(state_dict)
final_strategy = fleet._final_strategy()
class TestFleetBaseSingleRunCollective(unittest.TestCase): class TestFleetBaseSingleRunCollective(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -19,6 +19,8 @@ import os ...@@ -19,6 +19,8 @@ import os
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
paddle.enable_static()
class TestFleetLambMetaOptimizer(unittest.TestCase): class TestFleetLambMetaOptimizer(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -19,6 +19,8 @@ import os ...@@ -19,6 +19,8 @@ import os
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
paddle.enable_static()
class TestFleetLarsMetaOptimizer(unittest.TestCase): class TestFleetLarsMetaOptimizer(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册