提交 ff5a7b67 编写于 作者: X Xin Pan

polish

上级 a891708d
...@@ -96,7 +96,8 @@ struct TestBroadcastOpHandle { ...@@ -96,7 +96,8 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(new ir::Node("node0")); std::unique_ptr<ir::Node> n(
new ir::Node("node0", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
...@@ -114,7 +115,8 @@ struct TestBroadcastOpHandle { ...@@ -114,7 +115,8 @@ struct TestBroadcastOpHandle {
#endif #endif
} }
std::unique_ptr<ir::Node> v(new ir::Node("node1")); std::unique_ptr<ir::Node> v(
new ir::Node("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]); gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
...@@ -122,7 +124,8 @@ struct TestBroadcastOpHandle { ...@@ -122,7 +124,8 @@ struct TestBroadcastOpHandle {
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v2(new ir::Node("node2")); std::unique_ptr<ir::Node> v2(
new ir::Node("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get())); vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
...@@ -133,7 +136,8 @@ struct TestBroadcastOpHandle { ...@@ -133,7 +136,8 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
std::unique_ptr<ir::Node> v3(new ir::Node("node3")); std::unique_ptr<ir::Node> v3(
new ir::Node("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
...@@ -141,7 +145,8 @@ struct TestBroadcastOpHandle { ...@@ -141,7 +145,8 @@ struct TestBroadcastOpHandle {
} }
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v4(new ir::Node("node4")); std::unique_ptr<ir::Node> v4(
new ir::Node("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get())); vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
......
...@@ -82,13 +82,13 @@ struct TestGatherOpHandle { ...@@ -82,13 +82,13 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node("node")); nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation));
op_handle_.reset( op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back(new ir::Node("node1")); nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable));
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
...@@ -96,7 +96,7 @@ struct TestGatherOpHandle { ...@@ -96,7 +96,7 @@ struct TestGatherOpHandle {
} }
// add dummy var // add dummy var
nodes.emplace_back(new ir::Node("node2")); nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
...@@ -104,14 +104,14 @@ struct TestGatherOpHandle { ...@@ -104,14 +104,14 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back(new ir::Node("node3")); nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable));
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]); "out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back(new ir::Node("node4")); nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
......
...@@ -80,7 +80,14 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, ...@@ -80,7 +80,14 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
} }
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
CreateOpOutput(result, op_handle, output, p, place_id); ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
}
CreateOpOutput(result, op_handle, new_node, p, place_id);
} }
} }
...@@ -246,7 +253,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -246,7 +253,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
node->Op()->SetAttr("throw_eof_exp", false); node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, node.get(), places_.size()); CreateComputationalOps(&result, node.get(), places_.size());
// TODO(panyx0718): builder shouldn't depend on the out logic of // TODO(paddle-dev): builder shouldn't depend on the out logic of
// a specific op. // a specific op.
const auto &data_var_names = node->Op()->Output("Out"); const auto &data_var_names = node->Op()->Output("Out");
InsertDataBalanceOp(&result, data_var_names); InsertDataBalanceOp(&result, data_var_names);
...@@ -354,10 +361,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -354,10 +361,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), auto *op_handle = new BroadcastOpHandle(
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_); local_scopes_, places_, nccl_ctxs_);
#else #else
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), auto *op_handle = new BroadcastOpHandle(
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_); local_scopes_, places_);
#endif #endif
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
...@@ -370,7 +379,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -370,7 +379,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name); auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(), auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
i, p_name, p); i, p_name, p);
vars.emplace_back(out_var); vars.emplace_back(out_var);
op_handle->AddOutput(out_var); op_handle->AddOutput(out_var);
...@@ -389,12 +399,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, ...@@ -389,12 +399,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_, result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce"), local_scopes_, places_)); result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -407,7 +418,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -407,7 +418,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p); new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -416,12 +428,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -416,12 +428,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
Graph *result, const std::vector<std::string> &datas) const { Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance"), local_scopes_, places_)); result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -431,8 +444,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -431,8 +444,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
auto &vars = result->Get<GraphVars>("vars")[i][d_name]; auto &vars = result->Get<GraphVars>("vars")[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back().get());
auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i, auto var = new VarHandle(
d_name, p); result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -487,8 +501,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -487,8 +501,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif #endif
auto *op_handle = new ScaleLossGradOpHandle( auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(), result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_[i], places_[i], communication_dev_ctx); local_scopes_.size(), local_scopes_[i], places_[i],
communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
...@@ -497,14 +512,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -497,14 +512,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
// loss->pending_ops_.emplace_back(op_handle); // loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss); // op_handle->inputs_.emplace_back(loss);
// TODO(panyx0718): GradVarName(loss_var_name_) CreateOpOutput(result, op_handle,
const std::string grad_var_name = GradVarName(loss_var_name_); result->CreateEmptyNode(GradVarName(loss_var_name_),
auto &vars = result->Get<GraphVars>("vars")[i][grad_var_name]; ir::Node::Type::kVariable),
size_t version = vars.size(); places_[i], i);
auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i,
grad_var_name, places_[i]);
vars.emplace_back(var);
op_handle->AddOutput(var);
} }
} }
...@@ -525,10 +536,12 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, ...@@ -525,10 +536,12 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
int dst_dev_id) const { int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_)); result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce"), local_scopes_, places_)); result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -541,8 +554,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, ...@@ -541,8 +554,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og]; auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id, auto var =
og, places_[dst_dev_id]); new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
return var; return var;
...@@ -554,7 +568,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, ...@@ -554,7 +568,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy")); auto *dep_var = new DummyVarHandle(
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var); result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
......
...@@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue; continue;
} }
auto *dep_var = new DummyVarHandle(graph->CreateEmptyNode("dummy")); auto *dep_var = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var); graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
...@@ -54,12 +55,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -54,12 +55,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[node->Name()]; auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
if (node->NodeType() == ir::Node::Type::kVariable) { if (node->Var()) {
var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
node->Name(), place); node->Name(), place);
} else { } else {
var = new VarHandle(graph->CreateEmptyNode(node->Name()), 0, place_offset, var = new VarHandle(
node->Name(), place); graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0,
place_offset, node->Name(), place);
} }
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
...@@ -69,13 +71,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -69,13 +71,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
} }
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
ir::Node *node, ir::Node *new_node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][node->Name()]; auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
size_t version = vars.size(); size_t version = vars.size();
auto var = new VarHandle(graph->CreateVarNode(node->Var()), version, auto var =
place_offset, node->Name(), place); new VarHandle(new_node, version, place_offset, new_node->Name(), place);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -85,7 +87,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { ...@@ -85,7 +87,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(graph->CreateEmptyNode("dummy")); auto *dummy_leaf = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
......
...@@ -73,7 +73,7 @@ class SSAGraphBuilder : public ir::Pass { ...@@ -73,7 +73,7 @@ class SSAGraphBuilder : public ir::Pass {
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
ir::Node *node, const platform::Place &place, ir::Node *new_node, const platform::Place &place,
size_t place_offset); size_t place_offset);
static void AddOutputToLeafOps(Graph *graph); static void AddOutputToLeafOps(Graph *graph);
......
...@@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name); auto &vars = fetched_vars.at(var_name);
temp_nodes->emplace_back(new ir::Node("fetch")); temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
&local_scopes_); &local_scopes_);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
...@@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var); op->AddInput(var);
} }
temp_nodes->emplace_back(new ir::Node("fetch")); temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
op->AddOutput(fetch_dummy); op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy); fetch_dependencies->emplace(fetch_dummy);
......
...@@ -41,7 +41,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -41,7 +41,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
// TODO(paddle-dev): Seems some assumption doesn't hold? // TODO(paddle-dev): Seems some assumption doesn't hold?
LOG(ERROR) << op->Type() LOG(ERROR) << op->Type()
<< " input var not in all_var list: " << each_var_name; << " input var not in all_var list: " << each_var_name;
var = CreateEmptyNode(each_var_name); var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name] = var; var_nodes[each_var_name] = var;
} }
node->inputs.push_back(var); node->inputs.push_back(var);
......
...@@ -67,8 +67,8 @@ class Graph { ...@@ -67,8 +67,8 @@ class Graph {
// TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph. // TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph.
// node should either be a executable kOperation or a kVariable. kNone // node should either be a executable kOperation or a kVariable. kNone
// node is a temporary solution. // node is a temporary solution.
ir::Node* CreateEmptyNode(const std::string& name) { ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) {
nodes.emplace_back(new ir::Node(name)); nodes.emplace_back(new ir::Node(name, type));
return nodes.back().get(); return nodes.back().get();
} }
......
...@@ -26,12 +26,9 @@ namespace ir { ...@@ -26,12 +26,9 @@ namespace ir {
class Node { class Node {
public: public:
enum class Type { kNone, kOperation, kVariable }; enum class Type { kOperation, kVariable };
explicit Node(const std::string& name) explicit Node(const std::string& name, Type type)
: name_(name), : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
var_desc_(nullptr),
op_desc_(nullptr),
type_(Type::kNone) {}
explicit Node(VarDesc* var_desc) explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()), : name_(var_desc->Name()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册