未验证 提交 66bb7343 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #12101 from panyx0718/ir

Initial IR change
## Motivation
There is a ```gap``` between the ```Program``` defined by
user and the ```Executable``` that can be scheduled
efficiently on heterogeneous hardware, either locally
or distributedly.
Usually, the ```gap``` is bridged by
* A serious transformations with defined order.
* These transformations usually involve
```insert, delete, clustering, split, dependency analysis```.
* Has a simple way to verify and debug each transformation.
* Flexible to add, remove or customize transformations to fit
the requirements of various algorithms (models) and hardware secenarios.
Some other events also push us to a better unified pattern.
* The deep learning framework is built around the concepts of graphs.
To leverage tools such as compilation (e.g. TVM and nGraph) or
cross-framework conversion (e.g. ONNX), we also need a intermediate
representation that can be connected to the rest of the ecosystem.
We need a unified pattern to naturally support the requirements
described above. The pattern should fit both training, inference
and other offline serielized model transformations.
Learned from LLVM and other deep learning framework, we draft the
design below.
## Design
### Major Concepts
#### Node
```Node``` represents an operation that performs some computation or
a variable that is input or output of operation.
```Node```s are connected to other ```Node```s via inputs and outputs.
Other properties (maybe device placement information) can be added
to ```Node``` in the future if it's a
common requirement of many other ```Pass```es. Otherwise, it should live
in a ```Node``` wrapper class that is private to some ```Pass``` or be
a local member of a ```Pass```.
#### Graph
```Graph``` contains a list of ```Node```s, which are connected to
each other via inputs and outputs.
TODO: Better definitions for the graph.
```Graph``` can also contain ```Attribute```s. ```Attribute```s
can be ``any`` thing. For example, it can be a list of "wraper"
nodes. The ```wrapper``` nodes compose ```Node```s and provide
helper method for execution or transformation. ```Attribute```
can also contain other things that describe some properties of
the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed
across ```Pass```. However, it should be used with care.
#### Pass
```Pass``` represents a transformation of ```Graph```. Its input
is a ```Graph``` and its output is also a ```Graph```. For example,
a ```Pass``` can simply print out the ```Graph```. A ```Pass```
can also fuse some ```Graph```'s ```Node```s.
#### Optimize
```Optimize``` contains a series of ```Pass``` with defined order.
```Optimize``` transforms a ```Graph``` that only contains raw
modeling logic to a ```Graph``` that can be run efficiently while
maintaining the original modeling logic.
### Optimize Process
* Program is first converted to Graph.
* Graph goes through a series of Pass
* Graph is transformed from raw model logic to a
form that is efficient to execute.
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
add_subdirectory(details)
add_subdirectory(ir)
# ddim lib
proto_library(framework_proto SRCS framework.proto)
......@@ -93,7 +94,7 @@ else()
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
......@@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......
......@@ -23,10 +23,14 @@ namespace framework {
namespace details {
#ifdef PADDLE_WITH_CUDA
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
......@@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
}
}
#else
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
void AllReduceOpHandle::RunImpl() {
......
......@@ -30,11 +30,11 @@ namespace details {
struct AllReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
std::string Name() const override;
......
......@@ -35,10 +35,13 @@ namespace details {
struct BroadcastOpHandle : public OpHandleBase {
public:
#ifdef PADDLE_WITH_CUDA
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
......@@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase {
}
}
#else
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string Name() const override;
......
......@@ -96,48 +96,61 @@ struct TestBroadcastOpHandle {
}
param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(
new ir::Node("node0", ir::Node::Type::kOperation));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
nccl_ctxs_.get()));
#else
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
op_handle_.reset(
new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_));
#endif
}
auto* in_var_handle =
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
std::unique_ptr<ir::Node> v(
new ir::Node("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add dummy var
vars_.emplace_back(new DummyVarHandle());
std::unique_ptr<ir::Node> v2(
new ir::Node("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
dummy_var_handle->generated_op_ = nullptr;
dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
std::unique_ptr<ir::Node> v3(
new ir::Node("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
// add dummy var
vars_.emplace_back(new DummyVarHandle());
std::unique_ptr<ir::Node> v4(
new ir::Node("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
out_dummy_var_handle->generated_op_ = nullptr;
out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle);
}
......
......@@ -19,9 +19,10 @@
namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope),
place_(place) {}
......@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() {
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->generated_op_ &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
in_var && in_var->GeneratedOp() &&
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait;
}
......
......@@ -28,8 +28,7 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place);
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
std::string Name() const override;
......
......@@ -22,10 +22,10 @@ namespace details {
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes,
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places) {
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) {
for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs->DevCtx(p);
......@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle(
}
#else
DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes,
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string DataBalanceOpHandle::Name() const { return "data balance"; }
......
......@@ -30,11 +30,11 @@ namespace details {
struct DataBalanceOpHandle : public OpHandleBase {
public:
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
......
......@@ -21,13 +21,16 @@ namespace paddle {
namespace framework {
namespace details {
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset,
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes)
: data_(data), offset_(offset), local_scopes_(local_scopes) {}
: OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() {
for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this);
input_var->RemoveOutput(this, this->Node());
}
}
......@@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() {
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) {
if (input->generated_op_) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
if (input->GeneratedOp()) {
input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx);
}
}
}
......
......@@ -28,7 +28,7 @@ namespace details {
struct FetchOpHandle : public OpHandleBase {
public:
FetchOpHandle(FeedFetchList *data, size_t offset,
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes);
~FetchOpHandle();
......
......@@ -30,10 +30,12 @@ namespace details {
struct FuseVarsOpHandle : public OpHandleBase {
public:
FuseVarsOpHandle(Scope *local_scope, const platform::Place &place,
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type)
: local_scope_(local_scope),
: OpHandleBase(node),
local_scope_(local_scope),
place_(place),
inputs_numel_(inputs_numel),
type_(var_type) {
......
......@@ -20,9 +20,10 @@ namespace paddle {
namespace framework {
namespace details {
GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
GatherOpHandle::GatherOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() {
if (places_.size() == 1) return;
......
......@@ -30,7 +30,7 @@ namespace details {
struct GatherOpHandle : public OpHandleBase {
public:
GatherOpHandle(const std::vector<Scope *> &local_scopes,
GatherOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
std::string Name() const override;
......
......@@ -70,6 +70,7 @@ struct TestGatherOpHandle {
}
void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
......@@ -81,30 +82,37 @@ struct TestGatherOpHandle {
}
param_scopes_[input_scope_idx]->Var("out");
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_));
nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation));
op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable));
auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
// add dummy var
vars_.emplace_back(new DummyVarHandle());
nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr;
in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle);
// add output
auto* out_var_handle =
new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]);
nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable));
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
// add dummy var
vars_.emplace_back(new DummyVarHandle());
nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
op_handle_->AddOutput(dummy_var_handle);
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
......@@ -66,31 +67,38 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
const OpDesc &op,
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle = result->ops_.back().get();
auto *op_handle = result->Get<GraphOps>("ops").back().get();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
for (auto &each_var_name : op.InputArgumentNames()) {
VarHandle *var =
CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
for (ir::Node *input : node->inputs) {
VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
op_handle->AddInput(var);
}
for (auto &each_var_name : op.OutputArgumentNames()) {
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
for (ir::Node *output : node->outputs) {
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
}
CreateOpOutput(result, op_handle, new_node, p, place_id);
}
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const ProgramDesc &program) const {
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) {
for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send") {
......@@ -104,9 +112,11 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const {
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) {
for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if (op->Type() == "recv") {
......@@ -120,7 +130,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars,
ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false;
......@@ -143,8 +153,17 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
return false;
};
return checker(op.OutputArgumentNames(), send_vars) ||
checker(op.InputArgumentNames(), recv_vars);
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
return checker(output_var_names, send_vars) ||
checker(input_var_names, recv_vars);
}
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
......@@ -167,25 +186,30 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
return dev_id;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
for (auto *var : program.Block(0).AllVars()) {
all_vars_.emplace(var->Name(), var);
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<Graph> graph) const {
// Rebuild the graph structure.
auto nodes = std::move(graph->nodes);
graph->nodes.clear();
for (auto &node : nodes) {
if (node->NodeType() == ir::Node::Type::kVariable) {
all_vars_.emplace(node->Name(), node->Var());
}
}
auto graph = new SSAGraph();
SSAGraph &result = *graph;
Graph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8
result.vars_ = std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
result.Set("vars", new GraphVars(places_.size()));
result.Set("dep_vars", new GraphDepVars);
result.Set("ops", new GraphOps);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);
auto send_vars = FindDistTrainSendVars(nodes);
auto recv_vars = FindDistTrainRecvVars(nodes);
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size());
......@@ -193,14 +217,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
size_t cur_device_id = 0;
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
// NOTE: Currently, passes before SSAGraphBuilder cannot reorder
// forward, backward nodes. E.g. you can't append an forward node
// at the end of the node list.
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op);
} else if (IsScaleLossOp(*op)) {
CreateRPCOp(&result, node.get());
} else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) {
CreateDistTrainOp(&result, node.get());
} else if (IsScaleLossOp(node.get())) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) {
......@@ -212,33 +241,35 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// the block.
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(*op);
int op_dev_id = GetOpDeviceID(node.get());
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(var_name, op_dev_id);
CreateComputationalOp(&result, node.get(), op_dev_id);
for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(n->Name(), op_dev_id);
}
} else {
// This op runs on all devices, and its output may have parameter's
// gradients.
if (op->Type() == "read" && strategy_.enable_data_balance_) {
op->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, *op, places_.size());
const auto &data_var_names = op->Output("Out");
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, node.get(), places_.size());
// TODO(paddle-dev): builder shouldn't depend on the out logic of
// a specific op.
const auto &data_var_names = node->Op()->Output("Out");
InsertDataBalanceOp(&result, data_var_names);
} else {
CreateComputationalOps(&result, *op, places_.size());
CreateComputationalOps(&result, node.get(), places_.size());
}
if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(op->GetAttr(
if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) {
try {
auto backward_vars =
boost::get<std::vector<std::string>>(op->GetNullableAttr(
auto backward_vars = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
......@@ -302,8 +333,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
return std::unique_ptr<SSAGraph>(graph);
return std::move(graph);
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
......@@ -327,78 +357,96 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name,
size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
auto *op_handle = new BroadcastOpHandle(
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_);
#else
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
auto *op_handle = new BroadcastOpHandle(
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_);
#endif
result->Get<GraphOps>("ops").emplace_back(op_handle);
result->ops_.emplace_back(op_handle);
auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
auto *in =
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_.at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
const OpDesc &op,
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
ir::Node *node,
int dev_id) const {
result->ops_.emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id);
result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id);
}
void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
auto &vars = result->Get<GraphVars>("vars")[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
auto var = new VarHandle(vars.size(), i, og, p);
auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
}
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
SSAGraph *result, const std::vector<std::string> &datas) const {
Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) {
auto &vars = result->vars_[i][d_name];
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get());
auto var = new VarHandle(vars.size(), i, d_name, p);
auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
......@@ -417,22 +465,22 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once;
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
int op_role = boost::get<int>(
op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(),
param_grad[0]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]",
node->Op()->Type(), param_grad[0]);
return dev_id;
}
......@@ -441,7 +489,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return got == var_name_on_devices_.end() ? -1 : got->second;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
......@@ -452,11 +500,11 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif
auto *op_handle =
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
places_[i], communication_dev_ctx);
result->ops_.emplace_back(op_handle);
auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i],
communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
......@@ -464,43 +512,51 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i],
i);
CreateOpOutput(result, op_handle,
result->CreateEmptyNode(GradVarName(loss_var_name_),
ir::Node::Type::kVariable),
places_[i], i);
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
const OpDesc &op,
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
ir::Node *node,
size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
CreateOpHandleIOs(result, op, scope_idx);
result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx);
}
}
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const std::string &og,
int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_));
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
auto &vars = result->Get<GraphVars>("vars")[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
}
auto &vars = result->vars_[dst_dev_id][og];
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var);
op_handle->AddOutput(var);
return var;
......@@ -508,35 +564,46 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) {
for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
auto *dep_var = new DummyVarHandle(
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var);
}
}
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
ir::Node *node) const {
int op_dev_id = -1;
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") {
op_dev_id = GetVarDeviceID(input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
for (auto &varname : op.OutputArgumentNames()) {
for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
for (auto &varname : op.OutputArgumentNames()) {
} else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(input_var_names[0]);
for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
......@@ -546,34 +613,43 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
}
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s", op.Type());
"can not find right place for distributed op: %s",
node->Op()->Type());
CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
CreateComputationalOp(result, node, op_dev_id);
if (node->Op()->Type() == "concat") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
"fetch_barrier");
}
}
// Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
int op_dev_id = -1;
if (op.Type() == "send") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (node->Op()->Type() == "send") {
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
// the variable name which contains .block means it was splited by
// split_byref op
// so that we can balance the variable blocks to all the pserver
// instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
node->inputs[0]->Name().find(".block") == std::string::npos) {
std::vector<std::string> input_var_names;
for (ir::Node *n : node->inputs) {
input_var_names.push_back(n->Name());
}
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
} else if (op.Type() == "recv") {
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
for (auto &varname : op.OutputArgumentNames()) {
} else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names;
for (ir::Node *n : node->outputs) {
output_var_names.push_back(n->Name());
}
op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
......@@ -582,18 +658,20 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
}
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send");
} else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send") {
node->Op()->Type());
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id]));
if (node->Op()->Type() == "send_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
} else if (node->Op()->Type() == "recv") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
"send_barrier");
} else if (node->Op()->Type() == "fetch_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
} else if (node->Op()->Type() == "send") {
// do nothing
} else {
PADDLE_THROW(
......@@ -601,12 +679,12 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
"send, send_barrier. recv, fetch_barrier]");
}
CreateOpHandleIOs(result, op, op_dev_id);
CreateOpHandleIOs(result, node, op_dev_id);
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
return boost::get<int>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode
......
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace platform {
......@@ -45,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy);
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t device_id) const;
void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const;
private:
std::string loss_var_name_;
......@@ -63,48 +62,46 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
platform::NCCLContextMap *nccl_ctxs_;
#endif
bool IsScaleLossOp(const OpDesc &op) const;
bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(Graph *result, ir::Node *node) const;
void CreateDistTrainOp(Graph *result, ir::Node *node) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
void CreateComputationalOps(Graph *result, ir::Node *node,
size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
int dev_id) const;
void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const;
bool IsParameterGradientOnce(
const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(const OpDesc &op) const;
int GetOpDeviceID(ir::Node *node) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
void InsertAllReduceOp(Graph *result, const std::string &og) const;
void InsertDataBalanceOp(SSAGraph *result,
void InsertDataBalanceOp(Graph *result,
const std::vector<std::string> &datas) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
void CreateBroadcastOp(Graph *result, const std::string &p_name,
size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const;
......
......@@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
void OpHandleBase::AddInput(VarHandleBase *in) {
this->inputs_.emplace_back(in);
in->pending_ops_.insert(this);
node_->inputs.push_back(in->Node());
in->AddOutput(this, this->Node());
}
void OpHandleBase::AddOutput(VarHandleBase *out) {
outputs_.emplace_back(out);
out->generated_op_ = this;
node_->outputs.push_back(out->Node());
out->AddInput(this, this->Node());
}
void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second);
}
}
}
......@@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() {
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]);
}
}
}
......@@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const {
}
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_;
return in_var && in_var->GeneratedOp();
}
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
......
......@@ -17,6 +17,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
......@@ -26,9 +27,11 @@ namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
class OpHandleBase {
public:
OpHandleBase() {}
explicit OpHandleBase(ir::Node *node) : node_(node) {}
virtual ~OpHandleBase();
......@@ -82,6 +85,8 @@ class OpHandleBase {
size_t NoDummyInputSize() const;
ir::Node *Node() { return node_; }
protected:
void RunAndRecordEvent(const std::function<void()> &callback);
......@@ -90,6 +95,7 @@ class OpHandleBase {
virtual void RunImpl() = 0;
ir::Node *node_;
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
......
......@@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
ReduceOpHandle(const std::vector<Scope *> &local_scopes,
ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
......@@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase {
}
}
#else
ReduceOpHandle(const std::vector<Scope *> &local_scopes,
ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string Name() const override;
......
......@@ -84,6 +84,7 @@ struct TestReduceOpHandle {
}
void InitReduceOp(size_t out_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
// init scope
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
......@@ -96,19 +97,21 @@ struct TestReduceOpHandle {
}
param_scopes_[out_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node("node"));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
gpu_list_, nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
gpu_list_, nccl_ctxs_.get()));
#else
op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_));
op_handle_.reset(
new ReduceOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
#endif
}
......@@ -118,8 +121,10 @@ struct TestReduceOpHandle {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
in_var_handle->generated_op_ = nullptr;
nodes.emplace_back(new ir::Node("node1"));
auto *in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
in_var_handle->ClearGeneratedOp();
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
......@@ -128,12 +133,13 @@ struct TestReduceOpHandle {
vars_.emplace_back(new DummyVarHandle());
DummyVarHandle *in_dummy_var_handle =
static_cast<DummyVarHandle *>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr;
in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle);
// add output
auto *out_var_handle =
new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]);
nodes.emplace_back(new ir::Node("node2"));
auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx,
"out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
......
......@@ -18,10 +18,11 @@ namespace paddle {
namespace framework {
namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
const Scope *local_scope, const std::string &name,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
name_(name),
place_(place) {}
......@@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK
continue;
}
if (in->generated_op_) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
......@@ -28,8 +28,9 @@ namespace framework {
namespace details {
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const std::string& name, const platform::Place& place);
RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc,
const Scope* local_scope, const std::string& name,
const platform::Place& place);
std::string Name() const override;
......
......@@ -19,10 +19,14 @@
namespace paddle {
namespace framework {
namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope,
platform::Place place,
platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {
: OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {
dev_ctxes_[place_] = dev_ctx;
}
......
......@@ -25,7 +25,8 @@ namespace framework {
namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place,
ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *context);
~ScaleLossGradOpHandle() final;
......
......@@ -17,6 +17,9 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
namespace paddle {
namespace framework {
namespace details {
// A SSA graph used by parallel executor.
struct SSAGraph {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a different version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
std::vector<std::unique_ptr<OpHandleBase>> ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -17,8 +17,8 @@
namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
for (auto &var_map : graph->vars_) {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
......@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = (*it_new)->generated_op_;
auto &read_ops = (*it_old)->pending_ops_;
OpHandleBase *write_op = (*it_new)->GeneratedOp();
const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
......@@ -37,10 +37,11 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
continue;
}
auto *dep_var = new DummyVarHandle();
auto *dep_var = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
}
}
}
......@@ -48,13 +49,20 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &var_holders = graph->vars_[place_offset];
auto &var_holder = var_holders[each_var_name];
Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr;
if (var_holder.empty()) {
var = new VarHandle(0, place_offset, each_var_name, place);
if (node->Var()) {
var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
node->Name(), place);
} else {
var = new VarHandle(
graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0,
place_offset, node->Name(), place);
}
var_holder.emplace_back(var);
} else {
var = var_holder.rbegin()->get();
......@@ -62,24 +70,26 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var;
}
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place);
auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) {
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle();
graph->dep_vars_.emplace(dummy_leaf);
auto *dummy_leaf = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}
......
......@@ -16,20 +16,42 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class SSAGraphBuilder {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder : public ir::Pass {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......@@ -42,20 +64,19 @@ class SSAGraphBuilder {
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(SSAGraph *graph);
static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
const std::string &each_var_name,
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node,
const platform::Place &place,
size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, const platform::Place &place,
size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph);
static void AddOutputToLeafOps(Graph *graph);
};
} // namespace details
} // namespace framework
......
......@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
......@@ -28,12 +28,12 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var);
if (var->generated_op_ == nullptr) {
if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var);
}
};
for (auto &var_map : graph->vars_) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
......@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
}
}
for (auto &var : graph->dep_vars_) {
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) {
insert_pending_var(var.get());
}
for (auto &op : graph->ops_) {
for (auto &op : graph->Get<GraphOps>("ops")) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
} else {
......@@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
for (auto *op : ready_var->PendingOps()) {
auto &deps = --pending_ops[op];
if (deps == 0) {
ready_ops.insert(op);
......
......@@ -21,7 +21,6 @@
namespace paddle {
namespace framework {
namespace details {
struct SSAGraph;
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
......@@ -29,17 +28,17 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return std::move(new_graph);
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
bool IsValidGraph(const SSAGraph* graph) const;
bool IsValidGraph(const Graph* graph) const;
private:
std::unique_ptr<SSAGraphBuilder> builder_;
......
......@@ -18,8 +18,8 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
......
......@@ -14,15 +14,15 @@
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
template <typename Callback>
static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
static inline void IterAllVar(const Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
......@@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
}
}
for (auto &var : graph.dep_vars_) {
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) {
callback(*var);
}
}
void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
void GraphvizSSAGraphPrinter::Print(const Graph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
......@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
});
size_t op_id = 0;
for (auto &op : graph.ops_) {
for (auto &op : graph.Get<GraphOps>("ops")) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......
......@@ -21,16 +21,16 @@
namespace paddle {
namespace framework {
namespace details {
struct SSAGraph;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0;
virtual void Print(const Graph& graph, std::ostream& sout) const = 0;
};
class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public:
void Print(const SSAGraph& graph, std::ostream& sout) const override;
void Print(const Graph& graph, std::ostream& sout) const override;
};
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
......@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
printer_->Print(*graph, stream_ref_);
return graph;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_);
return std::move(new_graph);
}
int GetVarDeviceID(const std::string& var_name) const override {
......
......@@ -14,13 +14,14 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph)
const std::vector<platform::Place> &places, std::unique_ptr<Graph> &&graph)
: graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr),
......@@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->vars_) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
}
}
}
for (auto &var : graph_->dep_vars_) {
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
InsertPendingVar(&pending_vars, &ready_vars, var.get());
}
for (auto &op : graph_->ops_) {
for (auto &op : graph_->Get<details::GraphOps>("ops")) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
......@@ -64,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
&pending_vars, &ready_vars, &fetch_data);
InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
&pending_ops, &pending_vars, &ready_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
......@@ -125,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
......@@ -151,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
......@@ -158,7 +161,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
......@@ -169,7 +172,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
&local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
......@@ -180,7 +186,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var);
}
auto *fetch_dummy = new DummyVarHandle();
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
......@@ -198,7 +205,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
pending_vars->insert(var);
if (var->generated_op_ == nullptr) {
if (var->GeneratedOp() == nullptr) {
ready_vars->Push(var);
}
}
......
......@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
......@@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph);
std::unique_ptr<Graph> &&graph);
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
......@@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op);
private:
std::unique_ptr<SSAGraph> graph_;
std::unique_ptr<Graph> graph_;
std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
......@@ -71,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
......
......@@ -13,11 +13,14 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <sstream>
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
......@@ -25,19 +28,60 @@ namespace framework {
namespace details {
class OpHandleBase;
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
//
// VarHandleBase is the var node in the dependency graph.
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {}
virtual ~VarHandleBase();
virtual std::string DebugString() const = 0;
void AddInput(OpHandleBase* in, ir::Node* node) {
node_->inputs.clear();
node_->inputs.push_back(node);
generated_op_ = in;
}
void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) {
pending_ops_.insert(out);
node_->outputs.push_back(node);
}
}
void RemoveOutput(OpHandleBase* out, ir::Node* node) {
pending_ops_.erase(out);
node_->outputs.erase(
std::remove(node_->outputs.begin(), node_->outputs.end(), node),
node_->outputs.end());
}
void ClearGeneratedOp() {
generated_op_ = nullptr;
node_->inputs.clear();
}
OpHandleBase* GeneratedOp() { return generated_op_; }
const std::unordered_set<OpHandleBase*>& PendingOps() const {
return pending_ops_;
}
ir::Node* Node() { return node_; }
protected:
// The operator who generate this variable. nullptr if the variable
// is a root node.
OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready.
std::unordered_set<OpHandleBase*> pending_ops_;
ir::Node* node_;
};
// VarHandle is actually a single version of Runtime Variable.
......@@ -46,11 +90,14 @@ struct VarHandleBase {
//
// NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override;
VarHandle(size_t version, size_t scope_index, std::string name,
platform::Place place)
: version_(version),
VarHandle(ir::Node* node, size_t version, size_t scope_index,
std::string name, platform::Place place)
: VarHandleBase(node),
version_(version),
scope_idx_(scope_index),
name_(std::move(name)),
place_(std::move(place)) {}
......@@ -70,6 +117,8 @@ struct VarHandle : public VarHandleBase {
// Dummy Variable. It is used to represent dependencies between operators
struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override;
};
......
cc_library(graph SRCS graph.cc node)
cc_library(node SRCS node.cc)
cc_library(pass SRCS pass.cc graph node)
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace framework {
// NOTE(paddle-dev): This graph contains circle.
Graph::Graph(const ProgramDesc &program) : program_(program) {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var);
}
std::map<std::string, ir::Node *> var_nodes;
for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op);
for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name);
} else if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
} else {
// TODO(paddle-dev): Seems some assumption doesn't hold?
LOG(ERROR) << op->Type()
<< " input var not in all_var list: " << each_var_name;
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name] = var;
}
node->inputs.push_back(var);
var->outputs.push_back(node);
}
for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name);
} else {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
}
node->outputs.push_back(var);
var->inputs.push_back(node);
}
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
class Graph {
public:
explicit Graph(const ProgramDesc& program);
virtual ~Graph() {
for (auto& attr : attrs_) {
attr_dels_[attr.first]();
}
attrs_.clear();
attr_dels_.clear();
}
template <typename AttrType>
AttrType& Get(const std::string& attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
}
template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
ir::Node* CreateVarNode(VarDesc* var_desc) {
nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get();
}
ir::Node* CreateOpNode(OpDesc* op_desc) {
nodes.emplace_back(new ir::Node(op_desc));
return nodes.back().get();
}
ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) {
nodes.emplace_back(new ir::Node(name, type));
return nodes.back().get();
}
std::vector<std::unique_ptr<ir::Node>> nodes;
private:
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc& program_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class NOP : public OperatorBase {
public:
NOP(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope &scope,
const platform::Place &place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker,
paddle::framework::SumOpVarTypeInference);
REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
paddle::framework::SumOpMaker);
namespace paddle {
namespace framework {
TEST(GraphTest, Basic) {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
std::unique_ptr<Graph> g(new Graph(prog));
ASSERT_EQ(g->nodes[0]->Name(), "sum");
ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a");
ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b");
ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c");
ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out");
ASSERT_EQ(g->nodes[1]->Name(), "test_a");
ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[2]->Name(), "test_b");
ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[3]->Name(), "test_c");
ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[4]->Name(), "test_out");
ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes.size(), 5);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace framework {} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
namespace ir {
class Node {
public:
enum class Type { kOperation, kVariable };
explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
var_desc_(var_desc),
op_desc_(nullptr),
type_(Type::kVariable) {}
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(op_desc),
type_(Type::kOperation) {}
Type NodeType() const { return type_; }
std::string Name() const { return name_; }
VarDesc* Var() {
PADDLE_ENFORCE(type_ == Type::kVariable);
return var_desc_;
}
OpDesc* Op() {
PADDLE_ENFORCE(type_ == Type::kOperation);
return op_desc_;
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
protected:
const std::string name_;
VarDesc* var_desc_;
OpDesc* op_desc_;
Type type_;
private:
DISABLE_COPY_AND_ASSIGN(Node);
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -18,6 +18,8 @@ limitations under the License. */
#include <tuple>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
......@@ -129,12 +131,11 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW("Not compiled with CUDA.");
#endif
}
builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph(new Graph(main_program));
graph = builder_->Apply(std::move(graph));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places,
builder_->Build(main_program)));
exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
......
......@@ -38,6 +38,7 @@ limitations under the License. */
#endif
#endif
#include <boost/any.hpp>
#include <boost/mpl/comparison.hpp>
#include <boost/mpl/less_equal.hpp>
#include <boost/variant.hpp>
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册