From 3755564ae14b96a07412b0035bc22cdea52d43ac Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 7 Aug 2020 17:38:36 +0800 Subject: [PATCH] Fix/large scale fix (#25999) * fix large scale KV * fix single training using async ssa graph --- .../details/async_ssa_graph_executor.cc | 37 +++++++++++++++---- .../distribute_transpiler/__init__.py | 2 +- .../fleet/parameter_server/ir/public.py | 2 +- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 1cf4eb6c29..d42bd0b16d 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -45,14 +45,35 @@ inline void InitVarsInScope(const std::vector &var_infos, Scope *scope, // get CommContext and remote send and recv op void ProcessGraph(std::vector graphs, Scope *scope) { #ifdef PADDLE_WITH_DISTRIBUTE - // init communicator here - auto *instance = operators::distributed::Communicator::GetInstance(); - auto initialized = instance ? true : false; - PADDLE_ENFORCE_EQ(initialized, true, - platform::errors::InvalidArgument( - "Communicator is not Initialized, you may use " - "FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/" - "develop/markdown_doc/transpiler)")); + + bool need_communicator = false; + + for (auto &node : graphs[0]->Nodes()) { + VLOG(3) << "node name " << node->Name(); + if (node && node->IsOp()) { + if (node->Name() == "send") { + auto send_varnames = + BOOST_GET_CONST(std::vector, + node->Op()->GetNullableAttr("send_varnames")); + + if (send_varnames.size() > 0) { + need_communicator = true; + break; + } + } + } + } + + if (need_communicator) { + // init communicator here + auto *instance = operators::distributed::Communicator::GetInstance(); + auto initialized = instance ? true : false; + PADDLE_ENFORCE_EQ(initialized, true, + platform::errors::InvalidArgument( + "Communicator is not Initialized, you may use " + "FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/" + "develop/markdown_doc/transpiler)")); + } #endif } diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index a7d86411e2..d2c7397c85 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -579,7 +579,7 @@ class FleetTranspiler(Fleet): block.append_op( type='recv_save', attrs={ - "trainer_id": self._role_maker.worker_id(), + "trainer_id": self._role_maker.worker_index(), "shape": var.shape, "slice_shapes": [",".join([str(i) for i in var.shape])], diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index 2056e3deb1..b96eff19e9 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -329,7 +329,7 @@ class CompileTimeStrategy(object): is_distributed = True if param_name in distibuted_varnames else False - ctx = self.build_ctx(grad, self.grad_var_mapping, True, False, + ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx -- GitLab