From b1e5183627e4aad56400b620342a55434321d544 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 10 May 2018 11:22:46 +0800 Subject: [PATCH] overlap sendop and backward ops --- paddle/fluid/operators/recv_op.cc | 18 ++- python/paddle/fluid/transpiler/__init__.py | 3 +- .../fluid/transpiler/distribute_transpiler.py | 123 +++++++++++------- .../fluid/transpiler/distributed_splitter.py | 57 -------- .../paddle/fluid/transpiler/ps_dispatcher.py | 78 +++++++++++ 5 files changed, 167 insertions(+), 112 deletions(-) delete mode 100644 python/paddle/fluid/transpiler/distributed_splitter.py create mode 100644 python/paddle/fluid/transpiler/ps_dispatcher.py diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index a4dcf704a63..aeb93c99817 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -36,19 +36,22 @@ class RecvOp : public framework::OperatorBase { const platform::Place& place) const override { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); + auto client_var_name = Output("RPCClient"); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), + "Can not find variable '%s' in the scope.", + client_var_name); + auto* client_var = scope.FindVar(client_var_name); + detail::RPCClient* rpc_client = client_var->GetMutable(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); for (size_t i = 0; i < outs.size(); i++) { - VLOG(3) << "getting " << outs[i]; - client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; + rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - PADDLE_ENFORCE(client_.Wait()); + PADDLE_ENFORCE(rpc_client->Wait()); } - - private: - mutable detail::RPCClient client_; }; class RecvOpMaker : public framework::OpProtoAndCheckerMaker { @@ -56,6 +59,9 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker { RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); + AddOutput("RPCClient", + "(RPCClient) The RPC client object which is" + "initialized at most once."); AddComment(R"DOC( Recv operator diff --git a/python/paddle/fluid/transpiler/__init__.py b/python/paddle/fluid/transpiler/__init__.py index 6d3c1b947f4..f21e4dc033e 100644 --- a/python/paddle/fluid/transpiler/__init__.py +++ b/python/paddle/fluid/transpiler/__init__.py @@ -15,8 +15,9 @@ from distribute_transpiler import DistributeTranspiler from inference_transpiler import InferenceTranspiler from memory_optimization_transpiler import memory_optimize, release_memory from distribute_transpiler_simple import SimpleDistributeTranspiler +from ps_dispatcher import HashName, RoundRobin __all__ = [ "DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler", - "memory_optimize", "release_memory" + "memory_optimize", "release_memory", "HashName", "RoundRobin" ] diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 640ac9f085e..05ffdefe05b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -17,7 +17,8 @@ from __future__ import print_function import math import distributed_splitter as splitter -from .. import core +from ps_dispatcher import RoundRobin, HashName, PSDispatcher +from .. import core, framework from ..framework import Program, default_main_program, Variable, Parameter LOOKUP_TABLE_TYPE = "lookup_table" @@ -144,13 +145,27 @@ def delete_ops(block, ops): block.program.sync_with_cpp() +def find_op_by_input_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.input_arg_names: + return index + return -1 + + +def find_op_by_output_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.output_arg_names: + return index + return -1 + + class DistributeTranspiler: def transpile(self, trainer_id, program=None, pservers="127.0.0.1:6174", trainers=1, - split_method=splitter.round_robin, + split_method=RoundRobin, sync_mode=True): """ Transpile the program to distributed data-parallelism programs. @@ -184,14 +199,14 @@ class DistributeTranspiler: :type pservers: string :param trainers: total number of workers/trainers in the job :type trainers: int - :param split_method: A function to determin how to split variables - to different servers equally. - :type split_method: function + :param split_method: A instance to determin how to dispatch variable + blocks to different servers equally. + :type split_method: A instance based on PSDispatcher class. :type sync_mode: boolean default True :param sync_mode: if sync_mode is set True, it means that dist transpiler will transpile the program into sync_mode pserver and trainer program. """ - assert (callable(split_method)) + assert (split_method.__bases__[0] == PSDispatcher) if program is None: program = default_main_program() self.origin_program = program @@ -204,6 +219,7 @@ class DistributeTranspiler: pserver_endpoints = pservers.split(",") self.pserver_endpoints = pserver_endpoints self.optimize_ops, params_grads = self._get_optimize_pass() + ps_dispatcher = split_method(pserver_endpoints) # process lookup_table_op # 1. check all lookup_table_op is distributed @@ -268,56 +284,67 @@ class DistributeTranspiler: grad_var_mapping = self._append_split_op(program, grad_blocks) param_var_mapping = self._create_vars_from_blocklist(program, param_blocks) - # step3: Add gradients as send op inputs and parameters as send - # op outputs. - send_inputs = [] - send_outputs = [] - for b in grad_blocks: # append by order - varname, block_id, _ = b.split(":") - send_inputs.append(grad_var_mapping[varname][int(block_id)]) - for b in param_blocks: - varname, block_id, _ = b.split(":") - send_outputs.append(param_var_mapping[varname][int(block_id)]) - # let send_op know which endpoint to send which var to, eplist has the same - # order as send_inputs. - eplist = split_method(send_inputs, pserver_endpoints) - # create mapping of endpoint -> split var to create pserver side program - self.param_grad_ep_mapping = dict() - for i, ep in enumerate(eplist): - param = send_outputs[i] - grad = send_inputs[i] - if not self.param_grad_ep_mapping.has_key(ep): - self.param_grad_ep_mapping[ep] = {"params": [], "grads": []} - self.param_grad_ep_mapping[ep]["params"].append(param) - self.param_grad_ep_mapping[ep]["grads"].append(grad) - rpc_client_var = program.global_block().create_var( name=RPC_CLIENT_VAR_NAME, persistable=True, type=core.VarDesc.VarType.RAW) - # create send_op + # step 3: transpile trainer side program, insert recv op and send op. + + # create mapping of endpoint -> split var to create pserver side program + self.param_grad_ep_mapping = dict() + [ + self.param_grad_ep_mapping.update({ + ep: { + "params": [], + "grads": [] + } + }) for ep in self.pserver_endpoints + ] + + # step 3.1: insert send op to send gradient vars to parameter servers + ps_dispatcher.reset() + for varname, send_vars in grad_var_mapping.items(): + index = find_op_by_output_arg(program.global_block(), varname) + eplist = ps_dispatcher.dispatch(send_vars) + program.global_block().insert_op( + index=index, + type="send_vars", + inputs={"X": send_vars}, + outputs={"RPCClient": rpc_client_var}, + attrs={"epmap": eplist}) + + if self.sync_mode: + program.global_block().append_op( + type="send_barrier", + inputs={}, + outputs={"RPCClient": rpc_client_var}, + attrs={"endpoints": pserver_endpoints}) + + # step 3.2: insert recv op to receive parameters from parameter server + ps_dispatcher.reset() + recv_vars = [] + for b in param_blocks: + varname, block_id, _ = b.split(":") + recv_vars.append(param_var_mapping[varname][int(block_id)]) + for b in grad_blocks: + varname, block_id, _ = b.split(":") + send_vars.append(grad_var_mapping[varname][int(block_id)]) + + eplist = ps_dispatcher.dispatch(recv_vars) + + for i, ep in enumerate(eplist): + self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) + self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + program.global_block().append_op( - type="send", - inputs={"X": send_inputs}, - outputs={"Out": send_outputs, + type="recv", + inputs={}, + outputs={"Out": recv_vars, "RPCClient": rpc_client_var}, - attrs={ - "endpoints": pserver_endpoints, - "epmap": eplist, - "sync_mode": self.sync_mode - }) - # step4: Concat the parameters splits together after recv. - for varname, splited_var in param_var_mapping.iteritems(): - if len(splited_var) <= 1: - continue - orig_param = program.global_block().vars[varname] - program.global_block().append_op( - type="concat", - inputs={"X": splited_var}, - outputs={"Out": [orig_param]}, - attrs={"axis": 0}) + attrs={"epmap": eplist}) + # TODO(Yancey1989): check dist lookup table if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, eplist) diff --git a/python/paddle/fluid/transpiler/distributed_splitter.py b/python/paddle/fluid/transpiler/distributed_splitter.py deleted file mode 100644 index 060c1df8ad2..00000000000 --- a/python/paddle/fluid/transpiler/distributed_splitter.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def hash_name(varlist, pserver_endpoints): - """ - hash variable names to several endpoints. - - Args: - varlist(list): a list of Variables - - Returns(dict): a map of pserver endpoint -> varname - """ - - def _hash_block(block_str, total): - return hash(block_str) % total - - eplist = [] - for var in varlist: - server_id = _hash_block(var.name(), len(pserver_endpoints)) - server_for_param = pserver_endpoints[server_id] - eplist.append(server_for_param) - return eplist - - -def round_robin(varlist, pserver_endpoints): - """ - Distribute variables to several endpoints. - Args: - varlist(list): a list of variables - pserver_endpoints(list): a list of pserver endpoints - - Returns(list[int]): the endpoint for each variable - """ - assert (len(varlist) >= len(pserver_endpoints)) - - eplist = [] - pserver_idx = 0 - for var in varlist: - server_for_param = pserver_endpoints[pserver_idx] - eplist.append(server_for_param) - - pserver_idx += 1 - if pserver_idx >= len(pserver_endpoints): - pserver_idx = 0 - return eplist diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py new file mode 100644 index 00000000000..dffe66998a4 --- /dev/null +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -0,0 +1,78 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class PSDispatcher(object): + """ + DistributedSpliter is the base class for dispatching vars + into different pserver instance. + You need to implement the `dispatch` inferface. + """ + + def __init__(self, pserver_endpoints): + self._eps = pserver_endpoints + self._step = 0 + + @property + def eps(self): + return self._eps + + def reset(self): + self._step = 0 + + def dispatch(self, varlist): + """ + :param varlist: a list of Variables + :return: a map of pserver endpoint -> varname + """ + AssertionError("Interface has not been implemented.") + + +class HashName(PSDispatcher): + """ + Hash variable names to servral endpoints + """ + + def __init__(self, pserver_endpoints): + super(self.__class__, self).__init__(pserver_endpoints) + + def _hash_block(self, block_str, total): + return hash(block_str) % total + + def dispatch(self, varlist): + eplist = [] + for var in varlist: + server_id = self._hash_block(var.name(), len(self._eps)) + server_for_param = self._eps[server_id] + eplist.append(server_for_param) + return eplist + + +class RoundRobin(PSDispatcher): + """ + Distribute variables to serveral endpoints. + """ + + def __init__(self, pserver_endpoints): + super(self.__class__, self).__init__(pserver_endpoints) + + def dispatch(self, varlist): + eplist = [] + for var in varlist: + server_for_param = self._eps[self._step] + eplist.append(server_for_param) + self._step += 1 + if self._step >= len(self._eps): + self._step = 0 + return eplist -- GitLab