From 27533b64237528e0de0166b45a322d4ab6fee276 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 4 Apr 2018 12:56:32 +0800 Subject: [PATCH] Fix Leaf Ops in Graph All leaves must be variables. When all variables are ready, the execution will be completed. If a operator has no output, the `Op::Run` might not be started when the execution of graph has been complete. --- .../framework/details/multi_devices_graph_builder.cc | 8 ++++++++ paddle/fluid/framework/details/ssa_graph_builder.cc | 11 +++++++++++ paddle/fluid/framework/details/ssa_graph_builder.h | 8 +++++--- .../framework/details/threaded_ssa_graph_executor.cc | 8 ++++++-- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index c277bd7cb..128a5344f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -21,6 +21,9 @@ #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #endif +#include +#include + namespace paddle { namespace framework { namespace details { @@ -168,6 +171,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( */ PolishGraphToSupportDataHazards(&result); + /* + * Only variables should be the leaves of graph. + */ + AddOutputToLeafOps(&result); + if (VLOG_IS_ON(10)) { std::ostringstream sout; PrintGraphviz(*graph, sout); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 361ba6d39..0a4febd22 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -136,6 +136,17 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { sout << "}\n"; } + +void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { + for (auto &op : graph->ops_) { + if (!op->outputs_.empty()) { + continue; + } + auto *dummy_leaf = new DummyVarHandle(); + graph->dep_vars_.emplace(dummy_leaf); + op->AddOutput(dummy_leaf); + } +} } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index bf20e7164..be1f0460e 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -14,13 +14,13 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" -#include -#include - namespace paddle { namespace framework { namespace details { @@ -52,6 +52,8 @@ class SSAGraphBuilder { const std::string &each_var_name, const platform::Place &place, size_t place_offset); + static void AddOutputToLeafOps(SSAGraph *graph); + static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout); }; } // namespace details diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 1f96b9dc6..596e57318 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Step 2. Insert FetchOps std::vector> fetch_ops; - std::vector dummy_vars; FeedFetchList fetch_data(fetch_tensors.size()); std::unordered_map> fetched_vars; @@ -101,13 +100,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } } + std::unordered_set> fetch_dependencies; for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_); fetch_ops.emplace_back(op); - // FIXME: Use new device context for (auto &p : places_) { op->dev_ctxes_[p] = fetch_ctxs_.Get(p); } @@ -115,6 +114,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (auto *var : vars) { op->AddInput(var); } + + auto *fetch_dummy = new DummyVarHandle(); + op->AddOutput(fetch_dummy); + fetch_dependencies.emplace(fetch_dummy); + InsertPendingVar(*fetch_dummy); InsertPendingOp(*op); } -- GitLab