提交 b1e51836 编写于 作者: Y Yancey1989

overlap sendop and backward ops

上级 2a22da6c
......@@ -36,19 +36,22 @@ class RecvOp : public framework::OperatorBase {
const platform::Place& place) const override {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
PADDLE_ENFORCE(client_.Wait());
PADDLE_ENFORCE(rpc_client->Wait());
}
private:
mutable detail::RPCClient client_;
};
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -56,6 +59,9 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
Recv operator
......
......@@ -15,8 +15,9 @@ from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler
from ps_dispatcher import HashName, RoundRobin
__all__ = [
"DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler",
"memory_optimize", "release_memory"
"memory_optimize", "release_memory", "HashName", "RoundRobin"
]
......@@ -17,7 +17,8 @@ from __future__ import print_function
import math
import distributed_splitter as splitter
from .. import core
from ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework
from ..framework import Program, default_main_program, Variable, Parameter
LOOKUP_TABLE_TYPE = "lookup_table"
......@@ -144,13 +145,27 @@ def delete_ops(block, ops):
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
class DistributeTranspiler:
def transpile(self,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=splitter.round_robin,
split_method=RoundRobin,
sync_mode=True):
"""
Transpile the program to distributed data-parallelism programs.
......@@ -184,14 +199,14 @@ class DistributeTranspiler:
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:param split_method: A instance to determin how to dispatch variable
blocks to different servers equally.
:type split_method: A instance based on PSDispatcher class.
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (callable(split_method))
assert (split_method.__bases__[0] == PSDispatcher)
if program is None:
program = default_main_program()
self.origin_program = program
......@@ -204,6 +219,7 @@ class DistributeTranspiler:
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(pserver_endpoints)
# process lookup_table_op
# 1. check all lookup_table_op is distributed
......@@ -268,56 +284,67 @@ class DistributeTranspiler:
grad_var_mapping = self._append_split_op(program, grad_blocks)
param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks)
# step3: Add gradients as send op inputs and parameters as send
# op outputs.
send_inputs = []
send_outputs = []
for b in grad_blocks: # append by order
varname, block_id, _ = b.split(":")
send_inputs.append(grad_var_mapping[varname][int(block_id)])
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
for i, ep in enumerate(eplist):
param = send_outputs[i]
grad = send_inputs[i]
if not self.param_grad_ep_mapping.has_key(ep):
self.param_grad_ep_mapping[ep] = {"params": [], "grads": []}
self.param_grad_ep_mapping[ep]["params"].append(param)
self.param_grad_ep_mapping[ep]["grads"].append(grad)
rpc_client_var = program.global_block().create_var(
name=RPC_CLIENT_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
# create send_op
# step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
[
self.param_grad_ep_mapping.update({
ep: {
"params": [],
"grads": []
}
}) for ep in self.pserver_endpoints
]
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset()
for varname, send_vars in grad_var_mapping.items():
index = find_op_by_output_arg(program.global_block(), varname)
eplist = ps_dispatcher.dispatch(send_vars)
program.global_block().insert_op(
index=index,
type="send_vars",
inputs={"X": send_vars},
outputs={"RPCClient": rpc_client_var},
attrs={"epmap": eplist})
if self.sync_mode:
program.global_block().append_op(
type="send",
inputs={"X": send_inputs},
outputs={"Out": send_outputs,
"RPCClient": rpc_client_var},
attrs={
"endpoints": pserver_endpoints,
"epmap": eplist,
"sync_mode": self.sync_mode
})
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
type="send_barrier",
inputs={},
outputs={"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints})
# step 3.2: insert recv op to receive parameters from parameter server
ps_dispatcher.reset()
recv_vars = []
for b in param_blocks:
varname, block_id, _ = b.split(":")
recv_vars.append(param_var_mapping[varname][int(block_id)])
for b in grad_blocks:
varname, block_id, _ = b.split(":")
send_vars.append(grad_var_mapping[varname][int(block_id)])
eplist = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
outputs={"Out": [orig_param]},
attrs={"axis": 0})
type="recv",
inputs={},
outputs={"Out": recv_vars,
"RPCClient": rpc_client_var},
attrs={"epmap": eplist})
# TODO(Yancey1989): check dist lookup table
if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist)
......
......@@ -13,45 +13,66 @@
# limitations under the License.
def hash_name(varlist, pserver_endpoints):
class PSDispatcher(object):
"""
hash variable names to several endpoints.
DistributedSpliter is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` inferface.
"""
def __init__(self, pserver_endpoints):
self._eps = pserver_endpoints
self._step = 0
@property
def eps(self):
return self._eps
def reset(self):
self._step = 0
def dispatch(self, varlist):
"""
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
"""
AssertionError("Interface has not been implemented.")
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
class HashName(PSDispatcher):
"""
Hash variable names to servral endpoints
"""
def _hash_block(block_str, total):
def __init__(self, pserver_endpoints):
super(self.__class__, self).__init__(pserver_endpoints)
def _hash_block(self, block_str, total):
return hash(block_str) % total
def dispatch(self, varlist):
eplist = []
for var in varlist:
server_id = _hash_block(var.name(), len(pserver_endpoints))
server_for_param = pserver_endpoints[server_id]
server_id = self._hash_block(var.name(), len(self._eps))
server_for_param = self._eps[server_id]
eplist.append(server_for_param)
return eplist
def round_robin(varlist, pserver_endpoints):
class RoundRobin(PSDispatcher):
"""
Distribute variables to several endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
Distribute variables to serveral endpoints.
"""
assert (len(varlist) >= len(pserver_endpoints))
def __init__(self, pserver_endpoints):
super(self.__class__, self).__init__(pserver_endpoints)
def dispatch(self, varlist):
eplist = []
pserver_idx = 0
for var in varlist:
server_for_param = pserver_endpoints[pserver_idx]
server_for_param = self._eps[self._step]
eplist.append(server_for_param)
pserver_idx += 1
if pserver_idx >= len(pserver_endpoints):
pserver_idx = 0
self._step += 1
if self._step >= len(self._eps):
self._step = 0
return eplist
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册