diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index b36ed8af9ad151d379a175f6d2046f2f98f0259f..12822c64e9f7ce20ebd9d1ac3c7479396cb7ea2f 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -80,6 +80,7 @@ void ProcessGraph(std::vector graphs, Scope *scope) { } } } + /* VLOG(3) << "delete all recv ops"; for (auto *node : nodes_to_delete) { // delete input edge @@ -105,6 +106,7 @@ void ProcessGraph(std::vector graphs, Scope *scope) { VLOG(3) << "delete node " << node->Name(); graphs[i]->RemoveNode(node); } + */ } // init communicator here if (send_varname_to_ctx.size() > 0) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index f7ec9d28de91e5f3ffd4ed3268c1640c0e8991e6..0b9061ad603d5e8656ab750579e700675a56b578 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -127,6 +127,10 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { bool NeedCollectiveOps() const override { return false; } bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override { + if (node->Op()->Type() == "recv") { + node->Op()->SetAttr("do_not_run", true); + node->Op()->Flush(); + } return false; } diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 680b484d413207595369dc665a1c64f115e4ae4b..afbf7a4a234a8c84912142b748907dc011623ebd 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -36,6 +36,11 @@ class RecvOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { + bool do_not_run = Attr("do_not_run"); + if (do_not_run) { + VLOG(3) << "recv do not run!"; + return; + } std::vector epmap = Attr>("epmap"); std::vector varnames = Attr>("varnames"); @@ -126,6 +131,7 @@ This operator can get variables from server side. "(vector) " "the splited parameter varnames to be recved from pserver") .SetDefault(std::vector{}); + AddAttr("do_not_run", "").SetDefault(false); } };