未验证 提交 fd10669e 编写于 作者: Q Qiao Longfei 提交者: GitHub

Add dependency to send recv (#12760)

Add dependency to send recv
上级 7c5f08e5
......@@ -57,6 +57,8 @@ class RecvOp : public framework::OperatorBase {
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Any) Dummy inputs, used for control dependency")
.AsDuplicable();
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddComment(R"DOC(
Recv operator
......
......@@ -37,23 +37,20 @@ class SendBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
VLOG(3) << "SendBarrierOp sync";
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
if (sync_mode) {
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
}
};
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -70,7 +67,6 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
}
};
......
......@@ -66,6 +66,8 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable();
AddComment(R"DOC(
Send operator
......
......@@ -24,7 +24,7 @@ from .layer_function_generator import templatedoc
from .. import core
from ..executor import global_scope
from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
default_startup_program, program_guard, Program
default_startup_program, program_guard, Program, Variable
from ..layer_helper import LayerHelper
from ..unique_name import generate as unique_name
......@@ -209,7 +209,7 @@ class ListenAndServ(object):
})
def Send(endpoints, send_vars, sync=True):
def Send(endpoints, send_vars, dummy_output=None, sync=True):
"""
Send variables to the server side, and get vars from server
side when server have finished running server side program.
......@@ -223,6 +223,13 @@ def Send(endpoints, send_vars, sync=True):
"""
assert (type(send_vars) == list)
if dummy_output is None:
dummy_output = []
elif isinstance(dummy_output, Variable):
dummy_output = [dummy_output]
assert (type(dummy_output) == list)
epmap = endpoints.split(",")
endpoints = list(set(epmap))
......@@ -232,6 +239,7 @@ def Send(endpoints, send_vars, sync=True):
helper.append_op(
type="send",
inputs={"X": send_vars},
outputs={"Out": dummy_output},
attrs={
"endpoints": endpoints,
"epmap": epmap,
......@@ -241,7 +249,7 @@ def Send(endpoints, send_vars, sync=True):
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
def Recv(endpoints, get_vars, sync=True):
def Recv(endpoints, get_vars, dummy_input=None, sync=True):
"""
Receive variables from server side
......@@ -256,13 +264,20 @@ def Recv(endpoints, get_vars, sync=True):
"""
assert (type(get_vars) == list)
if dummy_input is None:
dummy_input = []
elif isinstance(dummy_input, Variable):
dummy_input = [dummy_input]
assert (type(dummy_input) == list)
epmap = endpoints.split(",")
endpoints = list(set(epmap))
helper = LayerHelper("Recv", **locals())
helper.append_op(
type="recv",
inputs={"X": get_vars},
inputs={"X": dummy_input},
outputs={"Out": get_vars},
attrs={"endpoints": endpoints,
"epmap": epmap})
......
......@@ -211,6 +211,9 @@ class DistributeTranspiler(object):
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
self.param_name_to_grad_name = dict()
for param_var, grad_var in self.params_grads:
self.param_name_to_grad_name[param_var.name] = grad_var.name
# step 1: split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars()
......@@ -230,34 +233,39 @@ class DistributeTranspiler(object):
random.seed(self.origin_program.random_seed)
random.shuffle(grad_var_mapping_items)
for orig_varname, splited_vars in grad_var_mapping_items:
grad_name_to_send_dummy_out = dict()
for grad_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars)
if not self.config.slice_var_up:
assert (len(splited_vars) == 1)
splited_grad_varname = grad_varname
if len(splited_vars) == 1:
orig_varname = splited_vars[0].name
splited_grad_varname = splited_vars[0].name
index = find_op_by_output_arg(program.global_block(),
orig_varname)
splited_grad_varname)
elif len(splited_vars) > 1:
orig_var = program.global_block().vars[orig_varname]
orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg(program.global_block(),
orig_varname)
splited_grad_varname)
self._insert_split_op(program, orig_var, index, splited_vars)
index += 1
else:
AssertionError("Can not insert the send op by original "
"variable name :", orig_varname)
"variable name :", splited_grad_varname)
dummy_output = program.global_block().create_var()
grad_name_to_send_dummy_out[grad_varname] = dummy_output
program.global_block()._insert_op(
index=index + 1,
type="send",
inputs={"X": splited_vars},
outputs={},
outputs={"Out": dummy_output},
attrs={
"epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
"sync_mode": not self.sync_mode,
})
for _, var in enumerate(splited_vars):
send_vars.append(var)
......@@ -269,7 +277,6 @@ class DistributeTranspiler(object):
outputs={},
attrs={
"endpoints": pserver_endpoints,
"sync_mode": self.sync_mode,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
......@@ -285,19 +292,21 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in six.iteritems(self.param_var_mapping):
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index])
grad_send_dummy_out = grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]]
program.global_block().append_op(
type="recv",
inputs={},
inputs={"X": [grad_send_dummy_out]},
outputs={"Out": splited_var},
attrs={
"epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
"sync_mode": not self.sync_mode
})
if self.sync_mode:
......@@ -310,10 +319,10 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in six.iteritems(self.param_var_mapping):
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
orig_param = program.global_block().vars[param_varname]
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
......@@ -381,7 +390,7 @@ class DistributeTranspiler(object):
op = startup_program.global_block().append_op(
type="recv",
inputs={},
inputs={"X": []},
outputs={"Out": splited_var},
attrs={
"epmap": eps,
......@@ -787,13 +796,15 @@ class DistributeTranspiler(object):
self.config.min_block_size)
assert (len(grad_blocks) == len(param_blocks))
# origin_varname -> [splited_var]
# origin_param_name -> [splited_param_vars]
self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks)
# origin_grad_name -> [splited_grad_vars]
self.grad_var_mapping = self._create_vars_from_blocklist(
self.origin_program,
grad_blocks,
add_trainer_suffix=self.trainer_num > 1)
# dict(grad_splited_var -> param_splited_var)
self.grad_param_mapping = collections.OrderedDict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
......@@ -920,7 +931,7 @@ class DistributeTranspiler(object):
index=op_index + 2,
type="send",
inputs={'X': self.trainer_side_table_grad_list},
outputs={},
outputs={'Out': []},
attrs={
"sync_mode": True,
"epmap": pserver_endpoints,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册