From 4adac0e30978f3a28d10adf30f319e96ee3e03ed Mon Sep 17 00:00:00 2001 From: Dong Daxiang <35550832+guru4elephant@users.noreply.github.com> Date: Wed, 5 Aug 2020 14:55:55 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.fleet=E3=80=91Add=20fleet=20bas?= =?UTF-8?q?e=20context=20(#25954)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * generate context during compile --- python/paddle/fleet/base/fleet_base.py | 19 +++++++++++++++---- python/paddle/fleet/base/runtime_factory.py | 8 +++----- python/paddle/fleet/base/util_factory.py | 7 +++---- python/paddle/fleet/runtime/runtime_base.py | 7 ++----- .../fluid/tests/unittests/test_fleet_util.py | 6 ++++-- 5 files changed, 27 insertions(+), 20 deletions(-) diff --git a/python/paddle/fleet/base/fleet_base.py b/python/paddle/fleet/base/fleet_base.py index a9238df6292..979b878a3df 100644 --- a/python/paddle/fleet/base/fleet_base.py +++ b/python/paddle/fleet/base/fleet_base.py @@ -279,8 +279,11 @@ class Fleet(object): # for more examples, please reference https://github.com/PaddlePaddle/Fleet """ + context = {} # cache original feed forward program self.origin_main_program = loss.block.program + context["origin_main_program"] = self.origin_main_program + context["loss"] = loss if startup_program == None: self.origin_startup_program = \ paddle.default_startup_program().clone(for_test=False) @@ -288,6 +291,8 @@ class Fleet(object): else: self.origin_startup_program = \ startup_program.clone(for_test=False) + context["origin_startup_program"] = startup_program + context["role_maker"] = self._role_maker # compile time distributed_optimizer_list = \ @@ -317,6 +322,9 @@ class Fleet(object): valid_strategy = self.strategy_compiler._get_valid_strategy( self.user_defined_strategy, can_not_apply_optimizer_list) + + context["valid_strategy"] = valid_strategy + self.valid_strategy = valid_strategy optimize_ops = [] @@ -334,6 +342,8 @@ class Fleet(object): parameter_list=parameter_list, no_grad_set=no_grad_set) + context["program_optimize_ops"] = optimize_ops + context["program_params_grads"] = params_grads if graph_optimizer: optimize_ops, params_grads = graph_optimizer.minimize( loss, @@ -344,12 +354,13 @@ class Fleet(object): # if a graph optimizer takes effect, mostly # optimizers_ops and params_grads are None # i.e. users can not modify current computation graph anymore + context["graph_optimize_ops"] = optimize_ops + context["graph_optimize_grads"] = params_grads + if self._runtime_handle is None: - self._runtime_handle = RuntimeFactory()._create_runtime( - valid_strategy, self._role_maker, optimize_ops, params_grads) + self._runtime_handle = RuntimeFactory()._create_runtime(context) if self._util is None: - self._util = UtilFactory()._create_util( - valid_strategy, self._role_maker, optimize_ops, params_grads) + self._util = UtilFactory()._create_util(context) return optimize_ops, params_grads diff --git a/python/paddle/fleet/base/runtime_factory.py b/python/paddle/fleet/base/runtime_factory.py index c4d42db4ea9..45dca6dae4e 100644 --- a/python/paddle/fleet/base/runtime_factory.py +++ b/python/paddle/fleet/base/runtime_factory.py @@ -18,10 +18,8 @@ class RuntimeFactory(object): def __init__(self): pass - def _create_runtime(self, final_dist_strategy, role_maker, opt_ops, - params_grads): - if role_maker._is_collective: + def _create_runtime(self, context): + if context["role_maker"]._is_collective: collective_runtime = CollectiveRuntime() - collective_runtime._set_basic_info(final_dist_strategy, role_maker, - opt_ops, params_grads) + collective_runtime._set_basic_info(context) return collective_runtime diff --git a/python/paddle/fleet/base/util_factory.py b/python/paddle/fleet/base/util_factory.py index 74029f43d10..385500de8c0 100644 --- a/python/paddle/fleet/base/util_factory.py +++ b/python/paddle/fleet/base/util_factory.py @@ -20,11 +20,10 @@ __all__ = ['UtilBase'] class UtilFactory(object): - def _create_util(self, dist_strategy, role_maker, optimize_ops, - params_grads): + def _create_util(self, context): util = UtilBase() - util._set_strategy(dist_strategy) - util._set_role_maker(role_maker) + util._set_strategy(context["valid_strategy"]) + util._set_role_maker(context["role_maker"]) return util diff --git a/python/paddle/fleet/runtime/runtime_base.py b/python/paddle/fleet/runtime/runtime_base.py index 5610a5305a4..c7ce8b5a291 100644 --- a/python/paddle/fleet/runtime/runtime_base.py +++ b/python/paddle/fleet/runtime/runtime_base.py @@ -19,11 +19,8 @@ class RuntimeBase(object): def __init__(self): pass - def _set_basic_info(self, loss, role_maker, optimizer, strategy): - self.loss = loss - self.role_maker = role_maker - self.optimizer = optimizer - self.strategy = strategy + def _set_basic_info(self, context): + self.context = context def _run_worker(self): pass diff --git a/python/paddle/fluid/tests/unittests/test_fleet_util.py b/python/paddle/fluid/tests/unittests/test_fleet_util.py index 4825035d123..427e077416e 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_util.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_util.py @@ -33,8 +33,10 @@ class TestFleetUtil(unittest.TestCase): role_maker = None # should be fleet.PaddleCloudRoleMaker() optimize_ops = [] params_grads = [] - util = factory._create_util(strategy, role_maker, optimize_ops, - params_grads) + context = {} + context["role_maker"] = role_maker + context["valid_strategy"] = strategy + util = factory._create_util(context) self.assertEqual(util.role_maker, None) def test_get_util(self): -- GitLab