未验证 提交 3755564a 编写于 作者: T tangwei12 提交者: GitHub

Fix/large scale fix (#25999)

* fix large scale KV 
* fix single training using async ssa graph
上级 751305ec
...@@ -45,14 +45,35 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope, ...@@ -45,14 +45,35 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
// get CommContext and remote send and recv op // get CommContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// init communicator here
auto *instance = operators::distributed::Communicator::GetInstance(); bool need_communicator = false;
auto initialized = instance ? true : false;
PADDLE_ENFORCE_EQ(initialized, true, for (auto &node : graphs[0]->Nodes()) {
platform::errors::InvalidArgument( VLOG(3) << "node name " << node->Name();
"Communicator is not Initialized, you may use " if (node && node->IsOp()) {
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/" if (node->Name() == "send") {
"develop/markdown_doc/transpiler)")); auto send_varnames =
BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("send_varnames"));
if (send_varnames.size() > 0) {
need_communicator = true;
break;
}
}
}
}
if (need_communicator) {
// init communicator here
auto *instance = operators::distributed::Communicator::GetInstance();
auto initialized = instance ? true : false;
PADDLE_ENFORCE_EQ(initialized, true,
platform::errors::InvalidArgument(
"Communicator is not Initialized, you may use "
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)"));
}
#endif #endif
} }
......
...@@ -579,7 +579,7 @@ class FleetTranspiler(Fleet): ...@@ -579,7 +579,7 @@ class FleetTranspiler(Fleet):
block.append_op( block.append_op(
type='recv_save', type='recv_save',
attrs={ attrs={
"trainer_id": self._role_maker.worker_id(), "trainer_id": self._role_maker.worker_index(),
"shape": var.shape, "shape": var.shape,
"slice_shapes": "slice_shapes":
[",".join([str(i) for i in var.shape])], [",".join([str(i) for i in var.shape])],
......
...@@ -329,7 +329,7 @@ class CompileTimeStrategy(object): ...@@ -329,7 +329,7 @@ class CompileTimeStrategy(object):
is_distributed = True if param_name in distibuted_varnames else False is_distributed = True if param_name in distibuted_varnames else False
ctx = self.build_ctx(grad, self.grad_var_mapping, True, False, ctx = self.build_ctx(grad, self.grad_var_mapping, True, True,
True, is_distributed) True, is_distributed)
send_ctx[ctx.var_name()] = ctx send_ctx[ctx.var_name()] = ctx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册