未验证 提交 4adac0e3 编写于 作者: D Dong Daxiang 提交者: GitHub

【paddle.fleet】Add fleet base context (#25954)

* generate context during compile
上级 358bc06c
......@@ -279,8 +279,11 @@ class Fleet(object):
# for more examples, please reference https://github.com/PaddlePaddle/Fleet
"""
context = {}
# cache original feed forward program
self.origin_main_program = loss.block.program
context["origin_main_program"] = self.origin_main_program
context["loss"] = loss
if startup_program == None:
self.origin_startup_program = \
paddle.default_startup_program().clone(for_test=False)
......@@ -288,6 +291,8 @@ class Fleet(object):
else:
self.origin_startup_program = \
startup_program.clone(for_test=False)
context["origin_startup_program"] = startup_program
context["role_maker"] = self._role_maker
# compile time
distributed_optimizer_list = \
......@@ -317,6 +322,9 @@ class Fleet(object):
valid_strategy = self.strategy_compiler._get_valid_strategy(
self.user_defined_strategy, can_not_apply_optimizer_list)
context["valid_strategy"] = valid_strategy
self.valid_strategy = valid_strategy
optimize_ops = []
......@@ -334,6 +342,8 @@ class Fleet(object):
parameter_list=parameter_list,
no_grad_set=no_grad_set)
context["program_optimize_ops"] = optimize_ops
context["program_params_grads"] = params_grads
if graph_optimizer:
optimize_ops, params_grads = graph_optimizer.minimize(
loss,
......@@ -344,12 +354,13 @@ class Fleet(object):
# if a graph optimizer takes effect, mostly
# optimizers_ops and params_grads are None
# i.e. users can not modify current computation graph anymore
context["graph_optimize_ops"] = optimize_ops
context["graph_optimize_grads"] = params_grads
if self._runtime_handle is None:
self._runtime_handle = RuntimeFactory()._create_runtime(
valid_strategy, self._role_maker, optimize_ops, params_grads)
self._runtime_handle = RuntimeFactory()._create_runtime(context)
if self._util is None:
self._util = UtilFactory()._create_util(
valid_strategy, self._role_maker, optimize_ops, params_grads)
self._util = UtilFactory()._create_util(context)
return optimize_ops, params_grads
......@@ -18,10 +18,8 @@ class RuntimeFactory(object):
def __init__(self):
pass
def _create_runtime(self, final_dist_strategy, role_maker, opt_ops,
params_grads):
if role_maker._is_collective:
def _create_runtime(self, context):
if context["role_maker"]._is_collective:
collective_runtime = CollectiveRuntime()
collective_runtime._set_basic_info(final_dist_strategy, role_maker,
opt_ops, params_grads)
collective_runtime._set_basic_info(context)
return collective_runtime
......@@ -20,11 +20,10 @@ __all__ = ['UtilBase']
class UtilFactory(object):
def _create_util(self, dist_strategy, role_maker, optimize_ops,
params_grads):
def _create_util(self, context):
util = UtilBase()
util._set_strategy(dist_strategy)
util._set_role_maker(role_maker)
util._set_strategy(context["valid_strategy"])
util._set_role_maker(context["role_maker"])
return util
......
......@@ -19,11 +19,8 @@ class RuntimeBase(object):
def __init__(self):
pass
def _set_basic_info(self, loss, role_maker, optimizer, strategy):
self.loss = loss
self.role_maker = role_maker
self.optimizer = optimizer
self.strategy = strategy
def _set_basic_info(self, context):
self.context = context
def _run_worker(self):
pass
......
......@@ -33,8 +33,10 @@ class TestFleetUtil(unittest.TestCase):
role_maker = None # should be fleet.PaddleCloudRoleMaker()
optimize_ops = []
params_grads = []
util = factory._create_util(strategy, role_maker, optimize_ops,
params_grads)
context = {}
context["role_maker"] = role_maker
context["valid_strategy"] = strategy
util = factory._create_util(context)
self.assertEqual(util.role_maker, None)
def test_get_util(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册