未验证 提交 23ba7662 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13475 from panyx0718/ir5

avoid creating dangling ir::Node.
...@@ -96,8 +96,8 @@ struct TestBroadcastOpHandle { ...@@ -96,8 +96,8 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n( std::unique_ptr<ir::Node> n =
new ir::Node("node0", ir::Node::Type::kOperation)); ir::CreateNodeForTest("node0", ir::Node::Type::kOperation);
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
...@@ -115,8 +115,8 @@ struct TestBroadcastOpHandle { ...@@ -115,8 +115,8 @@ struct TestBroadcastOpHandle {
#endif #endif
} }
std::unique_ptr<ir::Node> v( std::unique_ptr<ir::Node> v =
new ir::Node("node1", ir::Node::Type::kVariable)); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable);
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]); gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
...@@ -124,8 +124,8 @@ struct TestBroadcastOpHandle { ...@@ -124,8 +124,8 @@ struct TestBroadcastOpHandle {
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v2( std::unique_ptr<ir::Node> v2 =
new ir::Node("node2", ir::Node::Type::kVariable)); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable);
vars_.emplace_back(new DummyVarHandle(v2.get())); vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
...@@ -136,8 +136,8 @@ struct TestBroadcastOpHandle { ...@@ -136,8 +136,8 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
std::unique_ptr<ir::Node> v3( std::unique_ptr<ir::Node> v3 =
new ir::Node("node3", ir::Node::Type::kVariable)); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable);
VarHandle* out_var_handle = VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
...@@ -145,8 +145,8 @@ struct TestBroadcastOpHandle { ...@@ -145,8 +145,8 @@ struct TestBroadcastOpHandle {
} }
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v4( std::unique_ptr<ir::Node> v4 =
new ir::Node("node4", ir::Node::Type::kVariable)); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable);
vars_.emplace_back(new DummyVarHandle(v4.get())); vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
......
...@@ -54,7 +54,6 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -54,7 +54,6 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<ir::Node>> fetch_nodes;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
...@@ -75,9 +74,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -75,9 +74,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
auto &vars = fetched_var_it->second; auto &vars = fetched_var_it->second;
fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); ir::Node *fetch_node =
auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i, graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
&local_scopes_); auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
fetch_ops.emplace_back(op); fetch_ops.emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
...@@ -116,9 +115,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -116,9 +115,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
num_complete += num_comp; num_complete += num_comp;
} }
// Wait FetchOps. // Wait FetchOps.
if (!fetch_ops.empty()) { ClearFetchOp(graph_.get(), &fetch_ops);
fetch_ops.clear();
}
return fetches; return fetches;
} }
void FastThreadedSSAGraphExecutor::RunOpAsync( void FastThreadedSSAGraphExecutor::RunOpAsync(
......
...@@ -82,13 +82,15 @@ struct TestGatherOpHandle { ...@@ -82,13 +82,15 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation)); nodes.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_.reset( op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable)); nodes.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
...@@ -96,7 +98,8 @@ struct TestGatherOpHandle { ...@@ -96,7 +98,8 @@ struct TestGatherOpHandle {
} }
// add dummy var // add dummy var
nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable)); nodes.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
...@@ -104,14 +107,16 @@ struct TestGatherOpHandle { ...@@ -104,14 +107,16 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable)); nodes.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]); "out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable)); nodes.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
......
...@@ -19,6 +19,19 @@ namespace framework { ...@@ -19,6 +19,19 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph,
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) {
for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var);
}
graph->RemoveNode(op->Node());
}
fetch_ops->clear();
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -36,6 +37,9 @@ class SSAGraphExecutor { ...@@ -36,6 +37,9 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph,
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -69,12 +69,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -69,12 +69,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
&pending_ops, &pending_vars, &ready_vars, &fetch_data); &pending_vars, &ready_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
...@@ -136,9 +135,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -136,9 +135,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
// Wait FetchOps. // Wait FetchOps.
if (!fetch_ops.empty()) { ClearFetchOp(graph_.get(), &fetch_ops);
fetch_ops.clear();
}
return fetch_data; return fetch_data;
} }
...@@ -146,7 +143,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -146,7 +143,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
...@@ -171,9 +167,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -171,9 +167,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto &vars = fetched_var_it->second; auto &vars = fetched_var_it->second;
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); ir::Node *fetch_node =
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
&local_scopes_); auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
...@@ -184,8 +180,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -184,8 +180,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var); op->AddInput(var);
} }
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); ir::Node *fetch_var =
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); graph_->CreateEmptyNode("fetch", ir::Node::Type::kVariable);
auto *fetch_dummy = new DummyVarHandle(fetch_var);
op->AddOutput(fetch_dummy); op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy); fetch_dependencies->emplace(fetch_dummy);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
......
...@@ -73,7 +73,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -73,7 +73,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void InsertFetchOps( void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
......
...@@ -19,6 +19,11 @@ namespace framework { ...@@ -19,6 +19,11 @@ namespace framework {
namespace ir { namespace ir {
constexpr char Node::kControlDepVarName[]; constexpr char Node::kControlDepVarName[];
int Node::count_ = 0; int Node::count_ = 0;
std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type) {
return std::unique_ptr<Node>(new Node(name, type));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -24,32 +24,12 @@ namespace paddle { ...@@ -24,32 +24,12 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// Node should normally created by Graph::CreateXXXNode().
class Node { class Node {
public: public:
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var"; static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type)
: name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(count_++) {}
explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr),
type_(Type::kVariable),
id_(count_++) {}
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation),
id_(count_++) {}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
std::string Name() const { return name_; } std::string Name() const { return name_; }
...@@ -81,11 +61,40 @@ class Node { ...@@ -81,11 +61,40 @@ class Node {
private: private:
friend class Graph; friend class Graph;
friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type);
explicit Node(const std::string& name, Type type)
: name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(count_++) {}
explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr),
type_(Type::kVariable),
id_(count_++) {}
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation),
id_(count_++) {}
Node() = delete;
static int count_; static int count_;
static void ResetId() { count_ = 0; } static void ResetId() { count_ = 0; }
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
}; };
std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册