From ff8054c5a7f4ea34f6f112c318c03a16adf37e64 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 8 Mar 2019 10:23:54 +0800 Subject: [PATCH] can run --- paddle/fluid/framework/details/async_ssa_graph_executor.cc | 2 ++ paddle/fluid/framework/details/multi_devices_graph_pass.h | 4 ++++ paddle/fluid/operators/distributed_ops/recv_op.cc | 6 ++++++ 3 files changed, 12 insertions(+) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index b36ed8af9ad..12822c64e9f 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -80,6 +80,7 @@ void ProcessGraph(std::vector graphs, Scope *scope) { } } } + /* VLOG(3) << "delete all recv ops"; for (auto *node : nodes_to_delete) { // delete input edge @@ -105,6 +106,7 @@ void ProcessGraph(std::vector graphs, Scope *scope) { VLOG(3) << "delete node " << node->Name(); graphs[i]->RemoveNode(node); } + */ } // init communicator here if (send_varname_to_ctx.size() > 0) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index f7ec9d28de9..0b9061ad603 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -127,6 +127,10 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { bool NeedCollectiveOps() const override { return false; } bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override { + if (node->Op()->Type() == "recv") { + node->Op()->SetAttr("do_not_run", true); + node->Op()->Flush(); + } return false; } diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 680b484d413..afbf7a4a234 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -36,6 +36,11 @@ class RecvOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { + bool do_not_run = Attr("do_not_run"); + if (do_not_run) { + VLOG(3) << "recv do not run!"; + return; + } std::vector epmap = Attr>("epmap"); std::vector varnames = Attr>("varnames"); @@ -126,6 +131,7 @@ This operator can get variables from server side. "(vector) " "the splited parameter varnames to be recved from pserver") .SetDefault(std::vector{}); + AddAttr("do_not_run", "").SetDefault(false); } }; -- GitLab