From e9abc66910a9ee613c60c6ccfcba86f3eed8d429 Mon Sep 17 00:00:00 2001 From: Yancey1989 <yancey1989@gmail.com> Date: Tue, 22 May 2018 16:48:40 +0800 Subject: [PATCH] fix pe --- .../details/computation_op_handle.cc | 2 + .../details/multi_devices_graph_builder.cc | 84 +++++++++++++------ .../details/multi_devices_graph_builder.h | 14 +++- paddle/fluid/operators/detail/grpc_client.cc | 6 -- .../fluid/transpiler/distribute_transpiler.py | 10 +++ 5 files changed, 82 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index df05bb0633..f6e1208a01 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -29,7 +29,9 @@ void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); this->RunAndRecordEvent([this] { + VLOG(3) << "begin run op type is " << op_->Type(); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); + VLOG(3) << "end run op type is " << op_->Type(); }); } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 50998fb8e0..fb5b8608b3 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" -#include <fstream> #include <utility> #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" @@ -79,9 +78,39 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, CreateOpOutput(result, op_handle, each_var_name, p, place_id); } } -bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, - OpDesc *send_op) const { - if (send_op == nullptr) { + +std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( + const ProgramDesc &program) const { + std::vector<std::string> send_vars; + for (auto *op : program.Block(0).AllOps()) { + if (op->Type() == "send_vars" || op->Type() == "send") { + auto op_vars = op->InputArgumentNames(); + send_vars.reserve(send_vars.size() + + std::distance(op_vars.begin(), op_vars.end())); + send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end()); + } + } + return send_vars; +} + +std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( + const ProgramDesc &program) const { + std::vector<std::string> recv_vars; + for (auto *op : program.Block(0).AllOps()) { + if (op->Type() == "recv" || op->Type() == "send") { + auto op_vars = op->OutputArgumentNames(); + recv_vars.reserve(recv_vars.size() + + std::distance(op_vars.begin(), op_vars.end())); + recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end()); + } + } + return recv_vars; +} + +bool MultiDevSSAGraphBuilder::IsDistTrainOp( + const OpDesc &op, const std::vector<std::string> &send_vars, + const std::vector<std::string> &recv_vars) const { + if (send_vars.size() == 0 || recv_vars.size() == 0) { return false; } @@ -89,21 +118,23 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, * Check any of opvars contains `.block` and in sendvars */ auto checker = [](const std::vector<std::string> &opvars, - const std::vector<std::string> &sendvars) -> bool { + const std::vector<std::string> &rpc_vars) -> bool { for (auto &var : opvars) { if (var.find(".block") != std::string::npos && - std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { + std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { return true; } } return false; }; - if (op.Type() == "split" || op.Type() == "split_byref") { - return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); + if (op.Type() == "split" || op.Type() == "split_byref" || + op.Type() == "split_selected_rows") { + return checker(op.OutputArgumentNames(), send_vars); } else if (op.Type() == "concat") { - return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); + return checker(op.InputArgumentNames(), recv_vars); } + return false; } @@ -132,8 +163,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( places_.size()); - // Find "send" op first for split is in front of send. - OpDesc *send_op = GetSendOpDesc(program); + // find send/recv vars so that we can place the distributed training + // realted op in the place 0 + auto send_vars = FindDistTrainSendVars(program); + auto recv_vars = FindDistTrainRecvVars(program); size_t cur_device_id = 0; std::vector<std::unordered_set<std::string>> var_name_on_devices; @@ -147,8 +180,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( // append rpc op if program is distributed trainer main program. // always use the first device CreateRPCOp(&result, *op); - } else if (IsDistTrainOp(*op, send_op)) { - CreateComputationalOps(&result, *op, 1); + } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { + // CreateComputationalOps(&result, *op, 1); + CreateComputationalOp(&result, *op, 0); } else if (IsScaleLossOp(*op)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != @@ -213,9 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::string filename = "/tmp/graph"; - std::ofstream fout(filename); - PrintGraphviz(*graph, fout); + std::ostringstream sout; + PrintGraphviz(*graph, sout); + VLOG(10) << sout.str(); } return std::unique_ptr<SSAGraph>(graph); @@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( } return nullptr; } + void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA @@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, return var; } -void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, - std::string op_name) const { +void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, + const std::string &prev_op_name) const { for (auto &prev_op : result->ops_) { - if (prev_op->Name() == op_name) { + if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(); prev_op->AddOutput(dep_var); result->dep_vars_.emplace(dep_var); - result->ops_.back().get()->AddInput(dep_var); + op->AddInput(dep_var); } } } @@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { auto &p = places_[0]; auto *s = local_scopes_[0]; - VLOG(3) << "create rpc op: " << op.Type(); result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); + if (op.Type() == "send_barrier") { - ConnectOp(result, "send_vars"); + ConnectOp(result, result->ops_.back().get(), "send_vars"); } else if (op.Type() == "recv") { - ConnectOp(result, "send_barrier"); + ConnectOp(result, result->ops_.back().get(), "send_barrier"); } else if (op.Type() == "fetch_barrier") { - ConnectOp(result, "recv"); + ConnectOp(result, result->ops_.back().get(), "recv"); } else if (op.Type() == "send" || op.Type() == "send_vars") { // do nothing } else { @@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, } // FIXME(wuyi): send op always copy from GPU 0 - // result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); // Create inputs for output on original place and no ssa output // is created for send op. CreateOpHandleIOs(result, op, 0); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 45713b0c4f..1d0021c954 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -64,17 +64,25 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; - void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. */ - bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + bool IsDistTrainOp(const OpDesc &op, + const std::vector<std::string> &send_vars, + const std::vector<std::string> &recv_vars) const; + + std::vector<std::string> FindDistTrainSendVars( + const ProgramDesc &program) const; + + std::vector<std::string> FindDistTrainRecvVars( + const ProgramDesc &program) const; bool IsRPCOp(const OpDesc &op) const; - void ConnectOp(SSAGraph *result, std::string op_name) const; + void ConnectOp(SSAGraph *result, OpHandleBase *op, + const std::string &prev_op_name) const; void CreateComputationalOps(SSAGraph *result, const OpDesc &op, size_t num_places) const; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ca0518d4dc..a758205938 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -245,17 +245,11 @@ bool RPCClient::Proceed() { } std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep, const std::string& key) { - VLOG(3) << "this addr: " << this; std::unique_lock<std::mutex> lock(mutex_); auto it = channels_.find(key); if (it != channels_.end()) { - VLOG(3) << "find ep: " << ep; return it->second; } - VLOG(3) << "can not find ep: " << ep; - for (auto it = channels_.begin(); it != channels_.end(); ++it) { - VLOG(3) << "ep: " << it->first; - } grpc::ChannelArguments args; args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 806cc2fcc1..cf7775e8ed 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -373,6 +373,16 @@ class DistributeTranspiler: for i, ep in enumerate(eplist): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + # step4: Concat the parameters splits together after recv. + for varname, splited_var in param_var_mapping.iteritems(): + if len(splited_var) <= 1: + continue + orig_param = program.global_block().vars[varname] + program.global_block().append_op( + type="concat", + inputs={"X": splited_var}, + outputs={"Out": [orig_param]}, + attrs={"axis": 0}) # TODO(Yancey1989): check dist lookup table if self.has_distributed_lookup_table: -- GitLab