提交 0bd7a67e 编写于 作者: X Xin Pan

avoid creating dangling ir::Node.

    Node should be created by Graph::CreateXXX so that
    they are managed by graph.
上级 d9297c12
......@@ -96,8 +96,8 @@ struct TestBroadcastOpHandle {
}
param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(
new ir::Node("node0", ir::Node::Type::kOperation));
std::unique_ptr<ir::Node> n =
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation);
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
......@@ -115,8 +115,8 @@ struct TestBroadcastOpHandle {
#endif
}
std::unique_ptr<ir::Node> v(
new ir::Node("node1", ir::Node::Type::kVariable));
std::unique_ptr<ir::Node> v =
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable);
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
......@@ -124,8 +124,8 @@ struct TestBroadcastOpHandle {
// add dummy var
std::unique_ptr<ir::Node> v2(
new ir::Node("node2", ir::Node::Type::kVariable));
std::unique_ptr<ir::Node> v2 =
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable);
vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
......@@ -136,8 +136,8 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
std::unique_ptr<ir::Node> v3(
new ir::Node("node3", ir::Node::Type::kVariable));
std::unique_ptr<ir::Node> v3 =
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable);
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle);
......@@ -145,8 +145,8 @@ struct TestBroadcastOpHandle {
}
// add dummy var
std::unique_ptr<ir::Node> v4(
new ir::Node("node4", ir::Node::Type::kVariable));
std::unique_ptr<ir::Node> v4 =
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable);
vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
......
......@@ -54,7 +54,6 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<ir::Node>> fetch_nodes;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) {
......@@ -75,9 +74,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
auto &vars = fetched_var_it->second;
fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i,
&local_scopes_);
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
fetch_ops.emplace_back(op);
for (auto &p : places_) {
......@@ -116,9 +115,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
num_complete += num_comp;
}
// Wait FetchOps.
if (!fetch_ops.empty()) {
fetch_ops.clear();
}
ClearFetchOp(graph_.get(), &fetch_ops);
return fetches;
}
void FastThreadedSSAGraphExecutor::RunOpAsync(
......
......@@ -82,13 +82,15 @@ struct TestGatherOpHandle {
}
param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation));
nodes.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable));
nodes.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle);
......@@ -96,7 +98,8 @@ struct TestGatherOpHandle {
}
// add dummy var
nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable));
nodes.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
......@@ -104,14 +107,16 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle);
// add output
nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable));
nodes.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
// add dummy var
nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable));
nodes.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
......
......@@ -19,6 +19,19 @@ namespace framework {
namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph,
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) {
for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var);
}
graph->RemoveNode(op->Node());
}
fetch_ops->clear();
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -36,6 +37,9 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
};
void ClearFetchOp(ir::Graph* graph,
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -69,12 +69,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
&pending_ops, &pending_vars, &ready_vars, &fetch_data);
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
&pending_vars, &ready_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
......@@ -136,9 +135,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
PADDLE_ENFORCE(ready_ops.empty());
// Wait FetchOps.
if (!fetch_ops.empty()) {
fetch_ops.clear();
}
ClearFetchOp(graph_.get(), &fetch_ops);
return fetch_data;
}
......@@ -146,7 +143,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
......@@ -171,9 +167,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto &vars = fetched_var_it->second;
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
&local_scopes_);
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
......@@ -184,8 +180,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var);
}
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
ir::Node *fetch_var =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kVariable);
auto *fetch_dummy = new DummyVarHandle(fetch_var);
op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
......
......@@ -73,7 +73,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
......
......@@ -19,6 +19,11 @@ namespace framework {
namespace ir {
constexpr char Node::kControlDepVarName[];
int Node::count_ = 0;
std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type) {
return std::unique_ptr<Node>(new Node(name, type));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -24,32 +24,12 @@ namespace paddle {
namespace framework {
namespace ir {
// Node should normally created by Graph::CreateXXXNode().
class Node {
public:
enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type)
: name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(count_++) {}
explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr),
type_(Type::kVariable),
id_(count_++) {}
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation),
id_(count_++) {}
Type NodeType() const { return type_; }
std::string Name() const { return name_; }
......@@ -81,11 +61,40 @@ class Node {
private:
friend class Graph;
friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type);
explicit Node(const std::string& name, Type type)
: name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(count_++) {}
explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr),
type_(Type::kVariable),
id_(count_++) {}
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation),
id_(count_++) {}
Node() = delete;
static int count_;
static void ResetId() { count_ = 0; }
DISABLE_COPY_AND_ASSIGN(Node);
};
std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type);
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册