未验证 提交 ff6c809b 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14251 from panyx0718/fix

Make OpHandle/VarHandle and ir::Node works cleaner
......@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_;
Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_;
OpHandleBase* op_handle_;
std::vector<VarHandleBase*> vars_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
std::vector<p::Place> place_list_;
bool use_gpu_;
#ifdef PADDLE_WITH_CUDA
......@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
}
void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear();
for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
......@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
}
param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n =
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation);
nodes_.emplace_back(
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_,
place_list_, nccl_ctxs_.get()));
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get());
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_,
place_list_, nccl_ctxs_.get()));
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get());
#else
op_handle_.reset(
new BroadcastOpHandle(n.get(), local_scopes_, place_list_));
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_);
#endif
}
std::unique_ptr<ir::Node> v =
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable);
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
place_list_[input_scope_idx]);
nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
"input", place_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add dummy var
std::unique_ptr<ir::Node> v2 =
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable);
vars_.emplace_back(new DummyVarHandle(v2.get()));
nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
static_cast<DummyVarHandle*>(vars_.back());
dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle);
......@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) {
op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get());
}
std::unique_ptr<ir::Node> v3 =
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable);
nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", place_list_[j]);
new VarHandle(nodes_.back().get(), 2, j, "out", place_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
// add dummy var
std::unique_ptr<ir::Node> v4 =
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable);
vars_.emplace_back(new DummyVarHandle(v4.get()));
nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
static_cast<DummyVarHandle*>(vars_.back());
out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle);
}
......
......@@ -16,6 +16,7 @@
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps
fetch_ctxs_(places) {
auto &ops = graph_->Get<details::GraphOps>("ops");
for (auto &op : ops) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op.get(), dep);
op_deps_.emplace(op, dep);
if (dep == 0) {
bootstrap_ops_.emplace_back(op.get());
bootstrap_ops_.emplace_back(op);
}
}
......@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<FetchOpHandle *> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
}
}
}
......@@ -110,8 +109,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
complete_q->Pop();
}
}
if (exception_.IsCaught()) {
ClearFetchOp(graph_.get(), &fetch_ops);
exception_.ReThrow();
}
}
num_complete += num_comp;
}
// Wait FetchOps.
......
......@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
offset_(offset),
local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() {
for (auto *input_var : inputs_) {
input_var->RemoveOutput(this, this->Node());
}
}
FetchOpHandle::~FetchOpHandle() {}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
......
......@@ -22,8 +22,10 @@ namespace details {
struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
std::vector<std::string> out_varnames_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear();
// initialize scope and var
for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope()));
......@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
}
// create op handle node
std::unique_ptr<ir::Node> n =
ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation);
nodes_.emplace_back(
ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get()));
op_handle_ = new FusedBroadcastOpHandle(
nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else
PADDLE_THROW("CUDA is not supported.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get()));
op_handle_ = new FusedBroadcastOpHandle(
nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else
op_handle_.reset(
new FusedBroadcastOpHandle(n.get(), local_scopes_, place_list_));
op_handle_ = new FusedBroadcastOpHandle(nodes_.back().get(),
local_scopes_, place_list_);
#endif
}
for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle
std::unique_ptr<ir::Node> in_node =
ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable);
nodes_.emplace_back(
ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable));
VarHandle* in_var_handle =
new VarHandle(in_node.get(), 1, input_scope_idxes[i], "in_var" + i,
place_list_[input_scope_idxes[i]]);
new VarHandle(nodes_.back().get(), 1, input_scope_idxes[i],
"in_var" + i, place_list_[input_scope_idxes[i]]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add output var handle
for (size_t j = 0; j < place_list_.size(); ++j) {
std::unique_ptr<ir::Node> out_node =
ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable);
VarHandle* out_var_handle =
new VarHandle(out_node.get(), 2, j, "out_var" + i, place_list_[j]);
nodes_.emplace_back(
ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable));
VarHandle* out_var_handle = new VarHandle(
nodes_.back().get(), 2, j, "out_var" + i, place_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
......
......@@ -31,9 +31,10 @@ struct TestGatherOpHandle {
std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_;
Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_;
OpHandleBase* op_handle_;
std::vector<VarHandleBase*> vars_;
std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) {
......@@ -70,7 +71,7 @@ struct TestGatherOpHandle {
}
void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
nodes_.clear();
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
......@@ -82,44 +83,45 @@ struct TestGatherOpHandle {
}
param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back(
nodes_.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
op_handle_ =
new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_);
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back(
nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
new VarHandle(nodes_.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
// add dummy var
nodes.emplace_back(
nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
static_cast<DummyVarHandle*>(vars_.back());
in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle);
// add output
nodes.emplace_back(
nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]);
auto* out_var_handle =
new VarHandle(nodes_.back().get(), 2, input_scope_idx, "out",
gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
// add dummy var
nodes.emplace_back(
nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
static_cast<DummyVarHandle*>(vars_.back());
op_handle_->AddOutput(dummy_var_handle);
}
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
insert_pending_var(version_pair);
}
}
}
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get());
insert_pending_var(var);
}
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
ready_ops.insert(op);
} else {
pending_ops.insert({op.get(), op.get()->NoDupInputSize()});
pending_ops.insert({op, op->NoDupInputSize()});
}
}
......@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
.RequireGraphAttr(paddle::framework::details::kGraphDepVars);
......@@ -34,7 +34,14 @@
namespace paddle {
namespace framework {
namespace details {
namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
void PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
......@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
}
var_holder.emplace_back(var);
} else {
var = var_holder.rbegin()->get();
var = *var_holder.rbegin();
}
return var;
}
......@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node,
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
......@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps);
result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training
// related op in the place 0
......@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
bool is_forwarding = true;
bool is_dist_train = false;
std::unordered_map<std::string, int> sharded_var_device;
for (ir::Node *node : sorted_ops) {
if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(&result, node);
int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") {
......@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} else if (boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node);
int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
......@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block.
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(result, node);
int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(n->Name(), op_dev_id);
sharded_var_device.emplace(n->Name(), op_dev_id);
}
} else {
// This op runs on all devices, and its output may have parameter's
......@@ -398,8 +405,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id);
sharded_var_device.emplace(g_name, cur_device_id);
if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name);
}
......@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
PADDLE_ENFORCE(!ir::HasCircle(result));
result.Erase<GraphOps>(kGraphOps);
return graph;
}
......@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
auto *in =
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get();
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back();
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
......@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
for (auto &p_name : bcast_varnames[dev_id]) {
auto *in =
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get();
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back();
op_handle->AddInput(in);
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
auto &p = places_[out_dev_id];
......@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
......@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
op_handle->AddInput(prev_grad);
auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
......@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get());
op_handle->AddInput(vars.back());
auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p);
......@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
}
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
ir::Node *node) const {
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
......@@ -631,15 +638,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(graph, param_grad[1]);
int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
const std::string &varname) const {
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice);
int MultiDevSSAGraphBuilder::GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second;
}
......@@ -690,7 +697,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
......@@ -698,7 +705,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
op_handle->AddInput(prev_grad);
}
auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
auto var =
......@@ -709,8 +716,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var;
}
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const {
int MultiDevSSAGraphBuilder::CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1;
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
......@@ -725,23 +733,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
}
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
} else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
} else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
......@@ -759,14 +766,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
}
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (ir::Node *input : node->inputs) {
VarHandle *var = nullptr;
for (int place_offset = 0; place_offset < num_places; ++place_offset) {
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[input->Name()];
if (!var_holder.empty()) {
var = var_holder.rbegin()->get();
var = *var_holder.rbegin();
op_handle->AddInput(var);
}
}
......@@ -774,12 +781,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
// Create RPC related op handles that connects its in ops and out ops.
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const {
int MultiDevSSAGraphBuilder::CreateRPCOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1;
if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name());
op_dev_id =
GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device);
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by
......@@ -797,11 +806,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(send_param_grad[1], op_dev_id);
sharded_var_device->emplace(send_param_grad[1], op_dev_id);
}
} else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names;
......@@ -811,7 +818,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
op_dev_id =
GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device);
VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id;
......@@ -819,8 +827,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
op_dev_id = GetAppropriateDeviceID(output_var_names);
}
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
} else {
// send_barrier, fetch_barrier will run on place 0;
......@@ -839,7 +846,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from
// all places
auto p = places_[op_dev_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
......@@ -847,7 +854,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(*result, output->Name());
outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
}
p = places_[outvar_dev_id];
......
......@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const;
int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
int CreateRPCOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const;
......@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const;
int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>(kGraphOps)) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......
......@@ -35,23 +35,14 @@ namespace details {
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
// `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
typedef std::unordered_set<VarHandleBase*> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -20,19 +20,16 @@ namespace paddle {
namespace framework {
namespace details {
OpGraphView::OpGraphView(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
preceding_ops_[op];
pending_ops_[op];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
preceding_ops_[pending_op].insert(op);
pending_ops_[op].insert(pending_op);
}
}
}
......@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
"There are duplicate ops in graph.");
}
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
......@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
......
......@@ -26,21 +26,16 @@ namespace details {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
size_t OpNumber() const;
explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PrecedingOps(
OpHandleBase *op) const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
private:
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
......
......@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node.
class OpHandleBase {
public:
explicit OpHandleBase(ir::Node *node) : node_(node) {}
// Owned by `node`. No need to be deleted explicitly.
explicit OpHandleBase(ir::Node *node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~OpHandleBase();
......
......@@ -30,8 +30,8 @@ struct TestReduceOpHandle {
Scope g_scope_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> param_scopes_;
std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_;
OpHandleBase *op_handle_;
std::vector<VarHandleBase *> vars_;
std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
......
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -71,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>>
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) {
OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op;
......@@ -121,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
};
auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) {
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
......@@ -151,21 +150,21 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
}
}
}
}
};
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
......@@ -173,7 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
}
for (auto &op : all_ops) {
......@@ -181,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
update_ref_cnts_from_non_compute_op(op, op->Outputs());
}
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops;
std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
......
......@@ -19,14 +19,16 @@ namespace framework {
namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph,
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) {
for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var);
}
for (auto& in_var : op->Inputs()) {
in_var->RemoveOutput(op, op->Node());
}
graph->RemoveNode(op->Node());
}
fetch_ops->clear();
......
......@@ -38,8 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
};
void ClearFetchOp(ir::Graph* graph,
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, ready_vars.get(), version_pair.get());
InsertPendingVar(&pending_vars, ready_vars.get(), version_pair);
}
}
}
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var.get());
InsertPendingVar(&pending_vars, ready_vars.get(), var);
}
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
ready_ops.insert(op);
} else {
InsertPendingOp(&pending_ops, op.get());
InsertPendingOp(&pending_ops, op);
}
}
// Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
std::vector<FetchOpHandle *> fetch_ops;
std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
......@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &run_op_future : run_op_futures_) {
run_op_future.wait();
}
ClearFetchOp(graph_.get(), &fetch_ops);
exception_holder_.ReThrow();
} else {
continue;
......@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
......@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
}
}
}
......
......@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
BlockingQueue<VarHandleBase *> *ready_vars,
VarHandleBase *var) const;
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data);
BlockingQueue<VarHandleBase *> *ready_vars,
FeedFetchList *fetch_data);
private:
ExecutionStrategy strategy_;
......
......@@ -20,6 +20,8 @@ namespace details {
VarHandleBase::~VarHandleBase() {}
VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const {
std::stringstream ss;
ss << name_ << ":" << place_;
......@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const {
}
std::string DummyVarHandle::DebugString() const { return node_->Name(); }
DummyVarHandle::~DummyVarHandle() {
VLOG(4) << "deleting dummy var handle " << DebugString();
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -35,7 +35,10 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {}
// Owned by `node`. No need to be deleted explicitly.
explicit VarHandleBase(ir::Node* node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~VarHandleBase();
......@@ -94,6 +97,8 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~VarHandle();
std::string DebugString() const override;
VarHandle(ir::Node* node, size_t version, size_t scope_index,
......@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase {
struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~DummyVarHandle();
std::string DebugString() const override;
};
......
......@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
......
......@@ -102,6 +102,15 @@ class Graph {
attr_dels_[attr_name] = []() {};
}
template <typename AttrType>
void Erase(const std::string &attr_name) {
PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph",
attr_name);
attr_dels_[attr_name]();
attrs_.erase(attr_name);
attr_dels_.erase(attr_name);
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc.
......
......@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph);
template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
}
return ret;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -15,7 +15,10 @@ limitations under the License. */
#pragma once
#include <string>
#include <typeindex>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
......@@ -24,9 +27,33 @@ namespace paddle {
namespace framework {
namespace ir {
// Node should normally created by Graph::CreateXXXNode().
// Node should only created by Graph::CreateXXXNode().
// 1. Every Node should be part of a graph. No dangling Node exists.
// 2. Node only contains members necessary for building graph structure.
// It doesn't contain other unrelated members, such as device, etc.
//
// Sometimes, for specific usages, Node needs to have additional members,
// such as device_placement, version in order to be executed. It is suggested
// to use composition pattern.
//
// class RunnableOp {
// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); }
//
// int any_thing_;
// }
//
// RunnableOp is owned by the ir::Node that composes it. In other words.
// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node
// is deleted from the graph.
class Node {
public:
virtual ~Node() {
if (!wrapper_.empty()) {
VLOG(4) << "ir::Node deleting a wrapper node " << Name();
wrapper_deleter_();
}
}
enum class Type { kOperation, kVariable };
#if !defined(_WIN32) // msvc not support constexpr correctly.
static constexpr char kControlDepVarName[] = "__control_var";
......@@ -48,6 +75,29 @@ class Node {
return op_desc_.get();
}
// Set the `wrapper` that wraps the Node. `wrapper` is owned by Node.
template <typename T>
void WrappedBy(T* wrapper) {
if (!wrapper_.empty()) {
wrapper_deleter_();
}
wrapper_ = wrapper;
wrapper_deleter_ = [wrapper]() { delete wrapper; };
wrapper_type_ = std::type_index(typeid(T));
}
// Return a reference to the `wrapper`.
template <typename T>
T& Wrapper() {
return *boost::any_cast<T*>(wrapper_);
}
// Test if the Node is wrapped by type T.
template <typename T>
bool IsWrappedBy() {
return std::type_index(typeid(T)) == wrapper_type_;
}
// Please don't use this API!
int id() const { return id_; }
......@@ -99,6 +149,11 @@ class Node {
static int count_;
// Please don't use this API or make this public.
static void ResetId() { count_ = 0; }
boost::any wrapper_;
std::function<void(void)> wrapper_deleter_;
std::type_index wrapper_type_ = std::type_index(typeid(void));
DISABLE_COPY_AND_ASSIGN(Node);
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RunnableOp {
public:
RunnableOp(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
class RunnableOp2 {
public:
RunnableOp2(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp2() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
TEST(NodeTest, Basic) {
bool alive1 = true;
bool alive2 = true;
std::unique_ptr<Node> n1(CreateNodeForTest("n1", Node::Type::kVariable));
std::unique_ptr<Node> n2(CreateNodeForTest("n2", Node::Type::kVariable));
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp2>());
new RunnableOp(n1.get(), &alive1);
new RunnableOp2(n2.get(), &alive2);
EXPECT_TRUE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_TRUE(n2->IsWrappedBy<RunnableOp2>());
EXPECT_TRUE(alive1);
EXPECT_TRUE(alive2);
n1.reset(nullptr);
n2.reset(nullptr);
EXPECT_FALSE(alive1);
EXPECT_FALSE(alive2);
}
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册