diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index fb5b8608b312ddd78da122dfd6cfba1ea486f91c..52e691a617d0e4ab0e428e9aff2a143d23b87544 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include #include #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" @@ -181,8 +182,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // always use the first device CreateRPCOp(&result, *op); } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { - // CreateComputationalOps(&result, *op, 1); - CreateComputationalOp(&result, *op, 0); + CreateDistTrainOp(&result, *op); } else if (IsScaleLossOp(*op)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != @@ -247,9 +247,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::ostringstream sout; - PrintGraphviz(*graph, sout); - VLOG(10) << sout.str(); + std::ofstream fout("/tmp/graph.dot"); + PrintGraphviz(*graph, fout); } return std::unique_ptr(graph); @@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, } } +void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, + const OpDesc &op) const { + CreateComputationalOp(result, op, 0); + if (op.Type() == "concat") { + ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); + } +} + void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { auto &p = places_[0]; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 1d0021c9542d222eb075c6c843bd3268e89636f5..cef21e4650f7fed869baad6582aca06ee36d68b6 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -65,6 +65,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; + void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 7ca3c20c7d2b7c00dd1ac66432f9b79dd666987b..1255ed4c49bbbd8c743d18c4fc1fedd6fc34ae0b 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -38,6 +38,7 @@ class RecvOp : public framework::OperatorBase { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); auto client_var_name = Output("RPCClient"); + int sync_recv = Attr("sync_recv"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -54,7 +55,9 @@ class RecvOp : public framework::OperatorBase { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - PADDLE_ENFORCE(rpc_client->Wait()); + if (sync_recv) { + PADDLE_ENFORCE(rpc_client->Wait()); + } } }; @@ -75,6 +78,10 @@ This operator can get variables from server side. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); + AddAttr("sync_recv", + "(int, default 0)" + "sync recv or async recv.") + .SetDefault(0); } }; diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index 3caceba4e9c68912f05de66fe9139cad1aad6d3c..8d5b5f4292a73407ea55c2811d8020b6e89cd262 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -50,9 +50,7 @@ class SendVarsOp : public framework::OperatorBase { "Can not find variable '%s' in the scope.", client_var_name); auto* client_var = scope.FindVar(client_var_name); - VLOG(3) << "client var addr: " << client_var; detail::RPCClient* rpc_client = client_var->GetMutable(); - VLOG(3) << "rpc_client addr: " << rpc_client; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index cf7775e8ed50db16237027f7735d1ce55f2db51e..e6a4e64e7f0b89667d962d1df37f8b5a0006d1e7 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -357,12 +357,35 @@ class DistributeTranspiler: ps_dispatcher.reset() eplist = ps_dispatcher.dispatch(recv_vars) - program.global_block().append_op( - type="recv", - inputs={}, - outputs={"Out": recv_vars, - "RPCClient": rpc_client_var}, - attrs={"epmap": eplist}) + #program.global_block().append_op( + # type="recv", + # inputs={}, + # outputs={"Out": recv_vars, + # "RPCClient": rpc_client_var}, + # attrs={"epmap": eplist}) + + #program.global_block().append_op( + # type="fetch_barrier", + # inputs={}, + # outputs={"RPCClient": rpc_client_var}, + # attrs={"endpoints": pserver_endpoints}) + + for i, ep in enumerate(eplist): + self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) + self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + # step4: Concat the parameters splits together after recv. + for varname, splited_var in param_var_mapping.iteritems(): + eps = [] + for var in splited_var: + index = [v.name for v in recv_vars].index(var.name) + eps.append(eplist[index]) + + program.global_block().append_op( + type="recv", + inputs={}, + outputs={"Out": splited_var, + "RPCClient": rpc_client_var}, + attrs={"epmap": eps}) program.global_block().append_op( type="fetch_barrier", @@ -370,10 +393,6 @@ class DistributeTranspiler: outputs={"RPCClient": rpc_client_var}, attrs={"endpoints": pserver_endpoints}) - for i, ep in enumerate(eplist): - self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) - self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) - # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: continue