提交 37e51443 编写于 作者: X Xin Pan

op compose node and update nodes.

上级 9605fcd1
......@@ -23,10 +23,14 @@ namespace framework {
namespace details {
#ifdef PADDLE_WITH_CUDA
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
......@@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
}
}
#else
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
void AllReduceOpHandle::RunImpl() {
......
......@@ -30,11 +30,11 @@ namespace details {
struct AllReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
std::string Name() const override;
......
......@@ -35,10 +35,13 @@ namespace details {
struct BroadcastOpHandle : public OpHandleBase {
public:
#ifdef PADDLE_WITH_CUDA
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
......@@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase {
}
}
#else
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string Name() const override;
......
......@@ -96,48 +96,56 @@ struct TestBroadcastOpHandle {
}
param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(new ir::Node(ir::Node::Type::kOperation));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
nccl_ctxs_.get()));
#else
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
op_handle_.reset(
new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_));
#endif
}
auto* in_var_handle =
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
std::unique_ptr<ir::Node> v(new ir::Node(ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add dummy var
vars_.emplace_back(new DummyVarHandle());
std::unique_ptr<ir::Node> v2(new ir::Node(ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
dummy_var_handle->generated_op_ = nullptr;
dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
std::unique_ptr<ir::Node> v3(new ir::Node(ir::Node::Type::kVariable));
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
// add dummy var
vars_.emplace_back(new DummyVarHandle());
std::unique_ptr<ir::Node> v4(new ir::Node(ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
out_dummy_var_handle->generated_op_ = nullptr;
out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle);
}
......
......@@ -19,9 +19,10 @@
namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc,
Scope *scope, platform::Place place)
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope),
place_(place) {}
......@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() {
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->generated_op_ &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
in_var && in_var->GeneratedOp() &&
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait;
}
......
......@@ -28,7 +28,7 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope,
platform::Place place);
std::string Name() const override;
......
......@@ -22,10 +22,10 @@ namespace details {
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes,
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places) {
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) {
for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs->DevCtx(p);
......@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle(
}
#else
DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes,
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string DataBalanceOpHandle::Name() const { return "data balance"; }
......
......@@ -30,11 +30,11 @@ namespace details {
struct DataBalanceOpHandle : public OpHandleBase {
public:
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
......
......@@ -21,13 +21,16 @@ namespace paddle {
namespace framework {
namespace details {
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset,
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes)
: data_(data), offset_(offset), local_scopes_(local_scopes) {}
: OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() {
for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this);
input_var->RemoveOutput(this, this->Node());
}
}
......@@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() {
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) {
if (input->generated_op_) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
if (input->GeneratedOp()) {
input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx);
}
}
}
......
......@@ -28,7 +28,7 @@ namespace details {
struct FetchOpHandle : public OpHandleBase {
public:
FetchOpHandle(FeedFetchList *data, size_t offset,
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes);
~FetchOpHandle();
......
......@@ -30,10 +30,12 @@ namespace details {
struct FuseVarsOpHandle : public OpHandleBase {
public:
FuseVarsOpHandle(Scope *local_scope, const platform::Place &place,
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type)
: local_scope_(local_scope),
: OpHandleBase(node),
local_scope_(local_scope),
place_(place),
inputs_numel_(inputs_numel),
type_(var_type) {
......
......@@ -20,9 +20,10 @@ namespace paddle {
namespace framework {
namespace details {
GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
GatherOpHandle::GatherOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() {
if (places_.size() == 1) return;
......
......@@ -30,7 +30,7 @@ namespace details {
struct GatherOpHandle : public OpHandleBase {
public:
GatherOpHandle(const std::vector<Scope *> &local_scopes,
GatherOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
std::string Name() const override;
......
......@@ -70,6 +70,7 @@ struct TestGatherOpHandle {
}
void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
......@@ -81,30 +82,37 @@ struct TestGatherOpHandle {
}
param_scopes_[input_scope_idx]->Var("out");
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_));
nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
// add dummy var
vars_.emplace_back(new DummyVarHandle());
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr;
in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle);
// add output
auto* out_var_handle =
new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]);
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
// add dummy var
vars_.emplace_back(new DummyVarHandle());
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
op_handle_->AddOutput(dummy_var_handle);
......
......@@ -328,12 +328,16 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name,
size_t src_dev_id) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
auto *op_handle = new BroadcastOpHandle(result->nodes.back().get(),
local_scopes_, places_, nccl_ctxs_);
#else
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
auto *op_handle =
new BroadcastOpHandle(result->nodes.back().get(), local_scopes_, places_);
#endif
result->Get<GraphOps>("ops").emplace_back(op_handle);
auto *in =
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in);
......@@ -341,8 +345,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
auto *out_var =
new VarHandle(result->nodes.back().get(), vars.size(), i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
}
......@@ -351,19 +357,21 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
const OpDesc &op,
int dev_id) const {
result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
result->Get<GraphOps>("ops").emplace_back(new ComputationOpHandle(
result->nodes.back().get(), op, local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id);
}
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
#else
result->Get<GraphOps>("ops").emplace_back(
new AllReduceOpHandle(local_scopes_, places_));
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->nodes.back().get(), local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>("ops").back().get();
......@@ -375,7 +383,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
auto var = new VarHandle(vars.size(), i, og, p);
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var = new VarHandle(result->nodes.back().get(), vars.size(), i, og, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
......@@ -383,12 +392,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
Graph *result, const std::vector<std::string> &datas) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
#else
result->Get<GraphOps>("ops").emplace_back(
new DataBalanceOpHandle(local_scopes_, places_));
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->nodes.back().get(), local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) {
......@@ -398,7 +408,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get());
auto var = new VarHandle(vars.size(), i, d_name, p);
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var =
new VarHandle(result->nodes.back().get(), vars.size(), i, d_name, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
......@@ -452,10 +464,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif
auto *op_handle =
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
places_[i], communication_dev_ctx);
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
auto *op_handle = new ScaleLossGradOpHandle(
result->nodes.back().get(), local_scopes_.size(), local_scopes_[i],
places_[i], communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
......@@ -475,8 +487,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(op, s, p));
new ComputationOpHandle(result->nodes.back().get(), op, s, p));
CreateOpHandleIOs(result, op, scope_idx);
}
}
......@@ -484,12 +497,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const std::string &og,
int dst_dev_id) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
#else
result->Get<GraphOps>("ops").emplace_back(
new ReduceOpHandle(local_scopes_, places_));
new ReduceOpHandle(result->nodes.back().get(), local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>("ops").back().get();
......@@ -502,7 +516,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
op_handle->AddInput(prev_grad.get());
}
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var = new VarHandle(result->nodes.back().get(), vars.size(), dst_dev_id,
og, places_[dst_dev_id]);
vars.emplace_back(var);
op_handle->AddOutput(var);
return var;
......@@ -514,7 +530,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *dep_var = new DummyVarHandle(result->nodes.back().get());
prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var);
......@@ -587,8 +604,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id]));
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
result->Get<GraphOps>("ops").emplace_back(
new RPCOpHandle(result->nodes.back().get(), op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
......
......@@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
void OpHandleBase::AddInput(VarHandleBase *in) {
this->inputs_.emplace_back(in);
in->pending_ops_.insert(this);
node_->inputs.push_back(in->Node());
in->AddOutput(this, this->Node());
}
void OpHandleBase::AddOutput(VarHandleBase *out) {
outputs_.emplace_back(out);
out->generated_op_ = this;
node_->outputs.push_back(out->Node());
out->AddInput(this, this->Node());
}
void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second);
}
}
}
......@@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() {
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]);
}
}
}
......@@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const {
}
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_;
return in_var && in_var->GeneratedOp();
}
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
......
......@@ -17,6 +17,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
......@@ -28,7 +29,7 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
class OpHandleBase {
public:
OpHandleBase() {}
explicit OpHandleBase(ir::Node *node) : node_(node) {}
virtual ~OpHandleBase();
......@@ -82,6 +83,8 @@ class OpHandleBase {
size_t NoDummyInputSize() const;
ir::Node *Node() { return node_; }
protected:
void RunAndRecordEvent(const std::function<void()> &callback);
......@@ -90,6 +93,7 @@ class OpHandleBase {
virtual void RunImpl() = 0;
ir::Node *node_;
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
......
......@@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
ReduceOpHandle(const std::vector<Scope *> &local_scopes,
ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
......@@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase {
}
}
#else
ReduceOpHandle(const std::vector<Scope *> &local_scopes,
ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string Name() const override;
......
......@@ -84,6 +84,7 @@ struct TestReduceOpHandle {
}
void InitReduceOp(size_t out_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
// init scope
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
......@@ -96,19 +97,21 @@ struct TestReduceOpHandle {
}
param_scopes_[out_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
gpu_list_, nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
gpu_list_, nccl_ctxs_.get()));
#else
op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_));
op_handle_.reset(
new ReduceOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
#endif
}
......@@ -118,8 +121,10 @@ struct TestReduceOpHandle {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
in_var_handle->generated_op_ = nullptr;
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
in_var_handle->ClearGeneratedOp();
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
......@@ -128,12 +133,13 @@ struct TestReduceOpHandle {
vars_.emplace_back(new DummyVarHandle());
DummyVarHandle *in_dummy_var_handle =
static_cast<DummyVarHandle *>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr;
in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle);
// add output
auto *out_var_handle =
new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]);
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx,
"out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
......
......@@ -18,10 +18,11 @@ namespace paddle {
namespace framework {
namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
const Scope *local_scope, const std::string &name,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
name_(name),
place_(place) {}
......@@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK
continue;
}
if (in->generated_op_) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
......@@ -28,8 +28,9 @@ namespace framework {
namespace details {
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const std::string& name, const platform::Place& place);
RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc,
const Scope* local_scope, const std::string& name,
const platform::Place& place);
std::string Name() const override;
......
......@@ -19,10 +19,14 @@
namespace paddle {
namespace framework {
namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope,
platform::Place place,
platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {
: OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {
dev_ctxes_[place_] = dev_ctx;
}
......
......@@ -25,7 +25,8 @@ namespace framework {
namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place,
ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *context);
~ScaleLossGradOpHandle() final;
......
......@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = (*it_new)->generated_op_;
auto &read_ops = (*it_old)->pending_ops_;
OpHandleBase *write_op = (*it_new)->GeneratedOp();
const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
......@@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue;
}
auto *dep_var = new DummyVarHandle();
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *dep_var = new DummyVarHandle(graph->nodes.back().get());
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
......@@ -54,7 +55,9 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
var = new VarHandle(0, place_offset, each_var_name, place);
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
var = new VarHandle(graph->nodes.back().get(), 0, place_offset,
each_var_name, place);
var_holder.emplace_back(var);
} else {
var = var_holder.rbegin()->get();
......@@ -68,7 +71,9 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][each_var_name];
size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place);
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var = new VarHandle(graph->nodes.back().get(), version, place_offset,
each_var_name, place);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
......@@ -80,7 +85,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle();
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *dummy_leaf = new DummyVarHandle(graph->nodes.back().get());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
......
......@@ -28,7 +28,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var);
if (var->generated_op_ == nullptr) {
if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var);
}
};
......@@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
for (auto *op : ready_var->PendingOps()) {
auto &deps = --pending_ops[op];
if (deps == 0) {
ready_ops.insert(op);
......
......@@ -65,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
&pending_vars, &ready_vars, &fetch_data);
InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
&pending_ops, &pending_vars, &ready_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
......@@ -126,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
......@@ -152,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
......@@ -170,7 +172,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
ir::Node *fetch_n = new ir::Node(ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_);
temp_nodes->emplace_back(fetch_n);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
......@@ -181,9 +186,11 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var);
}
auto *fetch_dummy = new DummyVarHandle();
ir::Node *dummy_n = new ir::Node(ir::Node::Type::kVariable);
auto *fetch_dummy = new DummyVarHandle(dummy_n);
op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy);
temp_nodes->emplace_back(dummy_n);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
this->InsertPendingOp(pending_ops, op);
}
......@@ -199,7 +206,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
pending_vars->insert(var);
if (var->generated_op_ == nullptr) {
if (var->GeneratedOp() == nullptr) {
ready_vars->Push(var);
}
}
......
......@@ -72,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <sstream>
#include <string>
#include <unordered_set>
......@@ -30,15 +32,51 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {}
virtual ~VarHandleBase();
virtual std::string DebugString() const = 0;
void AddInput(OpHandleBase* in, ir::Node* node) {
node_->inputs.clear();
node_->inputs.push_back(node);
generated_op_ = in;
}
void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) {
pending_ops_.insert(out);
node_->outputs.push_back(node);
}
}
void RemoveOutput(OpHandleBase* out, ir::Node* node) {
pending_ops_.erase(out);
std::remove(node_->outputs.begin(), node_->outputs.end(), node);
}
void ClearGeneratedOp() {
generated_op_ = nullptr;
node_->inputs.clear();
}
OpHandleBase* GeneratedOp() { return generated_op_; }
const std::unordered_set<OpHandleBase*>& PendingOps() const {
return pending_ops_;
}
ir::Node* Node() { return node_; }
protected:
// The operator who generate this variable. nullptr if the variable
// is a root node.
OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready.
std::unordered_set<OpHandleBase*> pending_ops_;
ir::Node* node_;
};
// VarHandle is actually a single version of Runtime Variable.
......@@ -47,11 +85,14 @@ struct VarHandleBase {
//
// NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override;
VarHandle(size_t version, size_t scope_index, std::string name,
platform::Place place)
: version_(version),
VarHandle(ir::Node* node, size_t version, size_t scope_index,
std::string name, platform::Place place)
: VarHandleBase(node),
version_(version),
scope_idx_(scope_index),
name_(std::move(name)),
place_(std::move(place)) {}
......@@ -71,6 +112,8 @@ struct VarHandle : public VarHandleBase {
// Dummy Variable. It is used to represent dependencies between operators
struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override;
};
......
......@@ -14,10 +14,12 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
#include <map>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h"
......@@ -30,10 +32,10 @@ class Node {
public:
enum class Type { kNone = -1, kOperation, kVariable };
Node(const std::string& name, Type type) : name_(name), type_(type) {}
explicit Node(Type type) : type_(type) {}
virtual ~Node() {
for (auto& attr : attrs_) {
for (auto &attr : attrs_) {
if (attr_dels_.find(attr.first) != attr_dels_.end()) {
attr_dels_[attr.first]();
}
......@@ -42,54 +44,32 @@ class Node {
attrs_.clear();
}
int64_t ID() const { return id_; }
std::string Name() const { return name_; }
virtual std::string ToString() const {
return Name() + "(" + std::to_string(ID()) + ")";
}
virtual std::string DebugString() const = 0;
Type NodeType() const { return type_; }
template <typename AttrType>
void Set(const std::string& name, AttrType attr) {
void Set(const std::string &name, AttrType attr) {
attrs_[name] = attr;
}
template <typename AttrType>
void Set(const std::string& name, AttrType* attr,
void Set(const std::string &name, AttrType *attr,
std::function<void(void)> attr_del) {
attrs_[name] = attr;
attr_dels_[name] = attr_del;
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
std::vector<Node *> inputs;
std::vector<Node *> outputs;
protected:
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
int64_t id_ = 0;
std::string name_;
Type type_;
private:
DISABLE_COPY_AND_ASSIGN(Node);
};
class Variable : public Node {
public:
explicit Variable(const std::string& name) : Node(name, Type::kVariable) {}
};
class Operation : public Node {
public:
explicit Operation(const std::string& name) : Node(name, Type::kOperation) {}
};
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册