diff --git a/python/paddle/fleet/base/fleet_base.py b/python/paddle/fleet/base/fleet_base.py index a9238df629245d9ccae8e71226bac2a1c1c74af3..979b878a3df966a3af59cee126b884361f5b6ac7 100644 --- a/python/paddle/fleet/base/fleet_base.py +++ b/python/paddle/fleet/base/fleet_base.py @@ -279,8 +279,11 @@ class Fleet(object): # for more examples, please reference https://github.com/PaddlePaddle/Fleet """ + context = {} # cache original feed forward program self.origin_main_program = loss.block.program + context["origin_main_program"] = self.origin_main_program + context["loss"] = loss if startup_program == None: self.origin_startup_program = \ paddle.default_startup_program().clone(for_test=False) @@ -288,6 +291,8 @@ class Fleet(object): else: self.origin_startup_program = \ startup_program.clone(for_test=False) + context["origin_startup_program"] = startup_program + context["role_maker"] = self._role_maker # compile time distributed_optimizer_list = \ @@ -317,6 +322,9 @@ class Fleet(object): valid_strategy = self.strategy_compiler._get_valid_strategy( self.user_defined_strategy, can_not_apply_optimizer_list) + + context["valid_strategy"] = valid_strategy + self.valid_strategy = valid_strategy optimize_ops = [] @@ -334,6 +342,8 @@ class Fleet(object): parameter_list=parameter_list, no_grad_set=no_grad_set) + context["program_optimize_ops"] = optimize_ops + context["program_params_grads"] = params_grads if graph_optimizer: optimize_ops, params_grads = graph_optimizer.minimize( loss, @@ -344,12 +354,13 @@ class Fleet(object): # if a graph optimizer takes effect, mostly # optimizers_ops and params_grads are None # i.e. users can not modify current computation graph anymore + context["graph_optimize_ops"] = optimize_ops + context["graph_optimize_grads"] = params_grads + if self._runtime_handle is None: - self._runtime_handle = RuntimeFactory()._create_runtime( - valid_strategy, self._role_maker, optimize_ops, params_grads) + self._runtime_handle = RuntimeFactory()._create_runtime(context) if self._util is None: - self._util = UtilFactory()._create_util( - valid_strategy, self._role_maker, optimize_ops, params_grads) + self._util = UtilFactory()._create_util(context) return optimize_ops, params_grads diff --git a/python/paddle/fleet/base/runtime_factory.py b/python/paddle/fleet/base/runtime_factory.py index c4d42db4ea993d9241222d42595e2c0d6af0a2d7..45dca6dae4e065ba6f2a9f09ac8cf298222b2d15 100644 --- a/python/paddle/fleet/base/runtime_factory.py +++ b/python/paddle/fleet/base/runtime_factory.py @@ -18,10 +18,8 @@ class RuntimeFactory(object): def __init__(self): pass - def _create_runtime(self, final_dist_strategy, role_maker, opt_ops, - params_grads): - if role_maker._is_collective: + def _create_runtime(self, context): + if context["role_maker"]._is_collective: collective_runtime = CollectiveRuntime() - collective_runtime._set_basic_info(final_dist_strategy, role_maker, - opt_ops, params_grads) + collective_runtime._set_basic_info(context) return collective_runtime diff --git a/python/paddle/fleet/base/util_factory.py b/python/paddle/fleet/base/util_factory.py index 74029f43d10c86dadb052000884fa9df7a667f72..385500de8c018853fe46205fc3d5bc6aac1aa22d 100644 --- a/python/paddle/fleet/base/util_factory.py +++ b/python/paddle/fleet/base/util_factory.py @@ -20,11 +20,10 @@ __all__ = ['UtilBase'] class UtilFactory(object): - def _create_util(self, dist_strategy, role_maker, optimize_ops, - params_grads): + def _create_util(self, context): util = UtilBase() - util._set_strategy(dist_strategy) - util._set_role_maker(role_maker) + util._set_strategy(context["valid_strategy"]) + util._set_role_maker(context["role_maker"]) return util diff --git a/python/paddle/fleet/runtime/runtime_base.py b/python/paddle/fleet/runtime/runtime_base.py index 5610a5305a464e39e9ab5a6bb7594e5e225a12ba..c7ce8b5a2914bf30f346cbd0777d1d233ddf5e1b 100644 --- a/python/paddle/fleet/runtime/runtime_base.py +++ b/python/paddle/fleet/runtime/runtime_base.py @@ -19,11 +19,8 @@ class RuntimeBase(object): def __init__(self): pass - def _set_basic_info(self, loss, role_maker, optimizer, strategy): - self.loss = loss - self.role_maker = role_maker - self.optimizer = optimizer - self.strategy = strategy + def _set_basic_info(self, context): + self.context = context def _run_worker(self): pass diff --git a/python/paddle/fluid/tests/unittests/test_fleet_util.py b/python/paddle/fluid/tests/unittests/test_fleet_util.py index 4825035d123df1767fe7845b2515f7d42253446c..427e077416e979ad5a77f4744ba6ffdb5064fdff 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_util.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_util.py @@ -33,8 +33,10 @@ class TestFleetUtil(unittest.TestCase): role_maker = None # should be fleet.PaddleCloudRoleMaker() optimize_ops = [] params_grads = [] - util = factory._create_util(strategy, role_maker, optimize_ops, - params_grads) + context = {} + context["role_maker"] = role_maker + context["valid_strategy"] = strategy + util = factory._create_util(context) self.assertEqual(util.role_maker, None) def test_get_util(self):