提交 adf5615e 编写于 作者: X Xin Pan

clean kGraphOp

test=develop
上级 fb576cb5
......@@ -16,6 +16,7 @@
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -32,9 +33,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps
fetch_ctxs_(places) {
auto &ops = graph_->Get<details::GraphOps>("ops");
for (auto &op : ops) {
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep);
if (dep == 0) {
......
......@@ -45,7 +45,9 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
insert_pending_var(var);
}
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
for (ir::Node *node : graph->Nodes()) {
if (!node->IsWrappedBy<OpHandleBase>()) continue;
OpHandleBase *op = &node->Wrapper<OpHandleBase>();
if (op->Inputs().empty()) {
ready_ops.insert(op);
} else {
......@@ -89,5 +91,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps);
.RequireGraphAttr(paddle::framework::details::kGraphDepVars);
......@@ -36,6 +36,11 @@ namespace framework {
namespace details {
namespace {
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
void PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
......@@ -458,7 +463,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
PADDLE_ENFORCE(!ir::HasCircle(result));
result.Erase<GraphOps>(kGraphOps);
return graph;
}
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>(kGraphOps)) {
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......
......@@ -43,11 +43,6 @@ const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase*> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase*> GraphOps;
const char kGraphOps[] = "ops";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -156,7 +157,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
}
};
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
auto all_ops = ir::GetFilteredNodes<OpHandleBase>(*graph);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......@@ -59,7 +60,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
InsertPendingVar(&pending_vars, ready_vars.get(), var);
}
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op);
} else {
......
......@@ -102,6 +102,15 @@ class Graph {
attr_dels_[attr_name] = []() {};
}
template <typename AttrType>
void Erase(const std::string &attr_name) {
PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph",
attr_name);
attr_dels_[attr_name]();
attrs_.erase(attr_name);
attr_dels_.erase(attr_name);
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc.
......
......@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph);
template <typename T>
std::vector<T *> GetFilteredNodes(const Graph &graph) {
std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
}
return ret;
}
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册