diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 12822c64e9f7ce20ebd9d1ac3c7479396cb7ea2f..5ca676ccdebb360d6e4ecc685766b8d41a43af6e 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -80,33 +80,6 @@ void ProcessGraph(std::vector graphs, Scope *scope) { } } } - /* - VLOG(3) << "delete all recv ops"; - for (auto *node : nodes_to_delete) { - // delete input edge - for (auto *in : node->inputs) { - auto &in_outs = in->outputs; - for (auto iter = in_outs.begin(); iter != in_outs.end();) { - if (*iter == node) { - VLOG(3) << "delete input edge from " << in->Name() << " for " - << node->Name(); - iter = in_outs.erase(iter); - } else { - ++iter; - } - } - } - // delete output edge - for (auto *out : node->outputs) { - PADDLE_ENFORCE_EQ(out->outputs.size(), 0, "%s should have no outputs", - out->Name()); - VLOG(3) << "delete output edge to " << out->Name(); - graphs[i]->RemoveNode(out); - } - VLOG(3) << "delete node " << node->Name(); - graphs[i]->RemoveNode(node); - } - */ } // init communicator here if (send_varname_to_ctx.size() > 0) { diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 67de7b4185b574412292b98ee6ba182cf118a4e6..47688d0ad456873c93e9e7cdc1e550028347b052 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -48,15 +48,14 @@ class SendOp : public framework::OperatorBase { if (send_varnames.size() > 0) { PADDLE_ENFORCE_EQ(ins.size(), 1, ""); - /* - auto send_functor = distributed::ParameterSend(); - auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, - height_sections); - send_functor(rpc_ctx, scope, static_cast(sync_send)); - */ - VLOG(3) << "send " << ins[0]; - distributed::Communicator::GetInstance()->Send(ins[0], scope); - VLOG(3) << "send " << ins[0] << " done"; + if (distributed::Communicator::GetInstance() == nullptr) { + auto send_functor = distributed::ParameterSend(); + auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, + height_sections); + send_functor(rpc_ctx, scope, static_cast(sync_send)); + } else { + distributed::Communicator::GetInstance()->Send(ins[0], scope); + } } else { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();