未验证 提交 0ee6fed0 编写于 作者: W Wu Yi 提交者: GitHub

Refine dist rpc deps (#12899)

* refine dist train RPC deps

* clean up

* clean up

* fix ut

* remove input for fetch_barrier

* follow comments
上级 3a0b6f97
...@@ -754,17 +754,26 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -754,17 +754,26 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type()); node->Op()->Type());
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
if (node->Op()->Type() == "concat") { }
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"fetch_barrier"); void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (ir::Node *input : node->inputs) {
VarHandle *var = nullptr;
for (int place_offset = 0; place_offset < num_places; ++place_offset) {
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[input->Name()];
if (!var_holder.empty()) {
var = var_holder.rbegin()->get();
op_handle->AddInput(var);
}
}
} }
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
// put them into transpiler.
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
...@@ -799,8 +808,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -799,8 +808,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
} }
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
// FIXME(typhoonzero): assume each recv op output one param
// Use the same place as send.
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
...@@ -814,34 +821,44 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -814,34 +821,44 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier and fetch_barrier op can be scheduled on device 0 // send_barrier, fetch_barrier will run on place 0;
op_dev_id = 0; op_dev_id = 0;
} }
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
node->Op()->Type()); node->Op()->Type());
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id])); node->Op()->Type(), places_[op_dev_id]));
// TODO(panyx0718): This might not be needed anymore. if (node->Op()->Type() == "send") {
if (node->Op()->Type() == "send_barrier") { CreateOpHandleIOs(result, node, op_dev_id);
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "send");
} else if (node->Op()->Type() == "recv") {
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"send_barrier");
} else if (node->Op()->Type() == "fetch_barrier") {
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "recv");
} else if (node->Op()->Type() == "send") {
// do nothing
} else { } else {
PADDLE_THROW( // send_barrier, recv, fetch_barrier's inputs are deps var, get them from
"rpc op should be in [" // all places
"send, send_barrier. recv, fetch_barrier]"); auto p = places_[op_dev_id];
} auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
CreateOpHandleIOs(result, node, op_dev_id); SetOpInputsAllPlaces(result, node, places_.size());
for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(*result, output->Name());
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
}
p = places_[outvar_dev_id];
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
}
CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
}
}
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
......
...@@ -132,63 +132,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -132,63 +132,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
} }
} }
std::vector<ir::Node *> send_ops;
ir::Node *send_bar = nullptr;
std::vector<ir::Node *> recv_ops;
ir::Node *fetch_bar = nullptr;
for (ir::Node *node : Nodes()) {
if (node->Name() == "send") {
send_ops.push_back(node);
} else if (node->Name() == "send_barrier") {
PADDLE_ENFORCE(!send_bar, "only has one send barrier");
send_bar = node;
} else if (node->Name() == "recv") {
recv_ops.push_back(node);
} else if (node->Name() == "fetch_barrier") {
PADDLE_ENFORCE(!fetch_bar, "only has one fetch barrier");
fetch_bar = node;
}
}
if (send_bar) {
for (ir::Node *send : send_ops) {
ir::Node *dep_var = CreateControlDepVar();
send->outputs.push_back(dep_var);
dep_var->inputs.push_back(send);
send_bar->inputs.push_back(dep_var);
dep_var->outputs.push_back(send_bar);
}
for (ir::Node *recv : recv_ops) {
ir::Node *dep_var = CreateControlDepVar();
recv->inputs.push_back(dep_var);
dep_var->outputs.push_back(recv);
send_bar->outputs.push_back(dep_var);
dep_var->inputs.push_back(send_bar);
}
}
if (fetch_bar) {
for (ir::Node *recv : recv_ops) {
ir::Node *dep_var = CreateControlDepVar();
recv->outputs.push_back(dep_var);
dep_var->inputs.push_back(recv);
fetch_bar->inputs.push_back(dep_var);
dep_var->outputs.push_back(fetch_bar);
}
}
std::vector<std::string> send_vars = FindDistTrainSendVars(send_ops);
std::vector<std::string> recv_vars = FindDistTrainRecvVars(recv_ops);
for (ir::Node *node : Nodes()) {
if (IsDistTrainOp(node, send_vars, recv_vars)) {
if (fetch_bar && node->Name() == "concat") {
ir::Node *dep_var = CreateControlDepVar();
fetch_bar->outputs.push_back(dep_var);
dep_var->inputs.push_back(fetch_bar);
node->inputs.push_back(dep_var);
dep_var->outputs.push_back(node);
}
}
}
/** /**
* We should handle write after read(WAR) and write after write(WAW) here. * We should handle write after read(WAR) and write after write(WAW) here.
* Because some of the operators of the program can be executed parallelly. * Because some of the operators of the program can be executed parallelly.
......
...@@ -52,6 +52,8 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -52,6 +52,8 @@ class FetchBarrierOp : public framework::OperatorBase {
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
SendBarrier operator SendBarrier operator
......
...@@ -56,6 +56,10 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -56,6 +56,10 @@ class SendBarrierOp : public framework::OperatorBase {
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(Any) Dummy inputs, used for control dependency")
.AsDuplicable();
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
SendBarrier operator SendBarrier operator
......
...@@ -246,7 +246,11 @@ def Send(endpoints, send_vars, dummy_output=None, sync=True): ...@@ -246,7 +246,11 @@ def Send(endpoints, send_vars, dummy_output=None, sync=True):
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
}) })
if sync: if sync:
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints}) helper.append_op(
type="send_barrier",
inputs={"X": dummy_output},
outputs={"Out": []},
attrs={"endpoints": endpoints})
def Recv(endpoints, get_vars, dummy_input=None, sync=True): def Recv(endpoints, get_vars, dummy_input=None, sync=True):
...@@ -282,7 +286,10 @@ def Recv(endpoints, get_vars, dummy_input=None, sync=True): ...@@ -282,7 +286,10 @@ def Recv(endpoints, get_vars, dummy_input=None, sync=True):
attrs={"endpoints": endpoints, attrs={"endpoints": endpoints,
"epmap": epmap}) "epmap": epmap})
if sync: if sync:
helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints}) helper.append_op(
type="fetch_barrier",
outputs={"Out": get_vars},
attrs={"endpoints": endpoints})
return get_vars return get_vars
......
...@@ -134,7 +134,7 @@ class SE_ResNeXt(): ...@@ -134,7 +134,7 @@ class SE_ResNeXt():
size=class_dim, size=class_dim,
act='softmax', act='softmax',
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.2))) initializer=fluid.initializer.Constant(value=0.05)))
return out return out
def shortcut(self, input, ch_out, stride): def shortcut(self, input, ch_out, stride):
...@@ -184,7 +184,7 @@ class SE_ResNeXt(): ...@@ -184,7 +184,7 @@ class SE_ResNeXt():
act=None, act=None,
# avoid pserver CPU init differs from GPU # avoid pserver CPU init differs from GPU
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.2)), initializer=fluid.initializer.Constant(value=0.05)),
bias_attr=False) bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act) return fluid.layers.batch_norm(input=conv, act=act)
...@@ -192,12 +192,18 @@ class SE_ResNeXt(): ...@@ -192,12 +192,18 @@ class SE_ResNeXt():
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True) input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(input=pool, squeeze = fluid.layers.fc(
input=pool,
size=num_channels // reduction_ratio, size=num_channels // reduction_ratio,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.05)),
act='relu') act='relu')
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(input=squeeze, excitation = fluid.layers.fc(
input=squeeze,
size=num_channels, size=num_channels,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.05)),
act='sigmoid') act='sigmoid')
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale return scale
......
...@@ -49,28 +49,32 @@ class TestDistWord2vec2x2(TestDistRunnerBase): ...@@ -49,28 +49,32 @@ class TestDistWord2vec2x2(TestDistRunnerBase):
dtype='float32', dtype='float32',
is_sparse=IS_SPARSE, is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='shared_w', initializer=fluid.initializer.Constant())) name='shared_w',
initializer=fluid.initializer.Constant(value=0.1)))
embed_second = fluid.layers.embedding( embed_second = fluid.layers.embedding(
input=words[1], input=words[1],
size=[dict_size, EMBED_SIZE], size=[dict_size, EMBED_SIZE],
dtype='float32', dtype='float32',
is_sparse=IS_SPARSE, is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='shared_w', initializer=fluid.initializer.Constant())) name='shared_w',
initializer=fluid.initializer.Constant(value=0.1)))
embed_third = fluid.layers.embedding( embed_third = fluid.layers.embedding(
input=words[2], input=words[2],
size=[dict_size, EMBED_SIZE], size=[dict_size, EMBED_SIZE],
dtype='float32', dtype='float32',
is_sparse=IS_SPARSE, is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='shared_w', initializer=fluid.initializer.Constant())) name='shared_w',
initializer=fluid.initializer.Constant(value=0.1)))
embed_forth = fluid.layers.embedding( embed_forth = fluid.layers.embedding(
input=words[3], input=words[3],
size=[dict_size, EMBED_SIZE], size=[dict_size, EMBED_SIZE],
dtype='float32', dtype='float32',
is_sparse=IS_SPARSE, is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='shared_w', initializer=fluid.initializer.Constant())) name='shared_w',
initializer=fluid.initializer.Constant(value=0.1)))
concat_embed = fluid.layers.concat( concat_embed = fluid.layers.concat(
input=[embed_first, embed_second, embed_third, embed_forth], input=[embed_first, embed_second, embed_third, embed_forth],
...@@ -80,13 +84,13 @@ class TestDistWord2vec2x2(TestDistRunnerBase): ...@@ -80,13 +84,13 @@ class TestDistWord2vec2x2(TestDistRunnerBase):
size=HIDDEN_SIZE, size=HIDDEN_SIZE,
act='sigmoid', act='sigmoid',
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant())) initializer=fluid.initializer.Constant(value=0.1)))
predict_word = fluid.layers.fc( predict_word = fluid.layers.fc(
input=hidden1, input=hidden1,
size=dict_size, size=dict_size,
act='softmax', act='softmax',
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant())) initializer=fluid.initializer.Constant(value=0.1)))
cost = fluid.layers.cross_entropy( cost = fluid.layers.cross_entropy(
input=predict_word, label=words[4]) input=predict_word, label=words[4])
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
......
...@@ -100,7 +100,7 @@ class TestSendOp(unittest.TestCase): ...@@ -100,7 +100,7 @@ class TestSendOp(unittest.TestCase):
main.global_block().append_op( main.global_block().append_op(
type="fetch_barrier", type="fetch_barrier",
inputs={}, inputs={},
outputs={}, outputs={"Out": []},
attrs={ attrs={
"endpoints": ["127.0.0.1:{0}".format(port)], "endpoints": ["127.0.0.1:{0}".format(port)],
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
......
...@@ -22,7 +22,7 @@ class TestDistSeResneXt2x2(TestDistBase): ...@@ -22,7 +22,7 @@ class TestDistSeResneXt2x2(TestDistBase):
self._sync_mode = True self._sync_mode = True
def test_se_resnext(self): def test_se_resnext(self):
self.check_with_place("dist_word2vec.py", delta=1e-7) self.check_with_place("dist_word2vec.py", delta=1e-4)
class TestDistSeResneXt2x2Async(TestDistBase): class TestDistSeResneXt2x2Async(TestDistBase):
......
...@@ -283,10 +283,13 @@ class DistributeTranspiler(object): ...@@ -283,10 +283,13 @@ class DistributeTranspiler(object):
send_vars.append(var) send_vars.append(var)
if self.sync_mode: if self.sync_mode:
send_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
input_deps = grad_name_to_send_dummy_out.values()
program.global_block().append_op( program.global_block().append_op(
type="send_barrier", type="send_barrier",
inputs={}, inputs={"X": input_deps},
outputs={}, outputs={"Out": send_barrier_out},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
...@@ -304,16 +307,22 @@ class DistributeTranspiler(object): ...@@ -304,16 +307,22 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
all_recv_outputs = []
for param_varname, splited_var in six.iteritems(self.param_var_mapping): for param_varname, splited_var in six.iteritems(self.param_var_mapping):
eps = [] eps = []
for var in splited_var: for var in splited_var:
index = [v.name for v in recv_vars].index(var.name) index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index]) eps.append(eplist[index])
grad_send_dummy_out = grad_name_to_send_dummy_out[ if self.sync_mode:
recv_dep_in = send_barrier_out
else:
# connect deps to send op in async mode
recv_dep_in = grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]] self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var)
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [grad_send_dummy_out]}, inputs={"X": [recv_dep_in]},
outputs={"Out": splited_var}, outputs={"Out": splited_var},
attrs={ attrs={
"epmap": eps, "epmap": eps,
...@@ -326,10 +335,11 @@ class DistributeTranspiler(object): ...@@ -326,10 +335,11 @@ class DistributeTranspiler(object):
}) })
if self.sync_mode: if self.sync_mode:
# form a WAW dependency
program.global_block().append_op( program.global_block().append_op(
type="fetch_barrier", type="fetch_barrier",
inputs={}, inputs={},
outputs={}, outputs={"Out": all_recv_outputs},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
...@@ -413,10 +423,12 @@ class DistributeTranspiler(object): ...@@ -413,10 +423,12 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
fetch_barrier_out = startup_program.global_block().create_var(
name=framework.generate_control_dev_var_name())
startup_program.global_block().append_op( startup_program.global_block().append_op(
type="fetch_barrier", type="fetch_barrier",
inputs={}, inputs={},
outputs={}, outputs={"Out": fetch_barrier_out},
attrs={ attrs={
"endpoints": self.pserver_endpoints, "endpoints": self.pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册