未验证 提交 4e8f18ab 编写于 作者: D Dong Daxiang 提交者: GitHub

Get final strategy (#27602)

* add get final strategy for user to print final strategy
上级 d01f6269
......@@ -30,6 +30,7 @@ __all__ = [
]
fleet = Fleet()
_final_strategy = fleet._final_strategy
init = fleet.init
is_first_worker = fleet.is_first_worker
worker_index = fleet.worker_index
......
......@@ -1244,8 +1244,7 @@ class DistributedStrategy(object):
if getattr(self.strategy, f.name):
draws += border + "\n"
draws += h1_format.format(
"{} = True, please check {}_configs".format(
f.name, f.name))
"{}=True <-> {}_configs".format(f.name, f.name))
draws += line + "\n"
my_configs = getattr(self.strategy,
f.name + "_configs")
......
......@@ -119,6 +119,8 @@ class Fleet(object):
self.strategy_compiler = None
self._is_collective = False
self._runtime_handle = None
self._util = None
self._context = {}
def init(self, role_maker=None, is_collective=False):
"""
......@@ -569,8 +571,9 @@ class Fleet(object):
if strategy == None:
strategy = DistributedStrategy()
self.user_defined_strategy = strategy
self.valid_strategy = None
self._user_defined_strategy = copy.deepcopy(strategy)
self._context = {}
return self
@dygraph_only
......@@ -909,6 +912,15 @@ class Fleet(object):
# imitate target optimizer retrieval
return self.user_defined_optimizer.clear_grad()
def _final_strategy(self):
if "valid_strategy" not in self._context:
print(
"WARNING: You may need to call minimize function before this function is called"
)
return {}
else:
return self._context["valid_strategy"]
def minimize(self,
loss,
startup_program=None,
......@@ -958,12 +970,15 @@ class Fleet(object):
# for more examples, please reference https://github.com/PaddlePaddle/FleetX
"""
context = {}
context["user_defined_strategy"] = copy.deepcopy(
self._user_defined_strategy)
if paddle.fluid.framework.in_dygraph_mode():
# imitate target optimizer retrieval
target_opt = self.user_defined_optimizer
self._context = context
return target_opt.minimize(loss)
context = {}
# cache original feed forward program
self.origin_main_program = loss.block.program
context["origin_main_program"] = self.origin_main_program
......@@ -984,17 +999,19 @@ class Fleet(object):
MetaOptimizerFactory()._get_valid_meta_optimizers(
self.user_defined_optimizer)
context["user_defined_strategy"] = copy.copy(self.user_defined_strategy)
context["user_defined_strategy"] = copy.deepcopy(
self._user_defined_strategy)
copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
# trigger the auto-parallel in very strict condition
# strategy = DistributedStrategy()
# strategy.auto = True
# optimizer = paddle.optimizer.SGD(learning_rate=0.1)
# optimizer = fleet.distributed_optimizer(optimizer, strategy)
if self.user_defined_strategy._is_strict_auto():
if copy_user_defined_strategy._is_strict_auto():
# turn on all the strategy for each optimizer
for opt in distributed_optimizer_list:
opt._enable_strategy(self.user_defined_strategy, context)
opt._enable_strategy(copy_user_defined_strategy, context)
valid_optimizer_list = []
valid_graph_optimizer_list = []
......@@ -1003,7 +1020,7 @@ class Fleet(object):
for opt in distributed_optimizer_list:
opt._set_basic_info(loss, self._role_maker,
self.user_defined_optimizer,
self.user_defined_strategy)
copy_user_defined_strategy)
if opt._can_apply() and not opt._is_graph_out():
valid_optimizer_list.append(opt)
elif opt._can_apply() and opt._is_graph_out():
......@@ -1014,13 +1031,15 @@ class Fleet(object):
meta_optimizer, graph_optimizer = \
self.strategy_compiler.generate_optimizer(
loss, self._role_maker, self.user_defined_optimizer,
self.user_defined_strategy, valid_optimizer_list,
copy_user_defined_strategy, valid_optimizer_list,
valid_graph_optimizer_list)
valid_strategy = self.strategy_compiler._get_valid_strategy(
self.user_defined_strategy, can_not_apply_optimizer_list)
copy_user_defined_strategy, can_not_apply_optimizer_list)
context["valid_strategy"] = copy.deepcopy(valid_strategy)
context["valid_strategy"] = valid_strategy
self._context = context
self.valid_strategy = valid_strategy
self.valid_strategy._enable_env()
......
......@@ -60,8 +60,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(fleet._final_strategy().a_sync)
a_sync_configs = fleet._final_strategy().a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 0)
......
......@@ -72,8 +72,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(fleet._final_strategy().a_sync)
a_sync_configs = fleet._final_strategy().a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 0)
......
......@@ -60,8 +60,8 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(fleet._final_strategy().a_sync)
a_sync_configs = fleet._final_strategy().a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 800)
......
......@@ -18,6 +18,8 @@ import unittest
import paddle
import os
paddle.enable_static()
class TestFleetAMPOptimizer(unittest.TestCase):
def setUp(self):
......@@ -55,6 +57,8 @@ class TestFleetAMPOptimizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
strategy = fleet._final_strategy()
ops = [op.type for op in avg_cost.block.ops]
self.assertIn('cast', ops)
self.assertIn('check_finite_and_unscale', ops)
......
......@@ -18,6 +18,8 @@ import os
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
paddle.enable_static()
class TestDistributedStrategyAuto(unittest.TestCase):
def setUp(self):
......
......@@ -167,6 +167,8 @@ class TestFleetDygraph(unittest.TestCase):
state_dict = adam.state_dict()
adam.set_state_dict(state_dict)
final_strategy = fleet._final_strategy()
class TestFleetBaseSingleRunCollective(unittest.TestCase):
def setUp(self):
......
......@@ -19,6 +19,8 @@ import os
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
paddle.enable_static()
class TestFleetLambMetaOptimizer(unittest.TestCase):
def setUp(self):
......
......@@ -19,6 +19,8 @@ import os
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
paddle.enable_static()
class TestFleetLarsMetaOptimizer(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册