未验证 提交 bc8f4360 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #9634 from reyoung/feature/fix_leaf_ops

Fix Leaf Ops in Graph
......@@ -21,6 +21,9 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
#include <string>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
......@@ -168,6 +171,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/
PolishGraphToSupportDataHazards(&result);
/*
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
PrintGraphviz(*graph, sout);
......
......@@ -136,6 +136,17 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
sout << "}\n";
}
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) {
if (!op->outputs_.empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle();
graph->dep_vars_.emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -14,13 +14,13 @@
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <string>
namespace paddle {
namespace framework {
namespace details {
......@@ -52,6 +52,8 @@ class SSAGraphBuilder {
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph);
static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout);
};
} // namespace details
......
......@@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<DummyVarHandle> dummy_vars;
FeedFetchList fetch_data(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
......@@ -101,13 +100,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
}
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
fetch_ops.emplace_back(op);
// FIXME: Use new device context
for (auto &p : places_) {
op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
}
......@@ -115,6 +114,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto *var : vars) {
op->AddInput(var);
}
auto *fetch_dummy = new DummyVarHandle();
op->AddOutput(fetch_dummy);
fetch_dependencies.emplace(fetch_dummy);
InsertPendingVar(*fetch_dummy);
InsertPendingOp(*op);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册