diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index c277bd7cb69bba899296efe64107ee538c4aa847..128a5344fbb8c64c36ade24475bd0d99bdb3e0f5 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -21,6 +21,9 @@ #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #endif +#include +#include + namespace paddle { namespace framework { namespace details { @@ -168,6 +171,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( */ PolishGraphToSupportDataHazards(&result); + /* + * Only variables should be the leaves of graph. + */ + AddOutputToLeafOps(&result); + if (VLOG_IS_ON(10)) { std::ostringstream sout; PrintGraphviz(*graph, sout); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 361ba6d39721eed406a30fea325b3b4508ec45d0..0a4febd22f3feefdcac99cafc2cb58269380d192 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -136,6 +136,17 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { sout << "}\n"; } + +void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { + for (auto &op : graph->ops_) { + if (!op->outputs_.empty()) { + continue; + } + auto *dummy_leaf = new DummyVarHandle(); + graph->dep_vars_.emplace(dummy_leaf); + op->AddOutput(dummy_leaf); + } +} } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index bf20e7164a100718c1dcfe3ef971cfff60bbbaa2..be1f0460e45402806b18835f054a7195df1374cc 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -14,13 +14,13 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" -#include -#include - namespace paddle { namespace framework { namespace details { @@ -52,6 +52,8 @@ class SSAGraphBuilder { const std::string &each_var_name, const platform::Place &place, size_t place_offset); + static void AddOutputToLeafOps(SSAGraph *graph); + static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout); }; } // namespace details diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 1f96b9dc6235a18f7566c98cca60baa964e6aa56..596e5731868630cebc3cf51b2e78d4deb39a9b33 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Step 2. Insert FetchOps std::vector> fetch_ops; - std::vector dummy_vars; FeedFetchList fetch_data(fetch_tensors.size()); std::unordered_map> fetched_vars; @@ -101,13 +100,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } } + std::unordered_set> fetch_dependencies; for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_); fetch_ops.emplace_back(op); - // FIXME: Use new device context for (auto &p : places_) { op->dev_ctxes_[p] = fetch_ctxs_.Get(p); } @@ -115,6 +114,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (auto *var : vars) { op->AddInput(var); } + + auto *fetch_dummy = new DummyVarHandle(); + op->AddOutput(fetch_dummy); + fetch_dependencies.emplace(fetch_dummy); + InsertPendingVar(*fetch_dummy); InsertPendingOp(*op); }