提交 b1e51836 编写于 作者: Y Yancey1989

overlap sendop and backward ops

上级 2a22da6c
...@@ -36,19 +36,22 @@ class RecvOp : public framework::OperatorBase { ...@@ -36,19 +36,22 @@ class RecvOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
PADDLE_ENFORCE(client_.Wait()); PADDLE_ENFORCE(rpc_client->Wait());
} }
private:
mutable detail::RPCClient client_;
}; };
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -56,6 +59,9 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,6 +59,9 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker) RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Recv operator Recv operator
......
...@@ -15,8 +15,9 @@ from distribute_transpiler import DistributeTranspiler ...@@ -15,8 +15,9 @@ from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory from memory_optimization_transpiler import memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler from distribute_transpiler_simple import SimpleDistributeTranspiler
from ps_dispatcher import HashName, RoundRobin
__all__ = [ __all__ = [
"DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler", "DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler",
"memory_optimize", "release_memory" "memory_optimize", "release_memory", "HashName", "RoundRobin"
] ]
...@@ -17,7 +17,8 @@ from __future__ import print_function ...@@ -17,7 +17,8 @@ from __future__ import print_function
import math import math
import distributed_splitter as splitter import distributed_splitter as splitter
from .. import core from ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework
from ..framework import Program, default_main_program, Variable, Parameter from ..framework import Program, default_main_program, Variable, Parameter
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
...@@ -144,13 +145,27 @@ def delete_ops(block, ops): ...@@ -144,13 +145,27 @@ def delete_ops(block, ops):
block.program.sync_with_cpp() block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
class DistributeTranspiler: class DistributeTranspiler:
def transpile(self, def transpile(self,
trainer_id, trainer_id,
program=None, program=None,
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
split_method=splitter.round_robin, split_method=RoundRobin,
sync_mode=True): sync_mode=True):
""" """
Transpile the program to distributed data-parallelism programs. Transpile the program to distributed data-parallelism programs.
...@@ -184,14 +199,14 @@ class DistributeTranspiler: ...@@ -184,14 +199,14 @@ class DistributeTranspiler:
:type pservers: string :type pservers: string
:param trainers: total number of workers/trainers in the job :param trainers: total number of workers/trainers in the job
:type trainers: int :type trainers: int
:param split_method: A function to determin how to split variables :param split_method: A instance to determin how to dispatch variable
to different servers equally. blocks to different servers equally.
:type split_method: function :type split_method: A instance based on PSDispatcher class.
:type sync_mode: boolean default True :type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler :param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program. will transpile the program into sync_mode pserver and trainer program.
""" """
assert (callable(split_method)) assert (split_method.__bases__[0] == PSDispatcher)
if program is None: if program is None:
program = default_main_program() program = default_main_program()
self.origin_program = program self.origin_program = program
...@@ -204,6 +219,7 @@ class DistributeTranspiler: ...@@ -204,6 +219,7 @@ class DistributeTranspiler:
pserver_endpoints = pservers.split(",") pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass() self.optimize_ops, params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(pserver_endpoints)
# process lookup_table_op # process lookup_table_op
# 1. check all lookup_table_op is distributed # 1. check all lookup_table_op is distributed
...@@ -268,56 +284,67 @@ class DistributeTranspiler: ...@@ -268,56 +284,67 @@ class DistributeTranspiler:
grad_var_mapping = self._append_split_op(program, grad_blocks) grad_var_mapping = self._append_split_op(program, grad_blocks)
param_var_mapping = self._create_vars_from_blocklist(program, param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks) param_blocks)
# step3: Add gradients as send op inputs and parameters as send
# op outputs.
send_inputs = []
send_outputs = []
for b in grad_blocks: # append by order
varname, block_id, _ = b.split(":")
send_inputs.append(grad_var_mapping[varname][int(block_id)])
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
for i, ep in enumerate(eplist):
param = send_outputs[i]
grad = send_inputs[i]
if not self.param_grad_ep_mapping.has_key(ep):
self.param_grad_ep_mapping[ep] = {"params": [], "grads": []}
self.param_grad_ep_mapping[ep]["params"].append(param)
self.param_grad_ep_mapping[ep]["grads"].append(grad)
rpc_client_var = program.global_block().create_var( rpc_client_var = program.global_block().create_var(
name=RPC_CLIENT_VAR_NAME, name=RPC_CLIENT_VAR_NAME,
persistable=True, persistable=True,
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
# create send_op # step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
[
self.param_grad_ep_mapping.update({
ep: {
"params": [],
"grads": []
}
}) for ep in self.pserver_endpoints
]
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset()
for varname, send_vars in grad_var_mapping.items():
index = find_op_by_output_arg(program.global_block(), varname)
eplist = ps_dispatcher.dispatch(send_vars)
program.global_block().insert_op(
index=index,
type="send_vars",
inputs={"X": send_vars},
outputs={"RPCClient": rpc_client_var},
attrs={"epmap": eplist})
if self.sync_mode:
program.global_block().append_op(
type="send_barrier",
inputs={},
outputs={"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints})
# step 3.2: insert recv op to receive parameters from parameter server
ps_dispatcher.reset()
recv_vars = []
for b in param_blocks:
varname, block_id, _ = b.split(":")
recv_vars.append(param_var_mapping[varname][int(block_id)])
for b in grad_blocks:
varname, block_id, _ = b.split(":")
send_vars.append(grad_var_mapping[varname][int(block_id)])
eplist = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
program.global_block().append_op( program.global_block().append_op(
type="send", type="recv",
inputs={"X": send_inputs}, inputs={},
outputs={"Out": send_outputs, outputs={"Out": recv_vars,
"RPCClient": rpc_client_var}, "RPCClient": rpc_client_var},
attrs={ attrs={"epmap": eplist})
"endpoints": pserver_endpoints,
"epmap": eplist,
"sync_mode": self.sync_mode
})
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
outputs={"Out": [orig_param]},
attrs={"axis": 0})
# TODO(Yancey1989): check dist lookup table
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist) eplist)
......
...@@ -13,45 +13,66 @@ ...@@ -13,45 +13,66 @@
# limitations under the License. # limitations under the License.
def hash_name(varlist, pserver_endpoints): class PSDispatcher(object):
""" """
hash variable names to several endpoints. DistributedSpliter is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` inferface.
"""
def __init__(self, pserver_endpoints):
self._eps = pserver_endpoints
self._step = 0
@property
def eps(self):
return self._eps
def reset(self):
self._step = 0
def dispatch(self, varlist):
"""
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
"""
AssertionError("Interface has not been implemented.")
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname class HashName(PSDispatcher):
""" """
Hash variable names to servral endpoints
"""
def __init__(self, pserver_endpoints):
super(self.__class__, self).__init__(pserver_endpoints)
def _hash_block(block_str, total): def _hash_block(self, block_str, total):
return hash(block_str) % total return hash(block_str) % total
eplist = [] def dispatch(self, varlist):
for var in varlist: eplist = []
server_id = _hash_block(var.name(), len(pserver_endpoints)) for var in varlist:
server_for_param = pserver_endpoints[server_id] server_id = self._hash_block(var.name(), len(self._eps))
eplist.append(server_for_param) server_for_param = self._eps[server_id]
return eplist eplist.append(server_for_param)
return eplist
def round_robin(varlist, pserver_endpoints): class RoundRobin(PSDispatcher):
""" """
Distribute variables to several endpoints. Distribute variables to serveral endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
""" """
assert (len(varlist) >= len(pserver_endpoints))
def __init__(self, pserver_endpoints):
eplist = [] super(self.__class__, self).__init__(pserver_endpoints)
pserver_idx = 0
for var in varlist: def dispatch(self, varlist):
server_for_param = pserver_endpoints[pserver_idx] eplist = []
eplist.append(server_for_param) for var in varlist:
server_for_param = self._eps[self._step]
pserver_idx += 1 eplist.append(server_for_param)
if pserver_idx >= len(pserver_endpoints): self._step += 1
pserver_idx = 0 if self._step >= len(self._eps):
return eplist self._step = 0
return eplist
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册