提交 37e51443 编写于 作者: X Xin Pan

op compose node and update nodes.

上级 9605fcd1
...@@ -23,10 +23,14 @@ namespace framework { ...@@ -23,10 +23,14 @@ namespace framework {
namespace details { namespace details {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs) const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { : OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
...@@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes, ...@@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
} }
} }
#else #else
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
......
...@@ -30,11 +30,11 @@ namespace details { ...@@ -30,11 +30,11 @@ namespace details {
struct AllReduceOpHandle : public OpHandleBase { struct AllReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
AllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs); const platform::NCCLContextMap *ctxs);
#else #else
AllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
#endif #endif
std::string Name() const override; std::string Name() const override;
......
...@@ -35,10 +35,13 @@ namespace details { ...@@ -35,10 +35,13 @@ namespace details {
struct BroadcastOpHandle : public OpHandleBase { struct BroadcastOpHandle : public OpHandleBase {
public: public:
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs) const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { : OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) { for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
...@@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase {
} }
} }
#else #else
BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
std::string Name() const override; std::string Name() const override;
......
...@@ -96,48 +96,56 @@ struct TestBroadcastOpHandle { ...@@ -96,48 +96,56 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(new ir::Node(ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset( op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); nccl_ctxs_.get()));
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset( op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); nccl_ctxs_.get()));
#else #else
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); op_handle_.reset(
new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_));
#endif #endif
} }
auto* in_var_handle = std::unique_ptr<ir::Node> v(new ir::Node(ir::Node::Type::kVariable));
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add dummy var // add dummy var
vars_.emplace_back(new DummyVarHandle());
std::unique_ptr<ir::Node> v2(new ir::Node(ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
dummy_var_handle->generated_op_ = nullptr; dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); std::unique_ptr<ir::Node> v3(new ir::Node(ir::Node::Type::kVariable));
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
// add dummy var // add dummy var
vars_.emplace_back(new DummyVarHandle()); std::unique_ptr<ir::Node> v4(new ir::Node(ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
out_dummy_var_handle->generated_op_ = nullptr; out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle); op_handle_->AddOutput(out_dummy_var_handle);
} }
......
...@@ -19,9 +19,10 @@ ...@@ -19,9 +19,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc,
platform::Place place) Scope *scope, platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope), scope_(scope),
place_(place) {} place_(place) {}
...@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() { ...@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() {
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait = bool need_wait =
in_var && in_var->generated_op_ && in_var && in_var->GeneratedOp() &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait; return need_wait;
} }
......
...@@ -28,7 +28,7 @@ namespace framework { ...@@ -28,7 +28,7 @@ namespace framework {
namespace details { namespace details {
struct ComputationOpHandle : public OpHandleBase { struct ComputationOpHandle : public OpHandleBase {
public: public:
ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope,
platform::Place place); platform::Place place);
std::string Name() const override; std::string Name() const override;
......
...@@ -22,10 +22,10 @@ namespace details { ...@@ -22,10 +22,10 @@ namespace details {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle::DataBalanceOpHandle( DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs) const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places) { : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) { if (ctxs) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs->DevCtx(p); this->dev_ctxes_[p] = ctxs->DevCtx(p);
...@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle( ...@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle(
} }
#else #else
DataBalanceOpHandle::DataBalanceOpHandle( DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
std::string DataBalanceOpHandle::Name() const { return "data balance"; } std::string DataBalanceOpHandle::Name() const { return "data balance"; }
......
...@@ -30,11 +30,11 @@ namespace details { ...@@ -30,11 +30,11 @@ namespace details {
struct DataBalanceOpHandle : public OpHandleBase { struct DataBalanceOpHandle : public OpHandleBase {
public: public:
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes, DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs); const platform::NCCLContextMap *ctxs);
#else #else
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes, DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
#endif #endif
......
...@@ -21,13 +21,16 @@ namespace paddle { ...@@ -21,13 +21,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset, FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes) std::vector<Scope *> *local_scopes)
: data_(data), offset_(offset), local_scopes_(local_scopes) {} : OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() { FetchOpHandle::~FetchOpHandle() {
for (auto *input_var : inputs_) { for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this); input_var->RemoveOutput(this, this->Node());
} }
} }
...@@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() { ...@@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() {
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) { void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place); auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) { for (auto *input : inputs_) {
if (input->generated_op_) { if (input->GeneratedOp()) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx); input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx);
} }
} }
} }
......
...@@ -28,7 +28,7 @@ namespace details { ...@@ -28,7 +28,7 @@ namespace details {
struct FetchOpHandle : public OpHandleBase { struct FetchOpHandle : public OpHandleBase {
public: public:
FetchOpHandle(FeedFetchList *data, size_t offset, FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes); std::vector<Scope *> *local_scopes);
~FetchOpHandle(); ~FetchOpHandle();
......
...@@ -30,10 +30,12 @@ namespace details { ...@@ -30,10 +30,12 @@ namespace details {
struct FuseVarsOpHandle : public OpHandleBase { struct FuseVarsOpHandle : public OpHandleBase {
public: public:
FuseVarsOpHandle(Scope *local_scope, const platform::Place &place, FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel, const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type) const std::type_index &var_type)
: local_scope_(local_scope), : OpHandleBase(node),
local_scope_(local_scope),
place_(place), place_(place),
inputs_numel_(inputs_numel), inputs_numel_(inputs_numel),
type_(var_type) { type_(var_type) {
......
...@@ -20,9 +20,10 @@ namespace paddle { ...@@ -20,9 +20,10 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes, GatherOpHandle::GatherOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() { void GatherOpHandle::RunImpl() {
if (places_.size() == 1) return; if (places_.size() == 1) return;
......
...@@ -30,7 +30,7 @@ namespace details { ...@@ -30,7 +30,7 @@ namespace details {
struct GatherOpHandle : public OpHandleBase { struct GatherOpHandle : public OpHandleBase {
public: public:
GatherOpHandle(const std::vector<Scope *> &local_scopes, GatherOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
std::string Name() const override; std::string Name() const override;
......
...@@ -70,6 +70,7 @@ struct TestGatherOpHandle { ...@@ -70,6 +70,7 @@ struct TestGatherOpHandle {
} }
void InitGatherOp(size_t input_scope_idx) { void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -81,30 +82,37 @@ struct TestGatherOpHandle { ...@@ -81,30 +82,37 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_)); nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
// add dummy var // add dummy var
vars_.emplace_back(new DummyVarHandle()); nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr; in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
auto* out_var_handle = nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
vars_.emplace_back(new DummyVarHandle()); nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
op_handle_->AddOutput(dummy_var_handle); op_handle_->AddOutput(dummy_var_handle);
......
...@@ -328,12 +328,16 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( ...@@ -328,12 +328,16 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_); auto *op_handle = new BroadcastOpHandle(result->nodes.back().get(),
local_scopes_, places_, nccl_ctxs_);
#else #else
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); auto *op_handle =
new BroadcastOpHandle(result->nodes.back().get(), local_scopes_, places_);
#endif #endif
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
auto *in = auto *in =
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get(); result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in); op_handle->AddInput(in);
...@@ -341,8 +345,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -341,8 +345,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name); auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p); auto *out_var =
new VarHandle(result->nodes.back().get(), vars.size(), i, p_name, p);
vars.emplace_back(out_var); vars.emplace_back(out_var);
op_handle->AddOutput(out_var); op_handle->AddOutput(out_var);
} }
...@@ -351,19 +357,21 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -351,19 +357,21 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
const OpDesc &op, const OpDesc &op,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>("ops").emplace_back( result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); result->Get<GraphOps>("ops").emplace_back(new ComputationOpHandle(
result->nodes.back().get(), op, local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id); CreateOpHandleIOs(result, op, dev_id);
} }
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const { const std::string &og) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
new AllReduceOpHandle(local_scopes_, places_)); result->nodes.back().get(), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -375,7 +383,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -375,7 +383,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
auto var = new VarHandle(vars.size(), i, og, p); result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var = new VarHandle(result->nodes.back().get(), vars.size(), i, og, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -383,12 +392,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -383,12 +392,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
Graph *result, const std::vector<std::string> &datas) const { Graph *result, const std::vector<std::string> &datas) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
new DataBalanceOpHandle(local_scopes_, places_)); result->nodes.back().get(), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -398,7 +408,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -398,7 +408,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
auto &vars = result->Get<GraphVars>("vars")[i][d_name]; auto &vars = result->Get<GraphVars>("vars")[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back().get());
auto var = new VarHandle(vars.size(), i, d_name, p); result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var =
new VarHandle(result->nodes.back().get(), vars.size(), i, d_name, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -452,10 +464,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -452,10 +464,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
auto *communication_dev_ctx = auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif #endif
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
auto *op_handle = auto *op_handle = new ScaleLossGradOpHandle(
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], result->nodes.back().get(), local_scopes_.size(), local_scopes_[i],
places_[i], communication_dev_ctx); places_[i], communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
...@@ -475,8 +487,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, ...@@ -475,8 +487,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(op, s, p)); new ComputationOpHandle(result->nodes.back().get(), op, s, p));
CreateOpHandleIOs(result, op, scope_idx); CreateOpHandleIOs(result, op, scope_idx);
} }
} }
...@@ -484,12 +497,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, ...@@ -484,12 +497,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(
new ReduceOpHandle(local_scopes_, places_)); new ReduceOpHandle(result->nodes.back().get(), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -502,7 +516,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, ...@@ -502,7 +516,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og]; auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var = new VarHandle(result->nodes.back().get(), vars.size(), dst_dev_id,
og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
return var; return var;
...@@ -514,7 +530,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, ...@@ -514,7 +530,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(); result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *dep_var = new DummyVarHandle(result->nodes.back().get());
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var); result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
...@@ -587,8 +604,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ...@@ -587,8 +604,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type()); op.Type());
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle( result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id])); result->Get<GraphOps>("ops").emplace_back(
new RPCOpHandle(result->nodes.back().get(), op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send"); ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
......
...@@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { ...@@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
void OpHandleBase::AddInput(VarHandleBase *in) { void OpHandleBase::AddInput(VarHandleBase *in) {
this->inputs_.emplace_back(in); this->inputs_.emplace_back(in);
in->pending_ops_.insert(this); node_->inputs.push_back(in->Node());
in->AddOutput(this, this->Node());
} }
void OpHandleBase::AddOutput(VarHandleBase *out) { void OpHandleBase::AddOutput(VarHandleBase *out) {
outputs_.emplace_back(out); outputs_.emplace_back(out);
out->generated_op_ = this; node_->outputs.push_back(out->Node());
out->AddInput(this, this->Node());
} }
void OpHandleBase::WaitInputVarGenerated() { void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) { for (auto in_var : inputs_) {
if (NeedWait(in_var)) { if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) { for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second); in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second);
} }
} }
} }
...@@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() { ...@@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() {
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) { for (auto *in : inputs_) {
if (NeedWait(in)) { if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]);
} }
} }
} }
...@@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const { ...@@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const {
} }
bool OpHandleBase::NeedWait(VarHandleBase *in_var) { bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_; return in_var && in_var->GeneratedOp();
} }
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -28,7 +29,7 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; ...@@ -28,7 +29,7 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
class OpHandleBase { class OpHandleBase {
public: public:
OpHandleBase() {} explicit OpHandleBase(ir::Node *node) : node_(node) {}
virtual ~OpHandleBase(); virtual ~OpHandleBase();
...@@ -82,6 +83,8 @@ class OpHandleBase { ...@@ -82,6 +83,8 @@ class OpHandleBase {
size_t NoDummyInputSize() const; size_t NoDummyInputSize() const;
ir::Node *Node() { return node_; }
protected: protected:
void RunAndRecordEvent(const std::function<void()> &callback); void RunAndRecordEvent(const std::function<void()> &callback);
...@@ -90,6 +93,7 @@ class OpHandleBase { ...@@ -90,6 +93,7 @@ class OpHandleBase {
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
ir::Node *node_;
std::vector<VarHandleBase *> inputs_; std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_; std::vector<VarHandleBase *> outputs_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_; std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
......
...@@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
ReduceOpHandle(const std::vector<Scope *> &local_scopes, ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs) const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { : OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) { for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
...@@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase {
} }
} }
#else #else
ReduceOpHandle(const std::vector<Scope *> &local_scopes, ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
std::string Name() const override; std::string Name() const override;
......
...@@ -84,6 +84,7 @@ struct TestReduceOpHandle { ...@@ -84,6 +84,7 @@ struct TestReduceOpHandle {
} }
void InitReduceOp(size_t out_scope_idx) { void InitReduceOp(size_t out_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
// init scope // init scope
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
...@@ -96,19 +97,21 @@ struct TestReduceOpHandle { ...@@ -96,19 +97,21 @@ struct TestReduceOpHandle {
} }
param_scopes_[out_scope_idx]->Var("out"); param_scopes_[out_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset( op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); gpu_list_, nccl_ctxs_.get()));
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset( op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); gpu_list_, nccl_ctxs_.get()));
#else #else
op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_)); op_handle_.reset(
new ReduceOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
#endif #endif
} }
...@@ -118,8 +121,10 @@ struct TestReduceOpHandle { ...@@ -118,8 +121,10 @@ struct TestReduceOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
in_var_handle->generated_op_ = nullptr; auto *in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
in_var_handle->ClearGeneratedOp();
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
...@@ -128,12 +133,13 @@ struct TestReduceOpHandle { ...@@ -128,12 +133,13 @@ struct TestReduceOpHandle {
vars_.emplace_back(new DummyVarHandle()); vars_.emplace_back(new DummyVarHandle());
DummyVarHandle *in_dummy_var_handle = DummyVarHandle *in_dummy_var_handle =
static_cast<DummyVarHandle *>(vars_.back().get()); static_cast<DummyVarHandle *>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr; in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
auto *out_var_handle = nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]); auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx,
"out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
......
...@@ -18,10 +18,11 @@ namespace paddle { ...@@ -18,10 +18,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
const Scope *local_scope, const std::string &name, const Scope *local_scope, const std::string &name,
const platform::Place &place) const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope), local_scope_(local_scope),
name_(name), name_(name),
place_(place) {} place_(place) {}
...@@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() { ...@@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK if (in->DebugString() == "dummy") { // HACK
continue; continue;
} }
if (in->generated_op_) { if (in->GeneratedOp()) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]);
} }
} }
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
...@@ -28,8 +28,9 @@ namespace framework { ...@@ -28,8 +28,9 @@ namespace framework {
namespace details { namespace details {
struct RPCOpHandle : public OpHandleBase { struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc,
const std::string& name, const platform::Place& place); const Scope* local_scope, const std::string& name,
const platform::Place& place);
std::string Name() const override; std::string Name() const override;
......
...@@ -19,10 +19,14 @@ ...@@ -19,10 +19,14 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope,
platform::Place place, platform::Place place,
platform::DeviceContext *dev_ctx) platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) { : OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {
dev_ctxes_[place_] = dev_ctx; dev_ctxes_[place_] = dev_ctx;
} }
......
...@@ -25,7 +25,8 @@ namespace framework { ...@@ -25,7 +25,8 @@ namespace framework {
namespace details { namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase { struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place, ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *context); platform::DeviceContext *context);
~ScaleLossGradOpHandle() final; ~ScaleLossGradOpHandle() final;
......
...@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
auto it_old = name_pair.second.rbegin(); auto it_old = name_pair.second.rbegin();
++it_old; ++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = (*it_new)->generated_op_; OpHandleBase *write_op = (*it_new)->GeneratedOp();
auto &read_ops = (*it_old)->pending_ops_; const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) { for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op; // Manually add a dependency var from read_op to write_op;
...@@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue; continue;
} }
auto *dep_var = new DummyVarHandle(); graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *dep_var = new DummyVarHandle(graph->nodes.back().get());
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var); graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
...@@ -54,7 +55,9 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -54,7 +55,9 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name]; auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
var = new VarHandle(0, place_offset, each_var_name, place); graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
var = new VarHandle(graph->nodes.back().get(), 0, place_offset,
each_var_name, place);
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
var = var_holder.rbegin()->get(); var = var_holder.rbegin()->get();
...@@ -68,7 +71,9 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, ...@@ -68,7 +71,9 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][each_var_name]; auto &vars = graph->Get<GraphVars>("vars")[place_offset][each_var_name];
size_t version = vars.size(); size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place); graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var = new VarHandle(graph->nodes.back().get(), version, place_offset,
each_var_name, place);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -80,7 +85,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { ...@@ -80,7 +85,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(); graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *dummy_leaf = new DummyVarHandle(graph->nodes.back().get());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
......
...@@ -28,7 +28,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { ...@@ -28,7 +28,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
auto insert_pending_var = [&](VarHandleBase *var) { auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var); pending_vars.insert(var);
if (var->generated_op_ == nullptr) { if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var); ready_vars.emplace(var);
} }
}; };
...@@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { ...@@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
for (auto ready_var : ready_vars) { for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) { for (auto *op : ready_var->PendingOps()) {
auto &deps = --pending_ops[op]; auto &deps = --pending_ops[op];
if (deps == 0) { if (deps == 0) {
ready_ops.insert(op); ready_ops.insert(op);
......
...@@ -65,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -65,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
&pending_vars, &ready_vars, &fetch_data); &pending_ops, &pending_vars, &ready_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
...@@ -126,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -126,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Find the ready_ops after the ready_var. // Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) { for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) { for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op]; auto &deps = pending_ops[op];
--deps; --deps;
if (deps == 0) { if (deps == 0) {
...@@ -152,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -152,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
...@@ -170,7 +172,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -170,7 +172,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name); auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
ir::Node *fetch_n = new ir::Node(ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_);
temp_nodes->emplace_back(fetch_n);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
...@@ -181,9 +186,11 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -181,9 +186,11 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var); op->AddInput(var);
} }
auto *fetch_dummy = new DummyVarHandle(); ir::Node *dummy_n = new ir::Node(ir::Node::Type::kVariable);
auto *fetch_dummy = new DummyVarHandle(dummy_n);
op->AddOutput(fetch_dummy); op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy); fetch_dependencies->emplace(fetch_dummy);
temp_nodes->emplace_back(dummy_n);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
this->InsertPendingOp(pending_ops, op); this->InsertPendingOp(pending_ops, op);
} }
...@@ -199,7 +206,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ...@@ -199,7 +206,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const { BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
pending_vars->insert(var); pending_vars->insert(var);
if (var->generated_op_ == nullptr) { if (var->GeneratedOp() == nullptr) {
ready_vars->Push(var); ready_vars->Push(var);
} }
} }
......
...@@ -72,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -72,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void InsertFetchOps( void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
...@@ -30,15 +32,51 @@ class OpHandleBase; ...@@ -30,15 +32,51 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
// This is a single assignment graph. // This is a single assignment graph.
struct VarHandleBase { struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {}
virtual ~VarHandleBase(); virtual ~VarHandleBase();
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
void AddInput(OpHandleBase* in, ir::Node* node) {
node_->inputs.clear();
node_->inputs.push_back(node);
generated_op_ = in;
}
void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) {
pending_ops_.insert(out);
node_->outputs.push_back(node);
}
}
void RemoveOutput(OpHandleBase* out, ir::Node* node) {
pending_ops_.erase(out);
std::remove(node_->outputs.begin(), node_->outputs.end(), node);
}
void ClearGeneratedOp() {
generated_op_ = nullptr;
node_->inputs.clear();
}
OpHandleBase* GeneratedOp() { return generated_op_; }
const std::unordered_set<OpHandleBase*>& PendingOps() const {
return pending_ops_;
}
ir::Node* Node() { return node_; }
protected:
// The operator who generate this variable. nullptr if the variable // The operator who generate this variable. nullptr if the variable
// is a root node. // is a root node.
OpHandleBase* generated_op_{nullptr}; OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready. // Operators which depend on this variable ready.
std::unordered_set<OpHandleBase*> pending_ops_; std::unordered_set<OpHandleBase*> pending_ops_;
ir::Node* node_;
}; };
// VarHandle is actually a single version of Runtime Variable. // VarHandle is actually a single version of Runtime Variable.
...@@ -47,11 +85,14 @@ struct VarHandleBase { ...@@ -47,11 +85,14 @@ struct VarHandleBase {
// //
// NOTE: runtime variables have place. // NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override; std::string DebugString() const override;
VarHandle(size_t version, size_t scope_index, std::string name, VarHandle(ir::Node* node, size_t version, size_t scope_index,
platform::Place place) std::string name, platform::Place place)
: version_(version), : VarHandleBase(node),
version_(version),
scope_idx_(scope_index), scope_idx_(scope_index),
name_(std::move(name)), name_(std::move(name)),
place_(std::move(place)) {} place_(std::move(place)) {}
...@@ -71,6 +112,8 @@ struct VarHandle : public VarHandleBase { ...@@ -71,6 +112,8 @@ struct VarHandle : public VarHandleBase {
// Dummy Variable. It is used to represent dependencies between operators // Dummy Variable. It is used to represent dependencies between operators
struct DummyVarHandle : public VarHandleBase { struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override; std::string DebugString() const override;
}; };
......
...@@ -14,10 +14,12 @@ limitations under the License. */ ...@@ -14,10 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
...@@ -30,10 +32,10 @@ class Node { ...@@ -30,10 +32,10 @@ class Node {
public: public:
enum class Type { kNone = -1, kOperation, kVariable }; enum class Type { kNone = -1, kOperation, kVariable };
Node(const std::string& name, Type type) : name_(name), type_(type) {} explicit Node(Type type) : type_(type) {}
virtual ~Node() { virtual ~Node() {
for (auto& attr : attrs_) { for (auto &attr : attrs_) {
if (attr_dels_.find(attr.first) != attr_dels_.end()) { if (attr_dels_.find(attr.first) != attr_dels_.end()) {
attr_dels_[attr.first](); attr_dels_[attr.first]();
} }
...@@ -42,54 +44,32 @@ class Node { ...@@ -42,54 +44,32 @@ class Node {
attrs_.clear(); attrs_.clear();
} }
int64_t ID() const { return id_; }
std::string Name() const { return name_; }
virtual std::string ToString() const {
return Name() + "(" + std::to_string(ID()) + ")";
}
virtual std::string DebugString() const = 0;
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
template <typename AttrType> template <typename AttrType>
void Set(const std::string& name, AttrType attr) { void Set(const std::string &name, AttrType attr) {
attrs_[name] = attr; attrs_[name] = attr;
} }
template <typename AttrType> template <typename AttrType>
void Set(const std::string& name, AttrType* attr, void Set(const std::string &name, AttrType *attr,
std::function<void(void)> attr_del) { std::function<void(void)> attr_del) {
attrs_[name] = attr; attrs_[name] = attr;
attr_dels_[name] = attr_del; attr_dels_[name] = attr_del;
} }
std::vector<Node*> inputs; std::vector<Node *> inputs;
std::vector<Node*> outputs; std::vector<Node *> outputs;
protected: protected:
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
int64_t id_ = 0;
std::string name_;
Type type_; Type type_;
private: private:
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
}; };
class Variable : public Node {
public:
explicit Variable(const std::string& name) : Node(name, Type::kVariable) {}
};
class Operation : public Node {
public:
explicit Operation(const std::string& name) : Node(name, Type::kOperation) {}
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册