未验证 提交 bc8f4360 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #9634 from reyoung/feature/fix_leaf_ops

Fix Leaf Ops in Graph
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif #endif
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -168,6 +171,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -168,6 +171,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/ */
PolishGraphToSupportDataHazards(&result); PolishGraphToSupportDataHazards(&result);
/*
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
std::ostringstream sout; std::ostringstream sout;
PrintGraphviz(*graph, sout); PrintGraphviz(*graph, sout);
......
...@@ -136,6 +136,17 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { ...@@ -136,6 +136,17 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
sout << "}\n"; sout << "}\n";
} }
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) {
if (!op->outputs_.empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle();
graph->dep_vars_.emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#pragma once #pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include <memory>
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -52,6 +52,8 @@ class SSAGraphBuilder { ...@@ -52,6 +52,8 @@ class SSAGraphBuilder {
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, size_t place_offset); const platform::Place &place, size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph);
static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout); static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout);
}; };
} // namespace details } // namespace details
......
...@@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<DummyVarHandle> dummy_vars;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
...@@ -101,13 +100,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -101,13 +100,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
} }
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name); auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_); auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
fetch_ops.emplace_back(op); fetch_ops.emplace_back(op);
// FIXME: Use new device context
for (auto &p : places_) { for (auto &p : places_) {
op->dev_ctxes_[p] = fetch_ctxs_.Get(p); op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
} }
...@@ -115,6 +114,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -115,6 +114,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto *var : vars) { for (auto *var : vars) {
op->AddInput(var); op->AddInput(var);
} }
auto *fetch_dummy = new DummyVarHandle();
op->AddOutput(fetch_dummy);
fetch_dependencies.emplace(fetch_dummy);
InsertPendingVar(*fetch_dummy);
InsertPendingOp(*op); InsertPendingOp(*op);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册