提交 10786a24 编写于 作者: X Xin Pan

polish graph

上级 2fa8df1c
...@@ -96,7 +96,7 @@ struct TestBroadcastOpHandle { ...@@ -96,7 +96,7 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(new ir::Node()); std::unique_ptr<ir::Node> n(new ir::Node("node0"));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
...@@ -114,7 +114,7 @@ struct TestBroadcastOpHandle { ...@@ -114,7 +114,7 @@ struct TestBroadcastOpHandle {
#endif #endif
} }
std::unique_ptr<ir::Node> v(new ir::Node()); std::unique_ptr<ir::Node> v(new ir::Node("node1"));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]); gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
...@@ -122,7 +122,7 @@ struct TestBroadcastOpHandle { ...@@ -122,7 +122,7 @@ struct TestBroadcastOpHandle {
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v2(new ir::Node()); std::unique_ptr<ir::Node> v2(new ir::Node("node2"));
vars_.emplace_back(new DummyVarHandle(v2.get())); vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
...@@ -133,7 +133,7 @@ struct TestBroadcastOpHandle { ...@@ -133,7 +133,7 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
std::unique_ptr<ir::Node> v3(new ir::Node()); std::unique_ptr<ir::Node> v3(new ir::Node("node3"));
VarHandle* out_var_handle = VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
...@@ -141,7 +141,7 @@ struct TestBroadcastOpHandle { ...@@ -141,7 +141,7 @@ struct TestBroadcastOpHandle {
} }
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v4(new ir::Node()); std::unique_ptr<ir::Node> v4(new ir::Node("node4"));
vars_.emplace_back(new DummyVarHandle(v4.get())); vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
Scope *scope, platform::Place place) platform::Place place)
: OpHandleBase(node), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)), op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope), scope_(scope),
place_(place) {} place_(place) {}
......
...@@ -28,8 +28,7 @@ namespace framework { ...@@ -28,8 +28,7 @@ namespace framework {
namespace details { namespace details {
struct ComputationOpHandle : public OpHandleBase { struct ComputationOpHandle : public OpHandleBase {
public: public:
ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope, ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
platform::Place place);
std::string Name() const override; std::string Name() const override;
......
...@@ -82,13 +82,13 @@ struct TestGatherOpHandle { ...@@ -82,13 +82,13 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node()); nodes.emplace_back(new ir::Node("node"));
op_handle_.reset( op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back(new ir::Node()); nodes.emplace_back(new ir::Node("node1"));
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
...@@ -96,7 +96,7 @@ struct TestGatherOpHandle { ...@@ -96,7 +96,7 @@ struct TestGatherOpHandle {
} }
// add dummy var // add dummy var
nodes.emplace_back(new ir::Node()); nodes.emplace_back(new ir::Node("node2"));
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
...@@ -104,14 +104,14 @@ struct TestGatherOpHandle { ...@@ -104,14 +104,14 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back(new ir::Node()); nodes.emplace_back(new ir::Node("node3"));
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]); "out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back(new ir::Node()); nodes.emplace_back(new ir::Node("node4"));
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
......
...@@ -90,7 +90,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( ...@@ -90,7 +90,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
// since parameters are all in block 0, // since parameters are all in block 0,
// it's enough to only scan send ops in block 0 // it's enough to only scan send ops in block 0
for (auto &node : nodes) { for (auto &node : nodes) {
if (!node->Op()) continue; if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op(); OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find send op, // TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string // instead of the the hard code string
...@@ -108,7 +108,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( ...@@ -108,7 +108,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const { const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
std::vector<std::string> recv_vars; std::vector<std::string> recv_vars;
for (auto &node : nodes) { for (auto &node : nodes) {
if (!node->Op()) continue; if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op(); OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find recv op, // TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string // instead of the hard code string
...@@ -149,10 +149,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -149,10 +149,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) { for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Var()->Name()); input_var_names.push_back(input->Name());
} }
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Var()->Name()); output_var_names.push_back(output->Name());
} }
return checker(output_var_names, send_vars) || return checker(output_var_names, send_vars) ||
...@@ -181,13 +181,13 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -181,13 +181,13 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<Graph> graph) const { std::unique_ptr<Graph> graph) const {
// Rebuild the graph structure.
auto nodes = std::move(graph->nodes); auto nodes = std::move(graph->nodes);
graph->nodes.clear(); graph->nodes.clear();
LOG(ERROR) << "origin nodes count " << nodes.size();
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->Var()) { if (node->NodeType() == ir::Node::Type::kVariable) {
all_vars_.emplace(node->Var()->Name(), node->Var()); all_vars_.emplace(node->Name(), node->Var());
} }
} }
...@@ -212,7 +212,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -212,7 +212,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
// TODO(panyx0718): FIXME: nodes should be sorted by "program" order. // TODO(panyx0718): FIXME: nodes should be sorted by "program" order.
for (auto &node : nodes) { for (auto &node : nodes) {
if (!node->Op()) continue; if (node->NodeType() != ir::Node::Type::kOperation) continue;
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
...@@ -235,7 +235,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -235,7 +235,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node.get(), op_dev_id); CreateComputationalOp(&result, node.get(), op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(n->Var()->Name(), op_dev_id); var_name_on_devices_.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
...@@ -351,10 +351,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -351,10 +351,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
local_scopes_, places_, nccl_ctxs_); local_scopes_, places_, nccl_ctxs_);
#else #else
auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
local_scopes_, places_); local_scopes_, places_);
#endif #endif
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
...@@ -367,8 +367,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -367,8 +367,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name); auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto *out_var = auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(),
new VarHandle(result->CreateVarNode(p_name), vars.size(), i, p_name, p); i, p_name, p);
vars.emplace_back(out_var); vars.emplace_back(out_var);
op_handle->AddOutput(out_var); op_handle->AddOutput(out_var);
} }
...@@ -378,7 +378,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, ...@@ -378,7 +378,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), *node->Op(), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id])); local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
...@@ -386,11 +386,12 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, ...@@ -386,11 +386,12 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_,
places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->CreateOpNode(nullptr), local_scopes_, places_)); result->CreateEmptyNode("allreduce"), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -402,7 +403,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -402,7 +403,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
auto var = new VarHandle(result->CreateVarNode(og), vars.size(), i, og, p); auto var =
new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -411,11 +413,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -411,11 +413,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
Graph *result, const std::vector<std::string> &datas) const { Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>("ops").emplace_back(
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"),
local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->CreateOpNode(nullptr), local_scopes_, places_)); result->CreateEmptyNode("data_balance"), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -425,7 +428,7 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -425,7 +428,7 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
auto &vars = result->Get<GraphVars>("vars")[i][d_name]; auto &vars = result->Get<GraphVars>("vars")[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back().get());
auto var = new VarHandle(result->CreateVarNode(d_name), vars.size(), i, auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i,
d_name, p); d_name, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
...@@ -455,12 +458,12 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { ...@@ -455,12 +458,12 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
return -1; return -1;
} }
auto param_grad = boost::get<std::vector<std::string>>( auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(), PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]",
param_grad[0]); node->Op()->Type(), param_grad[0]);
return dev_id; return dev_id;
} }
...@@ -481,8 +484,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -481,8 +484,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif #endif
auto *op_handle = new ScaleLossGradOpHandle( auto *op_handle = new ScaleLossGradOpHandle(
result->CreateOpNode(nullptr), local_scopes_.size(), local_scopes_[i], result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(),
places_[i], communication_dev_ctx); local_scopes_[i], places_[i], communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
...@@ -495,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -495,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
const std::string grad_var_name = GradVarName(loss_var_name_); const std::string grad_var_name = GradVarName(loss_var_name_);
auto &vars = result->Get<GraphVars>("vars")[i][grad_var_name]; auto &vars = result->Get<GraphVars>("vars")[i][grad_var_name];
size_t version = vars.size(); size_t version = vars.size();
auto var = new VarHandle(result->CreateVarNode(grad_var_name), version, i, auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i,
grad_var_name, places_[i]); grad_var_name, places_[i]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
...@@ -508,8 +511,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, ...@@ -508,8 +511,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->Get<GraphOps>("ops").emplace_back(new ComputationOpHandle( result->Get<GraphOps>("ops").emplace_back(
result->CreateOpNode(node->Op()), *node->Op(), s, p)); new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx); CreateOpHandleIOs(result, node, scope_idx);
} }
} }
...@@ -519,10 +522,10 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, ...@@ -519,10 +522,10 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
int dst_dev_id) const { int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->CreateOpNode(nullptr), local_scopes_, places_)); result->CreateEmptyNode("reduce"), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -535,7 +538,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, ...@@ -535,7 +538,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og]; auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
auto var = new VarHandle(result->CreateVarNode(og), vars.size(), dst_dev_id, auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id,
og, places_[dst_dev_id]); og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
...@@ -548,7 +551,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, ...@@ -548,7 +551,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(result->CreateVarNode("dummy")); auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy"));
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var); result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
...@@ -562,10 +565,10 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, ...@@ -562,10 +565,10 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) { for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Var()->Name()); input_var_names.push_back(input->Name());
} }
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Var()->Name()); output_var_names.push_back(output->Name());
} }
if (node->Op()->Type() == "split_byref" || if (node->Op()->Type() == "split_byref" ||
...@@ -606,16 +609,16 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, ...@@ -606,16 +609,16 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
op_dev_id = GetVarDeviceID(node->inputs[0]->Var()->Name()); op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
// split_byref op // split_byref op
// so that we can balance the variable blocks to all the pserver // so that we can balance the variable blocks to all the pserver
// instances. // instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
node->inputs[0]->Var()->Name().find(".block") == std::string::npos) { node->inputs[0]->Name().find(".block") == std::string::npos) {
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
for (ir::Node *n : node->inputs) { for (ir::Node *n : node->inputs) {
input_var_names.push_back(n->Var()->Name()); input_var_names.push_back(n->Name());
} }
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
...@@ -625,7 +628,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { ...@@ -625,7 +628,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
output_var_names.push_back(n->Var()->Name()); output_var_names.push_back(n->Name());
} }
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
......
...@@ -97,7 +97,7 @@ struct TestReduceOpHandle { ...@@ -97,7 +97,7 @@ struct TestReduceOpHandle {
} }
param_scopes_[out_scope_idx]->Var("out"); param_scopes_[out_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node()); nodes.emplace_back(new ir::Node("node"));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
...@@ -121,7 +121,7 @@ struct TestReduceOpHandle { ...@@ -121,7 +121,7 @@ struct TestReduceOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
nodes.emplace_back(new ir::Node()); nodes.emplace_back(new ir::Node("node1"));
auto *in_var_handle = auto *in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
in_var_handle->ClearGeneratedOp(); in_var_handle->ClearGeneratedOp();
...@@ -137,7 +137,7 @@ struct TestReduceOpHandle { ...@@ -137,7 +137,7 @@ struct TestReduceOpHandle {
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back(new ir::Node()); nodes.emplace_back(new ir::Node("node2"));
auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx,
"out", gpu_list_[out_scope_idx]); "out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
......
...@@ -37,7 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -37,7 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue; continue;
} }
auto *dep_var = new DummyVarHandle(graph->CreateVarNode("dummy")); auto *dep_var = new DummyVarHandle(graph->CreateEmptyNode("dummy"));
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var); graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
...@@ -51,11 +51,16 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -51,11 +51,16 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
Graph *graph, ir::Node *node, const platform::Place &place, Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset]; auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[node->Var()->Name()]; auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, if (node->NodeType() == ir::Node::Type::kVariable) {
node->Var()->Name(), place); var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
node->Name(), place);
} else {
var = new VarHandle(graph->CreateEmptyNode(node->Name()), 0, place_offset,
node->Name(), place);
}
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
var = var_holder.rbegin()->get(); var = var_holder.rbegin()->get();
...@@ -67,10 +72,10 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, ...@@ -67,10 +72,10 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
ir::Node *node, ir::Node *node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][node->Var()->Name()]; auto &vars = graph->Get<GraphVars>("vars")[place_offset][node->Name()];
size_t version = vars.size(); size_t version = vars.size();
auto var = new VarHandle(graph->CreateVarNode(node->Var()), version, auto var = new VarHandle(graph->CreateVarNode(node->Var()), version,
place_offset, node->Var()->Name(), place); place_offset, node->Name(), place);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -82,7 +87,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { ...@@ -82,7 +87,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(graph->CreateVarNode("dummy")); auto *dummy_leaf = new DummyVarHandle(graph->CreateEmptyNode("dummy"));
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
......
...@@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name); auto &vars = fetched_vars.at(var_name);
ir::Node *fetch_n = new ir::Node(ir::Node::Type::kOperation); ir::Node *fetch_n = new ir::Node("fetch");
auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_); auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_);
temp_nodes->emplace_back(fetch_n); temp_nodes->emplace_back(fetch_n);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
...@@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var); op->AddInput(var);
} }
ir::Node *dummy_n = new ir::Node(ir::Node::Type::kVariable); ir::Node *dummy_n = new ir::Node("fetch");
auto *fetch_dummy = new DummyVarHandle(dummy_n); auto *fetch_dummy = new DummyVarHandle(dummy_n);
op->AddOutput(fetch_dummy); op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy); fetch_dependencies->emplace(fetch_dummy);
......
...@@ -35,19 +35,15 @@ std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) { ...@@ -35,19 +35,15 @@ std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) {
if (all_vars.count(each_var_name) != 0) { if (all_vars.count(each_var_name) != 0) {
var = graph->CreateVarNode(all_vars.at(each_var_name)); var = graph->CreateVarNode(all_vars.at(each_var_name));
} else { } else {
var = graph->CreateVarNode(each_var_name); LOG(ERROR) << "input var not in all_var list: " << each_var_name;
var = graph->CreateEmptyNode(each_var_name);
} }
node->inputs.push_back(var); node->inputs.push_back(var);
var->outputs.push_back(node); var->outputs.push_back(node);
} }
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = graph->CreateVarNode(all_vars.at(each_var_name));
if (all_vars.count(each_var_name) != 0) {
var = graph->CreateVarNode(all_vars.at(each_var_name));
} else {
var = graph->CreateVarNode(each_var_name);
}
node->outputs.push_back(var); node->outputs.push_back(var);
var->inputs.push_back(node); var->inputs.push_back(node);
} }
......
...@@ -72,16 +72,14 @@ class Graph { ...@@ -72,16 +72,14 @@ class Graph {
} }
// TODO(panyx0718): Need to handle CreateOpNode(nullptr). // TODO(panyx0718): Need to handle CreateOpNode(nullptr).
ir::Node* CreateVarNode(const std::string& var_name) { ir::Node* CreateEmptyNode(const std::string& name) {
var_descs_.emplace_back(new VarDesc(var_name)); nodes.emplace_back(new ir::Node(name));
nodes.emplace_back(new ir::Node(var_descs_.back().get()));
return nodes.back().get(); return nodes.back().get();
} }
std::vector<ir::Node*> inputs; std::vector<ir::Node*> inputs;
std::vector<ir::Node*> outputs; std::vector<ir::Node*> outputs;
std::vector<std::unique_ptr<ir::Node>> nodes; std::vector<std::unique_ptr<ir::Node>> nodes;
std::vector<std::unique_ptr<VarDesc>> var_descs_;
private: private:
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
......
...@@ -32,51 +32,43 @@ namespace ir { ...@@ -32,51 +32,43 @@ namespace ir {
class Node { class Node {
public: public:
enum class Type { kNone = -1, kOperation, kVariable }; enum class Type { kNone, kOperation, kVariable };
explicit Node(const std::string& name)
: name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(Type::kNone) {}
Node() : type_(Type::kNone) {} explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
explicit Node(Type type) : type_(type) {} var_desc_(var_desc),
op_desc_(nullptr),
type_(Type::kVariable) {}
virtual ~Node() { explicit Node(OpDesc* op_desc)
for (auto& attr : attrs_) { : name_(op_desc->Type()),
if (attr_dels_.find(attr.first) != attr_dels_.end()) { var_desc_(nullptr),
attr_dels_[attr.first](); op_desc_(op_desc),
} type_(Type::kOperation) {}
}
attr_dels_.clear();
attrs_.clear();
}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
template <typename AttrType> std::string Name() const { return name_; }
void Set(const std::string& name, AttrType attr) {
attrs_[name] = attr;
}
template <typename AttrType> VarDesc* Var() {
void Set(const std::string& name, AttrType* attr, PADDLE_ENFORCE(type_ == Type::kVariable);
std::function<void(void)> attr_del) { return var_desc_;
attrs_[name] = attr; }
attr_dels_[name] = attr_del; OpDesc* Op() {
PADDLE_ENFORCE(type_ == Type::kOperation);
return op_desc_;
} }
VarDesc* Var() { return var_desc_; }
OpDesc* Op() { return op_desc_; }
explicit Node(VarDesc* var_desc)
: var_desc_(var_desc), op_desc_(nullptr), type_(Type::kVariable) {}
explicit Node(OpDesc* op_desc)
: var_desc_(nullptr), op_desc_(op_desc), type_(Type::kOperation) {}
std::vector<Node*> inputs; std::vector<Node*> inputs;
std::vector<Node*> outputs; std::vector<Node*> outputs;
protected: protected:
std::map<std::string, boost::any> attrs_; const std::string name_;
std::map<std::string, std::function<void(void)>> attr_dels_;
VarDesc* var_desc_; VarDesc* var_desc_;
OpDesc* op_desc_; OpDesc* op_desc_;
Type type_; Type type_;
......
...@@ -148,7 +148,6 @@ class ParallelExecutor(object): ...@@ -148,7 +148,6 @@ class ParallelExecutor(object):
lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW, lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW,
main.list_vars()) main.list_vars())
] ]
sys.stderr.write('!!!!!!!!before\n')
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
self._places, self._places,
...@@ -159,7 +158,6 @@ class ParallelExecutor(object): ...@@ -159,7 +158,6 @@ class ParallelExecutor(object):
set(self.persistable_vars), main.desc, loss_name set(self.persistable_vars), main.desc, loss_name
if loss_name else '', scope, local_scopes, exec_strategy, if loss_name else '', scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)
sys.stderr.write('!!!!!!!!after\n')
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册