未验证 提交 4adac0e3 编写于 作者: D Dong Daxiang 提交者: GitHub

【paddle.fleet】Add fleet base context (#25954)

* generate context during compile
上级 358bc06c
...@@ -279,8 +279,11 @@ class Fleet(object): ...@@ -279,8 +279,11 @@ class Fleet(object):
# for more examples, please reference https://github.com/PaddlePaddle/Fleet # for more examples, please reference https://github.com/PaddlePaddle/Fleet
""" """
context = {}
# cache original feed forward program # cache original feed forward program
self.origin_main_program = loss.block.program self.origin_main_program = loss.block.program
context["origin_main_program"] = self.origin_main_program
context["loss"] = loss
if startup_program == None: if startup_program == None:
self.origin_startup_program = \ self.origin_startup_program = \
paddle.default_startup_program().clone(for_test=False) paddle.default_startup_program().clone(for_test=False)
...@@ -288,6 +291,8 @@ class Fleet(object): ...@@ -288,6 +291,8 @@ class Fleet(object):
else: else:
self.origin_startup_program = \ self.origin_startup_program = \
startup_program.clone(for_test=False) startup_program.clone(for_test=False)
context["origin_startup_program"] = startup_program
context["role_maker"] = self._role_maker
# compile time # compile time
distributed_optimizer_list = \ distributed_optimizer_list = \
...@@ -317,6 +322,9 @@ class Fleet(object): ...@@ -317,6 +322,9 @@ class Fleet(object):
valid_strategy = self.strategy_compiler._get_valid_strategy( valid_strategy = self.strategy_compiler._get_valid_strategy(
self.user_defined_strategy, can_not_apply_optimizer_list) self.user_defined_strategy, can_not_apply_optimizer_list)
context["valid_strategy"] = valid_strategy
self.valid_strategy = valid_strategy self.valid_strategy = valid_strategy
optimize_ops = [] optimize_ops = []
...@@ -334,6 +342,8 @@ class Fleet(object): ...@@ -334,6 +342,8 @@ class Fleet(object):
parameter_list=parameter_list, parameter_list=parameter_list,
no_grad_set=no_grad_set) no_grad_set=no_grad_set)
context["program_optimize_ops"] = optimize_ops
context["program_params_grads"] = params_grads
if graph_optimizer: if graph_optimizer:
optimize_ops, params_grads = graph_optimizer.minimize( optimize_ops, params_grads = graph_optimizer.minimize(
loss, loss,
...@@ -344,12 +354,13 @@ class Fleet(object): ...@@ -344,12 +354,13 @@ class Fleet(object):
# if a graph optimizer takes effect, mostly # if a graph optimizer takes effect, mostly
# optimizers_ops and params_grads are None # optimizers_ops and params_grads are None
# i.e. users can not modify current computation graph anymore # i.e. users can not modify current computation graph anymore
context["graph_optimize_ops"] = optimize_ops
context["graph_optimize_grads"] = params_grads
if self._runtime_handle is None: if self._runtime_handle is None:
self._runtime_handle = RuntimeFactory()._create_runtime( self._runtime_handle = RuntimeFactory()._create_runtime(context)
valid_strategy, self._role_maker, optimize_ops, params_grads)
if self._util is None: if self._util is None:
self._util = UtilFactory()._create_util( self._util = UtilFactory()._create_util(context)
valid_strategy, self._role_maker, optimize_ops, params_grads)
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -18,10 +18,8 @@ class RuntimeFactory(object): ...@@ -18,10 +18,8 @@ class RuntimeFactory(object):
def __init__(self): def __init__(self):
pass pass
def _create_runtime(self, final_dist_strategy, role_maker, opt_ops, def _create_runtime(self, context):
params_grads): if context["role_maker"]._is_collective:
if role_maker._is_collective:
collective_runtime = CollectiveRuntime() collective_runtime = CollectiveRuntime()
collective_runtime._set_basic_info(final_dist_strategy, role_maker, collective_runtime._set_basic_info(context)
opt_ops, params_grads)
return collective_runtime return collective_runtime
...@@ -20,11 +20,10 @@ __all__ = ['UtilBase'] ...@@ -20,11 +20,10 @@ __all__ = ['UtilBase']
class UtilFactory(object): class UtilFactory(object):
def _create_util(self, dist_strategy, role_maker, optimize_ops, def _create_util(self, context):
params_grads):
util = UtilBase() util = UtilBase()
util._set_strategy(dist_strategy) util._set_strategy(context["valid_strategy"])
util._set_role_maker(role_maker) util._set_role_maker(context["role_maker"])
return util return util
......
...@@ -19,11 +19,8 @@ class RuntimeBase(object): ...@@ -19,11 +19,8 @@ class RuntimeBase(object):
def __init__(self): def __init__(self):
pass pass
def _set_basic_info(self, loss, role_maker, optimizer, strategy): def _set_basic_info(self, context):
self.loss = loss self.context = context
self.role_maker = role_maker
self.optimizer = optimizer
self.strategy = strategy
def _run_worker(self): def _run_worker(self):
pass pass
......
...@@ -33,8 +33,10 @@ class TestFleetUtil(unittest.TestCase): ...@@ -33,8 +33,10 @@ class TestFleetUtil(unittest.TestCase):
role_maker = None # should be fleet.PaddleCloudRoleMaker() role_maker = None # should be fleet.PaddleCloudRoleMaker()
optimize_ops = [] optimize_ops = []
params_grads = [] params_grads = []
util = factory._create_util(strategy, role_maker, optimize_ops, context = {}
params_grads) context["role_maker"] = role_maker
context["valid_strategy"] = strategy
util = factory._create_util(context)
self.assertEqual(util.role_maker, None) self.assertEqual(util.role_maker, None)
def test_get_util(self): def test_get_util(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册