未验证 提交 7268760f 编写于 作者: Y yuyang18

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/combine_open_files_and_double_buffer
...@@ -18,7 +18,21 @@ learning to many products at Baidu. ...@@ -18,7 +18,21 @@ learning to many products at Baidu.
Our vision is to enable deep learning for everyone via PaddlePaddle. Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Lastest PaddlePaddle Version: [Fluid](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid)
### Latest PaddlePaddle Release: [Fluid 0.14.0](https://github.com/PaddlePaddle/Paddle/tree/v0.14.0)
### Install Latest Stable Release:
```
# Linux CPU
pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==0.14.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==0.14.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/
```
## Features ## Features
......
## Motivation
There is a ```gap``` between the ```Program``` defined by
user and the ```Executable``` that can be scheduled
efficiently on heterogeneous hardware, either locally
or distributedly.
Usually, the ```gap``` is bridged by
* A serious transformations with defined order.
* These transformations usually involve
```insert, delete, clustering, split, dependency analysis```.
* Has a simple way to verify and debug each transformation.
* Flexible to add, remove or customize transformations to fit
the requirements of various algorithms (models) and hardware secenarios.
Some other events also push us to a better unified pattern.
* The deep learning framework is built around the concepts of graphs.
To leverage tools such as compilation (e.g. TVM and nGraph) or
cross-framework conversion (e.g. ONNX), we also need a intermediate
representation that can be connected to the rest of the ecosystem.
We need a unified pattern to naturally support the requirements
described above. The pattern should fit both training, inference
and other offline serielized model transformations.
Learned from LLVM and other deep learning framework, we draft the
design below.
## Design
### Major Concepts
#### Node
```Node``` represents an operation that performs some computation or
a variable that is input or output of operation.
```Node```s are connected to other ```Node```s via inputs and outputs.
Other properties (maybe device placement information) can be added
to ```Node``` in the future if it's a
common requirement of many other ```Pass```es. Otherwise, it should live
in a ```Node``` wrapper class that is private to some ```Pass``` or be
a local member of a ```Pass```.
#### Graph
```Graph``` contains a list of ```Node```s, which are connected to
each other via inputs and outputs.
TODO: Better definitions for the graph.
```Graph``` can also contain ```Attribute```s. ```Attribute```s
can be ``any`` thing. For example, it can be a list of "wraper"
nodes. The ```wrapper``` nodes compose ```Node```s and provide
helper method for execution or transformation. ```Attribute```
can also contain other things that describe some properties of
the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed
across ```Pass```. However, it should be used with care.
#### Pass
```Pass``` represents a transformation of ```Graph```. Its input
is a ```Graph``` and its output is also a ```Graph```. For example,
a ```Pass``` can simply print out the ```Graph```. A ```Pass```
can also fuse some ```Graph```'s ```Node```s.
#### Optimize
```Optimize``` contains a series of ```Pass``` with defined order.
```Optimize``` transforms a ```Graph``` that only contains raw
modeling logic to a ```Graph``` that can be run efficiently while
maintaining the original modeling logic.
### Optimize Process
* Program is first converted to Graph.
* Graph goes through a series of Pass
* Graph is transformed from raw model logic to a
form that is efficient to execute.
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
...@@ -341,6 +341,26 @@ paddle.fluid.layers.polynomial_decay ArgSpec(args=['learning_rate', 'decay_steps ...@@ -341,6 +341,26 @@ paddle.fluid.layers.polynomial_decay ArgSpec(args=['learning_rate', 'decay_steps
paddle.fluid.layers.piecewise_decay ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.piecewise_decay ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.noam_decay ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.noam_decay ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.append_LARS ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.append_LARS ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.InitState.__init__ ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32'))
paddle.fluid.contrib.StateCell.__init__ ArgSpec(args=['self', 'inputs', 'states', 'out_state', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.StateCell.compute_state ArgSpec(args=['self', 'inputs'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.get_input ArgSpec(args=['self', 'input_name'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.get_state ArgSpec(args=['self', 'state_name'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.out_state ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.set_state ArgSpec(args=['self', 'state_name', 'state_value'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.state_updater ArgSpec(args=['self', 'updater'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.update_states ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.TrainingDecoder.__init__ ArgSpec(args=['self', 'state_cell', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.TrainingDecoder.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.contrib.TrainingDecoder.output ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None)
paddle.fluid.contrib.TrainingDecoder.static_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.TrainingDecoder.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.__init__ ArgSpec(args=['self', 'state_cell', 'init_ids', 'init_scores', 'target_dict_dim', 'word_dim', 'input_var_dict', 'topk_size', 'sparse_emb', 'max_len', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=({}, 50, True, 100, 1, 1, None))
paddle.fluid.contrib.BeamSearchDecoder.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.decode ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.early_stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.read_array ArgSpec(args=['self', 'init', 'is_ids', 'is_scores'], varargs=None, keywords=None, defaults=(False, False))
paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
......
add_subdirectory(details) add_subdirectory(details)
add_subdirectory(ir)
# ddim lib # ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
...@@ -93,7 +94,7 @@ else() ...@@ -93,7 +94,7 @@ else()
endif() endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod ...@@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
...@@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ...@@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -23,10 +23,14 @@ namespace framework { ...@@ -23,10 +23,14 @@ namespace framework {
namespace details { namespace details {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs) const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { : OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
...@@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes, ...@@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
} }
} }
#else #else
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
......
...@@ -30,11 +30,11 @@ namespace details { ...@@ -30,11 +30,11 @@ namespace details {
struct AllReduceOpHandle : public OpHandleBase { struct AllReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
AllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs); const platform::NCCLContextMap *ctxs);
#else #else
AllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
#endif #endif
std::string Name() const override; std::string Name() const override;
......
...@@ -35,10 +35,13 @@ namespace details { ...@@ -35,10 +35,13 @@ namespace details {
struct BroadcastOpHandle : public OpHandleBase { struct BroadcastOpHandle : public OpHandleBase {
public: public:
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs) const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { : OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) { for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
...@@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase {
} }
} }
#else #else
BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
std::string Name() const override; std::string Name() const override;
......
...@@ -96,48 +96,61 @@ struct TestBroadcastOpHandle { ...@@ -96,48 +96,61 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(
new ir::Node("node0", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset( op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); nccl_ctxs_.get()));
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset( op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); nccl_ctxs_.get()));
#else #else
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); op_handle_.reset(
new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_));
#endif #endif
} }
auto* in_var_handle = std::unique_ptr<ir::Node> v(
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); new ir::Node("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add dummy var // add dummy var
vars_.emplace_back(new DummyVarHandle());
std::unique_ptr<ir::Node> v2(
new ir::Node("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
dummy_var_handle->generated_op_ = nullptr; dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); std::unique_ptr<ir::Node> v3(
new ir::Node("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
// add dummy var // add dummy var
vars_.emplace_back(new DummyVarHandle()); std::unique_ptr<ir::Node> v4(
new ir::Node("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
out_dummy_var_handle->generated_op_ = nullptr; out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle); op_handle_->AddOutput(out_dummy_var_handle);
} }
......
...@@ -19,9 +19,10 @@ ...@@ -19,9 +19,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place) platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope), scope_(scope),
place_(place) {} place_(place) {}
...@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() { ...@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() {
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait = bool need_wait =
in_var && in_var->generated_op_ && in_var && in_var->GeneratedOp() &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait; return need_wait;
} }
......
...@@ -28,8 +28,7 @@ namespace framework { ...@@ -28,8 +28,7 @@ namespace framework {
namespace details { namespace details {
struct ComputationOpHandle : public OpHandleBase { struct ComputationOpHandle : public OpHandleBase {
public: public:
ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
platform::Place place);
std::string Name() const override; std::string Name() const override;
......
...@@ -22,10 +22,10 @@ namespace details { ...@@ -22,10 +22,10 @@ namespace details {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle::DataBalanceOpHandle( DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs) const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places) { : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) { if (ctxs) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs->DevCtx(p); this->dev_ctxes_[p] = ctxs->DevCtx(p);
...@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle( ...@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle(
} }
#else #else
DataBalanceOpHandle::DataBalanceOpHandle( DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
std::string DataBalanceOpHandle::Name() const { return "data balance"; } std::string DataBalanceOpHandle::Name() const { return "data balance"; }
......
...@@ -30,11 +30,11 @@ namespace details { ...@@ -30,11 +30,11 @@ namespace details {
struct DataBalanceOpHandle : public OpHandleBase { struct DataBalanceOpHandle : public OpHandleBase {
public: public:
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes, DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs); const platform::NCCLContextMap *ctxs);
#else #else
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes, DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
#endif #endif
......
...@@ -21,13 +21,16 @@ namespace paddle { ...@@ -21,13 +21,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset, FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes) std::vector<Scope *> *local_scopes)
: data_(data), offset_(offset), local_scopes_(local_scopes) {} : OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() { FetchOpHandle::~FetchOpHandle() {
for (auto *input_var : inputs_) { for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this); input_var->RemoveOutput(this, this->Node());
} }
} }
...@@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() { ...@@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() {
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) { void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place); auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) { for (auto *input : inputs_) {
if (input->generated_op_) { if (input->GeneratedOp()) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx); input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx);
} }
} }
} }
......
...@@ -28,7 +28,7 @@ namespace details { ...@@ -28,7 +28,7 @@ namespace details {
struct FetchOpHandle : public OpHandleBase { struct FetchOpHandle : public OpHandleBase {
public: public:
FetchOpHandle(FeedFetchList *data, size_t offset, FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes); std::vector<Scope *> *local_scopes);
~FetchOpHandle(); ~FetchOpHandle();
......
...@@ -30,10 +30,12 @@ namespace details { ...@@ -30,10 +30,12 @@ namespace details {
struct FuseVarsOpHandle : public OpHandleBase { struct FuseVarsOpHandle : public OpHandleBase {
public: public:
FuseVarsOpHandle(Scope *local_scope, const platform::Place &place, FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel, const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type) const std::type_index &var_type)
: local_scope_(local_scope), : OpHandleBase(node),
local_scope_(local_scope),
place_(place), place_(place),
inputs_numel_(inputs_numel), inputs_numel_(inputs_numel),
type_(var_type) { type_(var_type) {
......
...@@ -20,9 +20,10 @@ namespace paddle { ...@@ -20,9 +20,10 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes, GatherOpHandle::GatherOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() { void GatherOpHandle::RunImpl() {
if (places_.size() == 1) return; if (places_.size() == 1) return;
......
...@@ -30,7 +30,7 @@ namespace details { ...@@ -30,7 +30,7 @@ namespace details {
struct GatherOpHandle : public OpHandleBase { struct GatherOpHandle : public OpHandleBase {
public: public:
GatherOpHandle(const std::vector<Scope *> &local_scopes, GatherOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
std::string Name() const override; std::string Name() const override;
......
...@@ -70,6 +70,7 @@ struct TestGatherOpHandle { ...@@ -70,6 +70,7 @@ struct TestGatherOpHandle {
} }
void InitGatherOp(size_t input_scope_idx) { void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -81,30 +82,37 @@ struct TestGatherOpHandle { ...@@ -81,30 +82,37 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_)); nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation));
op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable));
auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
// add dummy var // add dummy var
vars_.emplace_back(new DummyVarHandle()); nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr; in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
auto* out_var_handle = nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable));
new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
vars_.emplace_back(new DummyVarHandle()); nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
op_handle_->AddOutput(dummy_var_handle); op_handle_->AddOutput(dummy_var_handle);
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -66,31 +67,38 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -66,31 +67,38 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
} }
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
const OpDesc &op,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
for (auto &each_var_name : op.InputArgumentNames()) { for (ir::Node *input : node->inputs) {
VarHandle *var = VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
op_handle->AddInput(var); op_handle->AddInput(var);
} }
for (auto &each_var_name : op.OutputArgumentNames()) { for (ir::Node *output : node->outputs) {
CreateOpOutput(result, op_handle, each_var_name, p, place_id); ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
}
CreateOpOutput(result, op_handle, new_node, p, place_id);
} }
} }
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const ProgramDesc &program) const { const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
std::vector<std::string> send_vars; std::vector<std::string> send_vars;
// since parameters are all in block 0, // since parameters are all in block 0,
// it's enough to only scan send ops in block 0 // it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) { for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find send op, // TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string // instead of the the hard code string
if (op->Type() == "send") { if (op->Type() == "send") {
...@@ -104,9 +112,11 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( ...@@ -104,9 +112,11 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
} }
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const { const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
std::vector<std::string> recv_vars; std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) { for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find recv op, // TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string // instead of the hard code string
if (op->Type() == "recv") { if (op->Type() == "recv") {
...@@ -120,7 +130,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( ...@@ -120,7 +130,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
} }
bool MultiDevSSAGraphBuilder::IsDistTrainOp( bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars, ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const { const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) { if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false; return false;
...@@ -143,8 +153,17 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -143,8 +153,17 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
return false; return false;
}; };
return checker(op.OutputArgumentNames(), send_vars) || std::vector<std::string> input_var_names;
checker(op.InputArgumentNames(), recv_vars); std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
return checker(output_var_names, send_vars) ||
checker(input_var_names, recv_vars);
} }
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
...@@ -167,25 +186,30 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -167,25 +186,30 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
return dev_id; return dev_id;
} }
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
const ProgramDesc &program) const { std::unique_ptr<Graph> graph) const {
for (auto *var : program.Block(0).AllVars()) { // Rebuild the graph structure.
all_vars_.emplace(var->Name(), var); auto nodes = std::move(graph->nodes);
graph->nodes.clear();
for (auto &node : nodes) {
if (node->NodeType() == ir::Node::Type::kVariable) {
all_vars_.emplace(node->Name(), node->Var());
}
} }
auto graph = new SSAGraph(); Graph &result = *graph;
SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8 // We cannot invoke resize. It is a bug of GCC 4.8
result.vars_ = std::vector< result.Set("vars", new GraphVars(places_.size()));
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( result.Set("dep_vars", new GraphDepVars);
places_.size()); result.Set("ops", new GraphOps);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// realted op in the place 0 // realted op in the place 0
auto send_vars = FindDistTrainSendVars(program); auto send_vars = FindDistTrainSendVars(nodes);
auto recv_vars = FindDistTrainRecvVars(program); auto recv_vars = FindDistTrainRecvVars(nodes);
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
...@@ -193,14 +217,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -193,14 +217,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
size_t cur_device_id = 0; size_t cur_device_id = 0;
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { // NOTE: Currently, passes before SSAGraphBuilder cannot reorder
// forward, backward nodes. E.g. you can't append an forward node
// at the end of the node list.
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
if (boost::get<int>( if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
CreateRPCOp(&result, *op); CreateRPCOp(&result, node.get());
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { } else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op); CreateDistTrainOp(&result, node.get());
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(node.get())) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
...@@ -212,33 +241,35 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -212,33 +241,35 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(*op); int op_dev_id = GetOpDeviceID(node.get());
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, node.get(), op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(var_name, op_dev_id); var_name_on_devices_.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
// gradients. // gradients.
if (op->Type() == "read" && strategy_.enable_data_balance_) { if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
op->SetAttr("throw_eof_exp", false); node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, node.get(), places_.size());
const auto &data_var_names = op->Output("Out"); // TODO(paddle-dev): builder shouldn't depend on the out logic of
// a specific op.
const auto &data_var_names = node->Op()->Output("Out");
InsertDataBalanceOp(&result, data_var_names); InsertDataBalanceOp(&result, data_var_names);
} else { } else {
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, node.get(), places_.size());
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(op->GetAttr( if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) { static_cast<int>(OpRole::kBackward))) {
try { try {
auto backward_vars = auto backward_vars = boost::get<std::vector<std::string>>(
boost::get<std::vector<std::string>>(op->GetNullableAttr( node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName())); OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
...@@ -302,8 +333,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -302,8 +333,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
return std::move(graph);
return std::unique_ptr<SSAGraph>(graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
...@@ -327,78 +357,96 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( ...@@ -327,78 +357,96 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif #endif
} }
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_); auto *op_handle = new BroadcastOpHandle(
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_);
#else #else
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); auto *op_handle = new BroadcastOpHandle(
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_);
#endif #endif
result->Get<GraphOps>("ops").emplace_back(op_handle);
result->ops_.emplace_back(op_handle); auto *in =
auto *in = result->vars_.at(src_dev_id).at(p_name).back().get(); result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->vars_.at(i).at(p_name); auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p); auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
i, p_name, p);
vars.emplace_back(out_var); vars.emplace_back(out_var);
op_handle->AddOutput(out_var); op_handle->AddOutput(out_var);
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
const OpDesc &op, ir::Node *node,
int dev_id) const { int dev_id) const {
result->ops_.emplace_back( result->Get<GraphOps>("ops").emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); new ComputationOpHandle(result->CreateOpNode(node->Op()),
CreateOpHandleIOs(result, op, dev_id); local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id);
} }
void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else #else
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_)); result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif #endif
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og]; auto &vars = result->Get<GraphVars>("vars")[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
auto var = new VarHandle(vars.size(), i, og, p); auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
} }
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
SSAGraph *result, const std::vector<std::string> &datas) const { Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else #else
result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif #endif
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) { for (const std::string &d_name : datas) {
auto &vars = result->vars_[i][d_name]; auto &vars = result->Get<GraphVars>("vars")[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back().get());
auto var = new VarHandle(vars.size(), i, d_name, p); auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -417,22 +465,22 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -417,22 +465,22 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
int op_role = boost::get<int>( int op_role = boost::get<int>(
op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
if (op_role != static_cast<int>(framework::OpRole::kOptimize)) { if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
return -1; return -1;
} }
auto param_grad = boost::get<std::vector<std::string>>( auto param_grad = boost::get<std::vector<std::string>>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(), PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]",
param_grad[0]); node->Op()->Type(), param_grad[0]);
return dev_id; return dev_id;
} }
...@@ -441,7 +489,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { ...@@ -441,7 +489,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return got == var_name_on_devices_.end() ? -1 : got->second; return got == var_name_on_devices_.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -452,11 +500,11 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { ...@@ -452,11 +500,11 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
auto *communication_dev_ctx = auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif #endif
auto *op_handle = new ScaleLossGradOpHandle(
auto *op_handle = result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], local_scopes_.size(), local_scopes_[i], places_[i],
places_[i], communication_dev_ctx); communication_dev_ctx);
result->ops_.emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators. // factor. So it does not depend on any other operators.
...@@ -464,43 +512,51 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { ...@@ -464,43 +512,51 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
// loss->pending_ops_.emplace_back(op_handle); // loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss); // op_handle->inputs_.emplace_back(loss);
CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i], CreateOpOutput(result, op_handle,
i); result->CreateEmptyNode(GradVarName(loss_var_name_),
ir::Node::Type::kVariable),
places_[i], i);
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
const OpDesc &op, ir::Node *node,
size_t num_places) const { size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); result->Get<GraphOps>("ops").emplace_back(
CreateOpHandleIOs(result, op, scope_idx); new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx);
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else #else
result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_)); result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif #endif
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og]; auto &vars = result->Get<GraphVars>("vars")[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->vars_[dst_dev_id][og]; auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
return var; return var;
...@@ -508,35 +564,46 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -508,35 +564,46 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend // Find the first occurence of `prev_op_name` and make current `op` depend
// on it. // on it.
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) { for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(); auto *dep_var = new DummyVarHandle(
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var); result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
} }
} }
} }
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
const OpDesc &op) const { ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") { std::vector<std::string> input_var_names;
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") {
op_dev_id = GetVarDeviceID(input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : op.InputArgumentNames()) { for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} }
for (auto &varname : op.OutputArgumentNames()) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} else if (op.Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); op_dev_id = GetVarDeviceID(input_var_names[0]);
for (auto &varname : op.OutputArgumentNames()) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} else { } else {
...@@ -546,34 +613,43 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, ...@@ -546,34 +613,43 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
} }
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s", op.Type()); "can not find right place for distributed op: %s",
node->Op()->Type());
CreateComputationalOp(result, op, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
if (op.Type() == "concat") { if (node->Op()->Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
"fetch_barrier");
} }
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
const OpDesc &op) const {
int op_dev_id = -1; int op_dev_id = -1;
if (op.Type() == "send") { if (node->Op()->Type() == "send") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
// split_byref op // split_byref op
// so that we can balance the variable blocks to all the pserver // so that we can balance the variable blocks to all the pserver
// instances. // instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) { node->inputs[0]->Name().find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); std::vector<std::string> input_var_names;
for (auto &varname : op.InputArgumentNames()) { for (ir::Node *n : node->inputs) {
input_var_names.push_back(n->Name());
}
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} }
} else if (op.Type() == "recv") { } else if (node->Op()->Type() == "recv") {
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames()); std::vector<std::string> output_var_names;
for (auto &varname : op.OutputArgumentNames()) { for (ir::Node *n : node->outputs) {
output_var_names.push_back(n->Name());
}
op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} else { } else {
...@@ -582,18 +658,20 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -582,18 +658,20 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
} }
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type()); node->Op()->Type());
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
op.Type(), places_[op_dev_id])); result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send"); if (node->Op()->Type() == "send_barrier") {
} else if (op.Type() == "recv") { ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
ConnectOp(result, result->ops_.back().get(), "send_barrier"); } else if (node->Op()->Type() == "recv") {
} else if (op.Type() == "fetch_barrier") { ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
ConnectOp(result, result->ops_.back().get(), "recv"); "send_barrier");
} else if (op.Type() == "send") { } else if (node->Op()->Type() == "fetch_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
} else if (node->Op()->Type() == "send") {
// do nothing // do nothing
} else { } else {
PADDLE_THROW( PADDLE_THROW(
...@@ -601,12 +679,12 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -601,12 +679,12 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
"send, send_barrier. recv, fetch_barrier]"); "send, send_barrier. recv, fetch_barrier]");
} }
CreateOpHandleIOs(result, op, op_dev_id); CreateOpHandleIOs(result, node, op_dev_id);
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
return boost::get<int>( return boost::get<int>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) | (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)) && static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode !loss_var_name_.empty(); // If loss_var is empty. This is test mode
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -45,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -45,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy); const BuildStrategy &strategy);
#endif #endif
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override;
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const;
size_t device_id) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
...@@ -63,48 +62,46 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -63,48 +62,46 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(Graph *result, ir::Node *node) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; void CreateDistTrainOp(Graph *result, ir::Node *node) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
*/ */
bool IsDistTrainOp(const OpDesc &op, bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const; const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const; const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars( std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const; const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op, void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const; const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op, void CreateComputationalOps(Graph *result, ir::Node *node,
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const; void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og, VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op, void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const;
int dev_id) const;
bool IsParameterGradientOnce( bool IsParameterGradientOnce(
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(const OpDesc &op) const; int GetOpDeviceID(ir::Node *node) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertAllReduceOp(Graph *result, const std::string &og) const;
void InsertDataBalanceOp(SSAGraph *result, void InsertDataBalanceOp(Graph *result,
const std::vector<std::string> &datas) const; const std::vector<std::string> &datas) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
......
...@@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { ...@@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
void OpHandleBase::AddInput(VarHandleBase *in) { void OpHandleBase::AddInput(VarHandleBase *in) {
this->inputs_.emplace_back(in); this->inputs_.emplace_back(in);
in->pending_ops_.insert(this); node_->inputs.push_back(in->Node());
in->AddOutput(this, this->Node());
} }
void OpHandleBase::AddOutput(VarHandleBase *out) { void OpHandleBase::AddOutput(VarHandleBase *out) {
outputs_.emplace_back(out); outputs_.emplace_back(out);
out->generated_op_ = this; node_->outputs.push_back(out->Node());
out->AddInput(this, this->Node());
} }
void OpHandleBase::WaitInputVarGenerated() { void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) { for (auto in_var : inputs_) {
if (NeedWait(in_var)) { if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) { for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second); in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second);
} }
} }
} }
...@@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() { ...@@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() {
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) { for (auto *in : inputs_) {
if (NeedWait(in)) { if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]);
} }
} }
} }
...@@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const { ...@@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const {
} }
bool OpHandleBase::NeedWait(VarHandleBase *in_var) { bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_; return in_var && in_var->GeneratedOp();
} }
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -26,9 +27,11 @@ namespace details { ...@@ -26,9 +27,11 @@ namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
class OpHandleBase { class OpHandleBase {
public: public:
OpHandleBase() {} explicit OpHandleBase(ir::Node *node) : node_(node) {}
virtual ~OpHandleBase(); virtual ~OpHandleBase();
...@@ -82,6 +85,8 @@ class OpHandleBase { ...@@ -82,6 +85,8 @@ class OpHandleBase {
size_t NoDummyInputSize() const; size_t NoDummyInputSize() const;
ir::Node *Node() { return node_; }
protected: protected:
void RunAndRecordEvent(const std::function<void()> &callback); void RunAndRecordEvent(const std::function<void()> &callback);
...@@ -90,6 +95,7 @@ class OpHandleBase { ...@@ -90,6 +95,7 @@ class OpHandleBase {
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
ir::Node *node_;
std::vector<VarHandleBase *> inputs_; std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_; std::vector<VarHandleBase *> outputs_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_; std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
......
...@@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
ReduceOpHandle(const std::vector<Scope *> &local_scopes, ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs) const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { : OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) { for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
...@@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase {
} }
} }
#else #else
ReduceOpHandle(const std::vector<Scope *> &local_scopes, ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
std::string Name() const override; std::string Name() const override;
......
...@@ -84,6 +84,7 @@ struct TestReduceOpHandle { ...@@ -84,6 +84,7 @@ struct TestReduceOpHandle {
} }
void InitReduceOp(size_t out_scope_idx) { void InitReduceOp(size_t out_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
// init scope // init scope
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
...@@ -96,19 +97,21 @@ struct TestReduceOpHandle { ...@@ -96,19 +97,21 @@ struct TestReduceOpHandle {
} }
param_scopes_[out_scope_idx]->Var("out"); param_scopes_[out_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node("node"));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset( op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); gpu_list_, nccl_ctxs_.get()));
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset( op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); gpu_list_, nccl_ctxs_.get()));
#else #else
op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_)); op_handle_.reset(
new ReduceOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
#endif #endif
} }
...@@ -118,8 +121,10 @@ struct TestReduceOpHandle { ...@@ -118,8 +121,10 @@ struct TestReduceOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); nodes.emplace_back(new ir::Node("node1"));
in_var_handle->generated_op_ = nullptr; auto *in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
in_var_handle->ClearGeneratedOp();
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
...@@ -128,12 +133,13 @@ struct TestReduceOpHandle { ...@@ -128,12 +133,13 @@ struct TestReduceOpHandle {
vars_.emplace_back(new DummyVarHandle()); vars_.emplace_back(new DummyVarHandle());
DummyVarHandle *in_dummy_var_handle = DummyVarHandle *in_dummy_var_handle =
static_cast<DummyVarHandle *>(vars_.back().get()); static_cast<DummyVarHandle *>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr; in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
auto *out_var_handle = nodes.emplace_back(new ir::Node("node2"));
new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]); auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx,
"out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
......
...@@ -18,10 +18,11 @@ namespace paddle { ...@@ -18,10 +18,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
const Scope *local_scope, const std::string &name, const Scope *local_scope, const std::string &name,
const platform::Place &place) const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope), local_scope_(local_scope),
name_(name), name_(name),
place_(place) {} place_(place) {}
...@@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() { ...@@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK if (in->DebugString() == "dummy") { // HACK
continue; continue;
} }
if (in->generated_op_) { if (in->GeneratedOp()) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]);
} }
} }
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
...@@ -28,8 +28,9 @@ namespace framework { ...@@ -28,8 +28,9 @@ namespace framework {
namespace details { namespace details {
struct RPCOpHandle : public OpHandleBase { struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc,
const std::string& name, const platform::Place& place); const Scope* local_scope, const std::string& name,
const platform::Place& place);
std::string Name() const override; std::string Name() const override;
......
...@@ -19,10 +19,14 @@ ...@@ -19,10 +19,14 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope,
platform::Place place, platform::Place place,
platform::DeviceContext *dev_ctx) platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) { : OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {
dev_ctxes_[place_] = dev_ctx; dev_ctxes_[place_] = dev_ctx;
} }
......
...@@ -25,7 +25,8 @@ namespace framework { ...@@ -25,7 +25,8 @@ namespace framework {
namespace details { namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase { struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place, ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *context); platform::DeviceContext *context);
~ScaleLossGradOpHandle() final; ~ScaleLossGradOpHandle() final;
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
namespace paddle {
namespace framework {
namespace details {
// A SSA graph used by parallel executor.
struct SSAGraph {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a different version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
std::vector<std::unique_ptr<OpHandleBase>> ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
for (auto &var_map : graph->vars_) { for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
continue; continue;
...@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { ...@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto it_old = name_pair.second.rbegin(); auto it_old = name_pair.second.rbegin();
++it_old; ++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = (*it_new)->generated_op_; OpHandleBase *write_op = (*it_new)->GeneratedOp();
auto &read_ops = (*it_old)->pending_ops_; const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) { for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op; // Manually add a dependency var from read_op to write_op;
...@@ -37,10 +37,11 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { ...@@ -37,10 +37,11 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
continue; continue;
} }
auto *dep_var = new DummyVarHandle(); auto *dep_var = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var); graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
} }
} }
} }
...@@ -48,13 +49,20 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { ...@@ -48,13 +49,20 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
} }
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name, Graph *graph, ir::Node *node, const platform::Place &place,
const platform::Place &place, size_t place_offset) { size_t place_offset) {
auto &var_holders = graph->vars_[place_offset]; auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[each_var_name]; auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
var = new VarHandle(0, place_offset, each_var_name, place); if (node->Var()) {
var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
node->Name(), place);
} else {
var = new VarHandle(
graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0,
place_offset, node->Name(), place);
}
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
var = var_holder.rbegin()->get(); var = var_holder.rbegin()->get();
...@@ -62,24 +70,26 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -62,24 +70,26 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var; return var;
} }
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, ir::Node *new_node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name]; auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
size_t version = vars.size(); size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place); auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
for (auto &op : graph->ops_) { for (auto &op : graph->Get<GraphOps>("ops")) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(); auto *dummy_leaf = new DummyVarHandle(
graph->dep_vars_.emplace(dummy_leaf); graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
} }
......
...@@ -16,20 +16,42 @@ ...@@ -16,20 +16,42 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class SSAGraphBuilder { // all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder : public ir::Pass {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0; virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
...@@ -42,20 +64,19 @@ class SSAGraphBuilder { ...@@ -42,20 +64,19 @@ class SSAGraphBuilder {
* *
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/ */
static void PolishGraphToSupportDataHazards(SSAGraph *graph); static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node,
const std::string &each_var_name,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, ir::Node *new_node, const platform::Place &place,
const platform::Place &place, size_t place_offset); size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph); static void AddOutputToLeafOps(Graph *graph);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_checker.h" #include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars; std::unordered_set<VarHandleBase *> ready_vars;
...@@ -28,12 +28,12 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { ...@@ -28,12 +28,12 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
auto insert_pending_var = [&](VarHandleBase *var) { auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var); pending_vars.insert(var);
if (var->generated_op_ == nullptr) { if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var); ready_vars.emplace(var);
} }
}; };
for (auto &var_map : graph->vars_) { for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair.get());
...@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { ...@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
} }
} }
for (auto &var : graph->dep_vars_) { for (auto &var : graph->Get<GraphDepVars>("dep_vars")) {
insert_pending_var(var.get()); insert_pending_var(var.get());
} }
for (auto &op : graph->ops_) { for (auto &op : graph->Get<GraphOps>("ops")) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
...@@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { ...@@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
for (auto ready_var : ready_vars) { for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) { for (auto *op : ready_var->PendingOps()) {
auto &deps = --pending_ops[op]; auto &deps = --pending_ops[op];
if (deps == 0) { if (deps == 0) {
ready_ops.insert(op); ready_ops.insert(op);
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct SSAGraph;
class SSAGraghBuilderWithChecker : public SSAGraphBuilder { class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public: public:
...@@ -29,17 +28,17 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -29,17 +28,17 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder) std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {} : builder_(std::move(builder)) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override { std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto graph = builder_->Build(program); auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(graph.get())); PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return graph; return std::move(new_graph);
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name); return builder_->GetVarDeviceID(var_name);
} }
bool IsValidGraph(const SSAGraph* graph) const; bool IsValidGraph(const Graph* graph) const;
private: private:
std::unique_ptr<SSAGraphBuilder> builder_; std::unique_ptr<SSAGraphBuilder> builder_;
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
#include "paddle/fluid/framework/details/ssa_graph_printer.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string> #include <string>
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
template <typename Callback> template <typename Callback>
static inline void IterAllVar(const SSAGraph &graph, Callback callback) { static inline void IterAllVar(const Graph &graph, Callback callback) {
for (auto &each : graph.vars_) { for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
callback(*pair2); callback(*pair2);
...@@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) { ...@@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
} }
} }
for (auto &var : graph.dep_vars_) { for (auto &var : graph.Get<GraphDepVars>("dep_vars")) {
callback(*var); callback(*var);
} }
} }
void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, void GraphvizSSAGraphPrinter::Print(const Graph &graph,
std::ostream &sout) const { std::ostream &sout) const {
size_t var_id = 0; size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars; std::unordered_map<const VarHandleBase *, size_t> vars;
...@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, ...@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.ops_) { for (auto &op : graph.Get<GraphOps>("ops")) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
......
...@@ -21,16 +21,16 @@ ...@@ -21,16 +21,16 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct SSAGraph;
class SSAGraphPrinter { class SSAGraphPrinter {
public: public:
virtual ~SSAGraphPrinter() {} virtual ~SSAGraphPrinter() {}
virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0; virtual void Print(const Graph& graph, std::ostream& sout) const = 0;
}; };
class GraphvizSSAGraphPrinter : public SSAGraphPrinter { class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public: public:
void Print(const SSAGraph& graph, std::ostream& sout) const override; void Print(const Graph& graph, std::ostream& sout) const override;
}; };
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
...@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)), stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {} stream_ref_(*stream_ptr_) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override { std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto graph = builder_->Build(program); auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*graph, stream_ref_); printer_->Print(*new_graph, stream_ref_);
return graph; return std::move(new_graph);
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
......
...@@ -14,13 +14,14 @@ ...@@ -14,13 +14,14 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places, std::unique_ptr<Graph> &&graph)
std::unique_ptr<SSAGraph> &&graph)
: graph_(std::move(graph)), : graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr), : nullptr),
...@@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars // Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->vars_) { for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
} }
} }
} }
for (auto &var : graph_->dep_vars_) { for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
InsertPendingVar(&pending_vars, &ready_vars, var.get()); InsertPendingVar(&pending_vars, &ready_vars, var.get());
} }
for (auto &op : graph_->ops_) { for (auto &op : graph_->Get<details::GraphOps>("ops")) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
...@@ -64,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -64,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
&pending_vars, &ready_vars, &fetch_data); &pending_ops, &pending_vars, &ready_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
...@@ -125,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -125,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Find the ready_ops after the ready_var. // Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) { for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) { for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op]; auto &deps = pending_ops[op];
--deps; --deps;
if (deps == 0) { if (deps == 0) {
...@@ -151,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -151,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
...@@ -158,7 +161,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -158,7 +161,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) { for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
...@@ -168,14 +171,16 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -168,14 +171,16 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name); auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program " "Cannot find fetched variable.(Perhaps the main_program "
"is not set to ParallelExecutor)"); "is not set to ParallelExecutor)");
auto &vars = fetched_var_it->second; auto &vars = fetched_var_it->second;
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
&local_scopes_);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
...@@ -186,7 +191,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -186,7 +191,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var); op->AddInput(var);
} }
auto *fetch_dummy = new DummyVarHandle(); temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
op->AddOutput(fetch_dummy); op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy); fetch_dependencies->emplace(fetch_dummy);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
...@@ -204,7 +210,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ...@@ -204,7 +210,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const { BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
pending_vars->insert(var); pending_vars->insert(var);
if (var->generated_op_ == nullptr) { if (var->GeneratedOp() == nullptr) {
ready_vars->Push(var); ready_vars->Push(var);
} }
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph); std::unique_ptr<Graph> &&graph);
// Run a SSAGraph by a thread pool // Run a SSAGraph by a thread pool
// Use topological sort algorithm // Use topological sort algorithm
...@@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
std::unique_ptr<SSAGraph> graph_; std::unique_ptr<Graph> graph_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
...@@ -71,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -71,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void InsertFetchOps( void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
......
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
...@@ -25,19 +28,60 @@ namespace framework { ...@@ -25,19 +28,60 @@ namespace framework {
namespace details { namespace details {
class OpHandleBase; class OpHandleBase;
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
//
// VarHandleBase is the var node in the dependency graph. // VarHandleBase is the var node in the dependency graph.
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
// This is a single assignment graph. // This is a single assignment graph.
struct VarHandleBase { struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {}
virtual ~VarHandleBase(); virtual ~VarHandleBase();
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
void AddInput(OpHandleBase* in, ir::Node* node) {
node_->inputs.clear();
node_->inputs.push_back(node);
generated_op_ = in;
}
void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) {
pending_ops_.insert(out);
node_->outputs.push_back(node);
}
}
void RemoveOutput(OpHandleBase* out, ir::Node* node) {
pending_ops_.erase(out);
node_->outputs.erase(
std::remove(node_->outputs.begin(), node_->outputs.end(), node),
node_->outputs.end());
}
void ClearGeneratedOp() {
generated_op_ = nullptr;
node_->inputs.clear();
}
OpHandleBase* GeneratedOp() { return generated_op_; }
const std::unordered_set<OpHandleBase*>& PendingOps() const {
return pending_ops_;
}
ir::Node* Node() { return node_; }
protected:
// The operator who generate this variable. nullptr if the variable // The operator who generate this variable. nullptr if the variable
// is a root node. // is a root node.
OpHandleBase* generated_op_{nullptr}; OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready. // Operators which depend on this variable ready.
std::unordered_set<OpHandleBase*> pending_ops_; std::unordered_set<OpHandleBase*> pending_ops_;
ir::Node* node_;
}; };
// VarHandle is actually a single version of Runtime Variable. // VarHandle is actually a single version of Runtime Variable.
...@@ -46,11 +90,14 @@ struct VarHandleBase { ...@@ -46,11 +90,14 @@ struct VarHandleBase {
// //
// NOTE: runtime variables have place. // NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override; std::string DebugString() const override;
VarHandle(size_t version, size_t scope_index, std::string name, VarHandle(ir::Node* node, size_t version, size_t scope_index,
platform::Place place) std::string name, platform::Place place)
: version_(version), : VarHandleBase(node),
version_(version),
scope_idx_(scope_index), scope_idx_(scope_index),
name_(std::move(name)), name_(std::move(name)),
place_(std::move(place)) {} place_(std::move(place)) {}
...@@ -70,6 +117,8 @@ struct VarHandle : public VarHandleBase { ...@@ -70,6 +117,8 @@ struct VarHandle : public VarHandleBase {
// Dummy Variable. It is used to represent dependencies between operators // Dummy Variable. It is used to represent dependencies between operators
struct DummyVarHandle : public VarHandleBase { struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override; std::string DebugString() const override;
}; };
......
cc_library(graph SRCS graph.cc node)
cc_library(node SRCS node.cc)
cc_library(pass SRCS pass.cc graph node)
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace framework {
// NOTE(paddle-dev): This graph contains circle.
Graph::Graph(const ProgramDesc &program) : program_(program) {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var);
}
std::map<std::string, ir::Node *> var_nodes;
for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op);
for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name);
} else if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
} else {
// TODO(paddle-dev): Seems some assumption doesn't hold?
LOG(ERROR) << op->Type()
<< " input var not in all_var list: " << each_var_name;
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name] = var;
}
node->inputs.push_back(var);
var->outputs.push_back(node);
}
for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name);
} else {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
}
node->outputs.push_back(var);
var->inputs.push_back(node);
}
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
class Graph {
public:
explicit Graph(const ProgramDesc& program);
virtual ~Graph() {
for (auto& attr : attrs_) {
attr_dels_[attr.first]();
}
attrs_.clear();
attr_dels_.clear();
}
template <typename AttrType>
AttrType& Get(const std::string& attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
}
template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
ir::Node* CreateVarNode(VarDesc* var_desc) {
nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get();
}
ir::Node* CreateOpNode(OpDesc* op_desc) {
nodes.emplace_back(new ir::Node(op_desc));
return nodes.back().get();
}
ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) {
nodes.emplace_back(new ir::Node(name, type));
return nodes.back().get();
}
std::vector<std::unique_ptr<ir::Node>> nodes;
private:
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc& program_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class NOP : public OperatorBase {
public:
NOP(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope &scope,
const platform::Place &place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker,
paddle::framework::SumOpVarTypeInference);
REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
paddle::framework::SumOpMaker);
namespace paddle {
namespace framework {
TEST(GraphTest, Basic) {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
std::unique_ptr<Graph> g(new Graph(prog));
ASSERT_EQ(g->nodes[0]->Name(), "sum");
ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a");
ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b");
ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c");
ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out");
ASSERT_EQ(g->nodes[1]->Name(), "test_a");
ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[2]->Name(), "test_b");
ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[3]->Name(), "test_c");
ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[4]->Name(), "test_out");
ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes.size(), 5);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace framework {} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
namespace ir {
class Node {
public:
enum class Type { kOperation, kVariable };
explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
var_desc_(var_desc),
op_desc_(nullptr),
type_(Type::kVariable) {}
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(op_desc),
type_(Type::kOperation) {}
Type NodeType() const { return type_; }
std::string Name() const { return name_; }
VarDesc* Var() {
PADDLE_ENFORCE(type_ == Type::kVariable);
return var_desc_;
}
OpDesc* Op() {
PADDLE_ENFORCE(type_ == Type::kOperation);
return op_desc_;
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
protected:
const std::string name_;
VarDesc* var_desc_;
OpDesc* op_desc_;
Type type_;
private:
DISABLE_COPY_AND_ASSIGN(Node);
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
...@@ -129,12 +131,11 @@ ParallelExecutor::ParallelExecutor( ...@@ -129,12 +131,11 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW("Not compiled with CUDA."); PADDLE_THROW("Not compiled with CUDA.");
#endif #endif
} }
builder_ = builder_factory.Create(); builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph(new Graph(main_program));
graph = builder_->Apply(std::move(graph));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, exec_strategy, member_->local_scopes_, places, std::move(graph)));
builder_->Build(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_))); member_->places_, std::move(member_->executor_)));
......
...@@ -19,10 +19,14 @@ function (inference_analysis_test TARGET) ...@@ -19,10 +19,14 @@ function (inference_analysis_test TARGET)
set(multiValueArgs SRCS) set(multiValueArgs SRCS)
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(mem_opt "")
if(WITH_GPU)
set(mem_opt "--fraction_of_gpu_memory_to_use=0.5")
endif()
cc_test(${TARGET} cc_test(${TARGET}
SRCS "${analysis_test_SRCS}" SRCS "${analysis_test_SRCS}"
DEPS analysis DEPS analysis
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5) ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING) endif(WITH_TESTING)
endfunction(inference_analysis_test) endfunction(inference_analysis_test)
......
...@@ -66,6 +66,7 @@ bool NativePaddlePredictor::Init( ...@@ -66,6 +66,7 @@ bool NativePaddlePredictor::Init(
if (parent_scope) { if (parent_scope) {
scope_ = parent_scope; scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope()); sub_scope_ = &(parent_scope->NewScope());
PADDLE_ENFORCE_NOT_NULL(sub_scope_, "create sub scope fail");
} else { } else {
paddle::framework::InitDevices(false); paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope()); scope_.reset(new paddle::framework::Scope());
...@@ -102,7 +103,6 @@ bool NativePaddlePredictor::Init( ...@@ -102,7 +103,6 @@ bool NativePaddlePredictor::Init(
NativePaddlePredictor::~NativePaddlePredictor() { NativePaddlePredictor::~NativePaddlePredictor() {
if (sub_scope_) { if (sub_scope_) {
PADDLE_ENFORCE_NOT_NULL(scope_, "Should have parent scope!");
scope_->DeleteScope(sub_scope_); scope_->DeleteScope(sub_scope_);
} }
} }
......
...@@ -57,4 +57,4 @@ By specifying the engine kind and config, one can get a specific implementation. ...@@ -57,4 +57,4 @@ By specifying the engine kind and config, one can get a specific implementation.
## Reference ## Reference
- [paddle_inference_api.h](./paddle_inference_api.h) - [paddle_inference_api.h](./paddle_inference_api.h)
- [some demos](./demo) - [some demos](./demo_ci)
...@@ -83,5 +83,5 @@ CHECK(predictor->Run(slots, &outputs)); ...@@ -83,5 +83,5 @@ CHECK(predictor->Run(slots, &outputs));
## 详细代码参考 ## 详细代码参考
- [inference demos](./demo) - [inference demos](./demo_ci)
- [复杂单线程/多线程例子](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/contrib/inference/test_paddle_inference_api_impl.cc) - [复杂单线程/多线程例子](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/api/test_api_impl.cc)
...@@ -38,6 +38,7 @@ limitations under the License. */ ...@@ -38,6 +38,7 @@ limitations under the License. */
#endif #endif
#endif #endif
#include <boost/any.hpp>
#include <boost/mpl/comparison.hpp> #include <boost/mpl/comparison.hpp>
#include <boost/mpl/less_equal.hpp> #include <boost/mpl/less_equal.hpp>
#include <boost/variant.hpp> #include <boost/variant.hpp>
...@@ -248,15 +248,11 @@ PYBIND11_PLUGIN(core) { ...@@ -248,15 +248,11 @@ PYBIND11_PLUGIN(core) {
#endif #endif
}) })
.def("rows", [](SelectedRows &self) { .def("rows", [](SelectedRows &self) {
#ifndef PADDLE_WITH_CUDA
return self.rows();
#else
auto rows = self.rows(); auto rows = self.rows();
std::vector<int64_t> new_rows; std::vector<int64_t> new_rows;
new_rows.reserve(rows.size()); new_rows.reserve(rows.size());
std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows)); std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
return new_rows; return new_rows;
#endif
}); });
py::class_<Variable>(m, "Variable", R"DOC(Variable Class. py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
......
...@@ -30,7 +30,9 @@ class RecordIOWriter { ...@@ -30,7 +30,9 @@ class RecordIOWriter {
public: public:
RecordIOWriter(const std::string& filename, recordio::Compressor compressor, RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
size_t max_num_record) size_t max_num_record)
: stream_(filename), writer_(&stream_, compressor, max_num_record) {} : closed_(false),
stream_(filename),
writer_(&stream_, compressor, max_num_record) {}
void AppendTensor(const framework::LoDTensor& tensor) { void AppendTensor(const framework::LoDTensor& tensor) {
tensors_.push_back(tensor); tensors_.push_back(tensor);
...@@ -47,9 +49,17 @@ class RecordIOWriter { ...@@ -47,9 +49,17 @@ class RecordIOWriter {
PADDLE_ENFORCE(tensors_.empty()); PADDLE_ENFORCE(tensors_.empty());
writer_.Flush(); writer_.Flush();
stream_.close(); stream_.close();
closed_ = true;
}
~RecordIOWriter() {
if (!closed_) {
Close();
}
} }
private: private:
bool closed_;
std::vector<framework::LoDTensor> tensors_; std::vector<framework::LoDTensor> tensors_;
std::ofstream stream_; std::ofstream stream_;
recordio::Writer writer_; recordio::Writer writer_;
......
...@@ -68,8 +68,14 @@ def reader_creator(image_filename, label_filename, buffer_size): ...@@ -68,8 +68,14 @@ def reader_creator(image_filename, label_filename, buffer_size):
for i in xrange(buffer_size): for i in xrange(buffer_size):
yield images[i, :], int(labels[i]) yield images[i, :], int(labels[i])
finally: finally:
try:
m.terminate() m.terminate()
except:
pass
try:
l.terminate() l.terminate()
except:
pass
return reader return reader
......
...@@ -35,6 +35,7 @@ import io ...@@ -35,6 +35,7 @@ import io
import evaluator import evaluator
import initializer import initializer
import layers import layers
import contrib
import nets import nets
import optimizer import optimizer
import backward import backward
...@@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \ ...@@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
'contrib',
'transpiler', 'transpiler',
'nets', 'nets',
'optimizer', 'optimizer',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import decoder
from decoder import *
__all__ = decoder.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import beam_search_decoder
from beam_search_decoder import *
__all__ = beam_search_decoder.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provides a general beam search decoder API for RNN based decoders.
The purpose of this API is to allow users to highly customize the behavior
within their RNN decoder(vanilla RNN, LSTM, attention + LSTM, future etc.),
without using the low level API such as while ops.
This API is still under active development and may change drastically.
"""
import contextlib
import numpy as np
from ... import layers
from ...framework import Variable
from ... import core
from ... import framework, unique_name
from ...layer_helper import LayerHelper
__all__ = ['InitState', 'StateCell', 'TrainingDecoder', 'BeamSearchDecoder']
class _DecoderType:
TRAINING = 1
BEAM_SEARCH = 2
class InitState(object):
"""
The initial hidden state object. The state objects holds a variable, and may
use it to initialize the hidden state cell of RNN. Usually used as input to
`StateCell` class.
Args:
init (Variable): The initial variable of the hidden state. If set None,
the variable will be created as a tensor with constant value based
on `shape` and `value` param.
shape (tuple|list): If `init` is None, new Variable's shape. Default
None.
value (float): If `init` is None, new Variable's value. Default None.
init_boot (Variable): If provided, the initial variable will be created
with the same shape as this variable.
need_reorder (bool): If set true, the init will be sorted by its lod
rank within its batches. This should be used if `batch_size > 1`.
dtype (np.dtype|core.VarDesc.VarType|str): Data type of the initial
variable.
Returns:
An initialized state object.
Examples:
See `StateCell`.
"""
def __init__(self,
init=None,
shape=None,
value=0.0,
init_boot=None,
need_reorder=False,
dtype='float32'):
if init is not None:
self._init = init
elif init_boot is None:
raise ValueError(
'init_boot must be provided to infer the shape of InitState .\n')
else:
self._init = layers.fill_constant_batch_size_like(
input=init_boot, value=value, shape=shape, dtype=dtype)
self._shape = shape
self._value = value
self._need_reorder = need_reorder
self._dtype = dtype
@property
def value(self):
return self._init
@property
def need_reorder(self):
return self._need_reorder
class _MemoryState(object):
def __init__(self, state_name, rnn_obj, init_state):
self._state_name = state_name # each is a rnn.memory
self._rnn_obj = rnn_obj
self._state_mem = self._rnn_obj.memory(
init=init_state.value, need_reorder=init_state.need_reorder)
def get_state(self):
return self._state_mem
def update_state(self, state):
self._rnn_obj.update_memory(self._state_mem, state)
class _ArrayState(object):
def __init__(self, state_name, block, init_state):
self._state_name = state_name
self._block = block
self._state_array = self._block.create_var(
name=unique_name.generate('array_state_array'),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=init_state.value.dtype)
self._counter = self._block.create_var(
name=unique_name.generate('array_state_counter'),
type=core.VarDesc.VarType.LOD_TENSOR,
dtype='int64')
# initialize counter
self._block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [self._counter]},
attrs={
'shape': [1],
'dtype': self._counter.dtype,
'value': float(0.0),
'force_cpu': True
})
self._counter.stop_gradient = True
# write initial state
block.append_op(
type='write_to_array',
inputs={'X': init_state.value,
'I': self._counter},
outputs={'Out': self._state_array})
def get_state(self):
state = layers.array_read(array=self._state_array, i=self._counter)
return state
def update_state(self, state):
layers.increment(x=self._counter, value=1, in_place=True)
layers.array_write(state, array=self._state_array, i=self._counter)
class StateCell(object):
"""
The state cell class stores the hidden state of the RNN cell. A typical RNN
cell has one or more hidden states, and one or more step inputs. This class
allows you to defines the name of hidden states as well as step inputs, and
their associated variables.
Args:
inputs (dict): A feeding dict of {name(str) : Variable}. It specifies
the names of step inputs for RNN cell, and the associated variables.
The variable could initially be None and set manually during each
RNN step.
states (dict): A feeding dict of {name(str) : InitState object}. It
specifies the names of hidden states and their initialized state.
out_state (str): A string that specifies the name of hidden state that
will be used to compute the score in beam search process.
name (str): The name of the RNN cell. Default None.
Raises:
`ValueError`: If the initial state is not an instance of InitState, or
the out_state is not in the dict of states.
Returns:
StateCell: The initialized StateCell object.
Examples:
.. code-block:: python
hidden_state = InitState(init=encoder_out, need_reorder=True)
state_cell = StateCell(
inputs={'current_word': None},
states={'h': hidden_state},
out_state='h')
"""
def __init__(self, inputs, states, out_state, name=None):
self._helper = LayerHelper('state_cell', name=name)
self._cur_states = {}
self._state_names = []
for state_name, state in states.items():
if not isinstance(state, InitState):
raise ValueError('state must be an InitState object.')
self._cur_states[state_name] = state
self._state_names.append(state_name)
self._inputs = inputs # inputs is place holder here
self._cur_decoder_obj = None
self._in_decoder = False
self._states_holder = {}
self._switched_decoder = False
self._state_updater = None
self._out_state = out_state
if self._out_state not in self._cur_states:
raise ValueError('out_state must be one state in states')
def _enter_decoder(self, decoder_obj):
if self._in_decoder == True or self._cur_decoder_obj is not None:
raise ValueError('StateCell has already entered a decoder.')
self._in_decoder = True
self._cur_decoder_obj = decoder_obj
self._switched_decoder = False
def _leave_decoder(self, decoder_obj):
if not self._in_decoder:
raise ValueError('StateCell not in decoder, '
'invalid leaving operation.')
if self._cur_decoder_obj != decoder_obj:
raise ValueError('Inconsistent decoder object in StateCell.')
self._in_decoder = False
self._cur_decoder_obj = None
self._switched_decoder = False
def _switch_decoder(self): # lazy switch
if not self._in_decoder:
raise ValueError('StateCell must be enter a decoder.')
if self._switched_decoder:
raise ValueError('StateCell already done switching.')
for state_name in self._state_names:
if state_name not in self._states_holder:
state = self._cur_states[state_name]
if not isinstance(state, InitState):
raise ValueError('Current type of state is %s, should be '
'an InitState object.' % type(state))
self._states_holder[state_name] = {}
if self._cur_decoder_obj.type == _DecoderType.TRAINING:
self._states_holder[state_name][id(self._cur_decoder_obj)] \
= _MemoryState(state_name,
self._cur_decoder_obj.dynamic_rnn,
state)
elif self._cur_decoder_obj.type == _DecoderType.BEAM_SEARCH:
self._states_holder[state_name][id(self._cur_decoder_obj)] \
= _ArrayState(state_name,
self._cur_decoder_obj._parent_block(),
state)
else:
raise ValueError('Unknown decoder type, only support '
'[TRAINING, BEAM_SEARCH]')
# Read back, since current state should be LoDTensor
self._cur_states[state_name] = \
self._states_holder[state_name][
id(self._cur_decoder_obj)].get_state()
self._switched_decoder = True
def get_state(self, state_name):
"""
The getter of state object. Find the state variable by its name.
Args:
state_name (str): A string of the state's name.
Returns:
The associated state object.
"""
if self._in_decoder and not self._switched_decoder:
self._switch_decoder()
if state_name not in self._cur_states:
raise ValueError(
'Unknown state %s. Please make sure _switch_decoder() '
'invoked.' % state_name)
return self._cur_states[state_name]
def get_input(self, input_name):
"""
The getter of input variable. Find the input variable by its name.
Args:
input_name (str): The string of the input's name.
Returns:
The associated input variable.
"""
if input_name not in self._inputs or self._inputs[input_name] is None:
raise ValueError('Invalid input %s.' % input_name)
return self._inputs[input_name]
def set_state(self, state_name, state_value):
"""
The setter of the state variable. Change the variable of the given
`state_name`.
Args:
state_name (str): The name of the state to change.
state_value (Var): The variable of the new state.
"""
self._cur_states[state_name] = state_value
def state_updater(self, updater):
"""
Set up the updater to update the hidden state every RNN step. The
behavior of updater could be customized by users. The updater should be
a function that takes a `StateCell` object as input and update the
hidden state within it. The hidden state could be accessed through
`get_state` method.
Args:
updater (func): the updater to update the state cell.
"""
self._state_updater = updater
def _decorator(state_cell):
if state_cell == self:
raise TypeError('Updater should only accept a StateCell object '
'as argument.')
updater(state_cell)
return _decorator
def compute_state(self, inputs):
"""
Provide the step input of RNN cell, and compute the new hidden state
with updater and give step input.
Args:
inputs (dict): A feed dict, {name(str): Variable}. name should be
the names of step inputs for this RNN cell, and Variable should be
the associated variables.
Examples:
.. code-block:: python
state_cell.compute_state(inputs={'x': current_word})
"""
if self._in_decoder and not self._switched_decoder:
self._switch_decoder()
for input_name, input_value in inputs.items():
if input_name not in self._inputs:
raise ValueError('Unknown input %s. '
'Please make sure %s in input '
'place holder.' % (input_name, input_name))
self._inputs[input_name] = input_value
self._state_updater(self)
def update_states(self):
"""
Update and record state information after each RNN step.
"""
if self._in_decoder and not self._switched_decoder:
self._switched_decoder()
for state_name, decoder_state in self._states_holder.items():
if id(self._cur_decoder_obj) not in decoder_state:
raise ValueError('Unknown decoder object, please make sure '
'switch_decoder been invoked.')
decoder_state[id(self._cur_decoder_obj)].update_state(
self._cur_states[state_name])
def out_state(self):
"""
Get the output state variable. This must be called after update_states.
Returns:
The output variable of the RNN cell.
"""
return self._cur_states[self._out_state]
class TrainingDecoder(object):
"""
A decoder that can only be used for training. The decoder could be
initialized with a `StateCell` object. The computation within the RNN cell
could be defined with decoder's block.
Args:
state_cell (StateCell): A StateCell object that handles the input and
state variables.
name (str): The name of this decoder. Default None.
Returns:
TrainingDecoder: The initialized TrainingDecoder object.
Examples:
.. code-block:: python
decoder = TrainingDecoder(state_cell)
with decoder.block():
current_word = decoder.step_input(trg_embedding)
decoder.state_cell.compute_state(inputs={'x': current_word})
current_score = layers.fc(input=decoder.state_cell.get_state('h'),
size=32,
act='softmax')
decoder.state_cell.update_states()
decoder.output(current_score)
"""
BEFORE_DECODER = 0
IN_DECODER = 1
AFTER_DECODER = 2
def __init__(self, state_cell, name=None):
self._helper = LayerHelper('training_decoder', name=name)
self._status = TrainingDecoder.BEFORE_DECODER
self._dynamic_rnn = layers.DynamicRNN()
self._type = _DecoderType.TRAINING
self._state_cell = state_cell
self._state_cell._enter_decoder(self)
@contextlib.contextmanager
def block(self):
"""
Define the behavior of the decoder for each RNN time step.
"""
if self._status != TrainingDecoder.BEFORE_DECODER:
raise ValueError('decoder.block() can only be invoked once')
self._status = TrainingDecoder.IN_DECODER
with self._dynamic_rnn.block():
yield
self._status = TrainingDecoder.AFTER_DECODER
self._state_cell._leave_decoder(self)
@property
def state_cell(self):
self._assert_in_decoder_block('state_cell')
return self._state_cell
@property
def dynamic_rnn(self):
return self._dynamic_rnn
@property
def type(self):
return self._type
def step_input(self, x):
"""
Set the input variable as a step input to the RNN cell. For example,
in machine translation, each time step we read one word from the target
sentences, then the target sentence is a step input to the RNN cell.
Args:
x (Variable): the variable to be used as step input.
Returns:
Variable: The variable as input of current step.
Examples:
.. code-block:: python
current_word = decoder.step_input(trg_embedding)
"""
self._assert_in_decoder_block('step_input')
return self._dynamic_rnn.step_input(x)
def static_input(self, x):
"""
Set the input variable as a static input of RNN cell. In contrast to
step input, this variable will be used as a whole within the RNN decode
loop and will not be scattered into time steps.
Args:
x (Variable): the variable to be used as static input.
Returns:
Variable: The variable as input of current step.
Examples:
.. code-block:: python
encoder_vec = decoder.static_input(encoded_vector)
"""
self._assert_in_decoder_block('static_input')
return self._dynamic_rnn.static_input(x)
def __call__(self, *args, **kwargs):
"""
Get the output of RNN. This API should only be invoked after RNN.block()
Returns:
Variable: The specified output of the RNN cell.
"""
if self._status != TrainingDecoder.AFTER_DECODER:
raise ValueError('Output of training decoder can only be visited '
'outside the block.')
return self._dynamic_rnn(*args, **kwargs)
def output(self, *outputs):
"""
Set the output variable of the RNN cell.
Args:
*outputs (Variables): a series of variables that treated as output
of the RNN cell.
Examples:
.. code-block:: python
out = fluid.layers.fc(input=h,
size=32,
bias_attr=True,
act='softmax')
decoder.output(out)
"""
self._assert_in_decoder_block('output')
self._dynamic_rnn.output(*outputs)
def _assert_in_decoder_block(self, method):
if self._status != TrainingDecoder.IN_DECODER:
raise ValueError('%s should be invoked inside block of '
'TrainingDecoder object.' % method)
class BeamSearchDecoder(object):
"""
A beam search decoder that can be used for inference. The decoder should be
initialized with a `StateCell` object. The decode process can be defined
within its block.
Args:
state_cell (StateCell): A StateCell object that handles the input and
state variables.
init_ids (Variable): The init beam search token ids.
init_scores (Variable): The associated score of each id.
target_dict_dim (int): Size of dictionary.
word_dim (int): Word embedding dimension.
input_var_dict (dict): A feeding dict to feed the required input
variables to the state cell. It will be used by state_cell 's
compute method. Default empty.
topk_size (int): The topk size used for beam search. Default 50.
max_len (int): The maximum allowed length of the generated sentence.
Default 100.
beam_size (int): The beam width of beam search decode. Default 1.
end_id (int): The id of end token within beam search.
name (str): The name of this decoder. Default None.
Returns:
BeamSearchDecoder: A initialized BeamSearchDecoder object.
Examples:
.. code-block:: python
decoder = BeamSearchDecoder(
state_cell=state_cell,
init_ids=init_ids,
init_scores=init_scores,
target_dict_dim=target_dict_dim,
word_dim=word_dim,
init_var_dict={},
topk_size=topk_size,
sparse_emb=IS_SPARSE,
max_len=max_length,
beam_size=beam_size,
end_id=1,
name=None
)
decoder.decode()
translation_ids, translation_scores = decoder()
"""
BEFORE_BEAM_SEARCH_DECODER = 0
IN_BEAM_SEARCH_DECODER = 1
AFTER_BEAM_SEARCH_DECODER = 2
def __init__(self,
state_cell,
init_ids,
init_scores,
target_dict_dim,
word_dim,
input_var_dict={},
topk_size=50,
sparse_emb=True,
max_len=100,
beam_size=1,
end_id=1,
name=None):
self._helper = LayerHelper('beam_search_decoder', name=name)
self._counter = layers.zeros(shape=[1], dtype='int64')
self._counter.stop_gradient = True
self._type = _DecoderType.BEAM_SEARCH
self._max_len = layers.fill_constant(
shape=[1], dtype='int64', value=max_len)
self._cond = layers.less_than(
x=self._counter,
y=layers.fill_constant(
shape=[1], dtype='int64', value=max_len))
self._while_op = layers.While(self._cond)
self._state_cell = state_cell
self._state_cell._enter_decoder(self)
self._status = BeamSearchDecoder.BEFORE_BEAM_SEARCH_DECODER
self._zero_idx = layers.fill_constant(
shape=[1], value=0, dtype='int64', force_cpu=True)
self._array_dict = {}
self._array_link = []
self._ids_array = None
self._scores_array = None
self._beam_size = beam_size
self._end_id = end_id
self._init_ids = init_ids
self._init_scores = init_scores
self._target_dict_dim = target_dict_dim
self._topk_size = topk_size
self._sparse_emb = sparse_emb
self._word_dim = word_dim
self._input_var_dict = input_var_dict
@contextlib.contextmanager
def block(self):
"""
Define the behavior of the decoder for each RNN time step.
"""
if self._status != BeamSearchDecoder.BEFORE_BEAM_SEARCH_DECODER:
raise ValueError('block() can only be invoke once.')
self._status = BeamSearchDecoder.IN_BEAM_SEARCH_DECODER
with self._while_op.block():
yield
with layers.Switch() as switch:
with switch.case(self._cond):
layers.increment(x=self._counter, value=1.0, in_place=True)
for value, array in self._array_link:
layers.array_write(
x=value, i=self._counter, array=array)
layers.less_than(
x=self._counter, y=self._max_len, cond=self._cond)
self._status = BeamSearchDecoder.AFTER_BEAM_SEARCH_DECODER
self._state_cell._leave_decoder(self)
@property
def type(self):
return self._type
def early_stop(self):
"""
Stop the generation process in advance. Could be used as "break".
"""
layers.fill_constant(
shape=[1], value=0, dtype='bool', force_cpu=True, out=self._cond)
def decode(self):
"""
Set up the computation within the decoder. Then you could call the
decoder to get the result of beam search decode. If you want to define
a more specific decoder, you could override this function.
Examples:
.. code-block:: python
decoder.decode()
translation_ids, translation_scores = decoder()
"""
with self.block():
prev_ids = self.read_array(init=self._init_ids, is_ids=True)
prev_scores = self.read_array(
init=self._init_scores, is_scores=True)
prev_ids_embedding = layers.embedding(
input=prev_ids,
size=[self._target_dict_dim, self._word_dim],
dtype='float32',
is_sparse=self._sparse_emb)
feed_dict = {}
update_dict = {}
for init_var_name, init_var in self._input_var_dict.items():
if init_var_name not in self.state_cell._inputs:
raise ValueError('Variable ' + init_var_name +
' not found in StateCell!\n')
read_var = self.read_array(init=init_var)
update_dict[init_var_name] = read_var
feed_var_expanded = layers.sequence_expand(read_var,
prev_scores)
feed_dict[init_var_name] = feed_var_expanded
for state_str in self._state_cell._state_names:
prev_state = self.state_cell.get_state(state_str)
prev_state_expanded = layers.sequence_expand(prev_state,
prev_scores)
self.state_cell.set_state(state_str, prev_state_expanded)
for i, input_name in enumerate(self._state_cell._inputs):
if input_name not in feed_dict:
feed_dict[input_name] = prev_ids_embedding
self.state_cell.compute_state(inputs=feed_dict)
current_state = self.state_cell.out_state()
current_state_with_lod = layers.lod_reset(
x=current_state, y=prev_scores)
scores = layers.fc(input=current_state_with_lod,
size=self._target_dict_dim,
act='softmax')
topk_scores, topk_indices = layers.topk(scores, k=self._topk_size)
accu_scores = layers.elementwise_add(
x=layers.log(x=topk_scores),
y=layers.reshape(
prev_scores, shape=[-1]),
axis=0)
selected_ids, selected_scores = layers.beam_search(
prev_ids,
prev_scores,
topk_indices,
accu_scores,
self._beam_size,
end_id=1,
level=0)
with layers.Switch() as switch:
with switch.case(layers.is_empty(selected_ids)):
self.early_stop()
with switch.default():
self.state_cell.update_states()
self.update_array(prev_ids, selected_ids)
self.update_array(prev_scores, selected_scores)
for update_name, var_to_update in update_dict.items():
self.update_array(var_to_update, feed_dict[update_name])
def read_array(self, init, is_ids=False, is_scores=False):
"""
Read an array to get the decoded ids and scores generated by previous
RNN step. At the first step of RNN, the init variable mut be used to
initialize the array.
Args:
init (Variable): The initial variable for first step usage. init
must be provided.
is_ids (bool): Specify whether the variable is an id.
is_scores (bool): Specify whether the variable is a score.
Returns:
The associated variable generated during previous RNN steps.
Examples:
.. code-block:: python
prev_ids = decoder.read_array(init=init_ids, is_ids=True)
prev_scores = decoder.read_array(init=init_scores, is_scores=True)
"""
self._assert_in_decoder_block('read_array')
if is_ids and is_scores:
raise ValueError('Shouldn\'t mark current array be ids array and'
'scores array at the same time.')
if not isinstance(init, Variable):
raise TypeError('The input argument `init` must be a Variable.')
parent_block = self._parent_block()
array = parent_block.create_var(
name=unique_name.generate('beam_search_decoder_array'),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=init.dtype)
parent_block.append_op(
type='write_to_array',
inputs={'X': init,
'I': self._zero_idx},
outputs={'Out': array})
if is_ids:
self._ids_array = array
elif is_scores:
self._scores_array = array
read_value = layers.array_read(array=array, i=self._counter)
self._array_dict[read_value.name] = array
return read_value
def update_array(self, array, value):
"""
Store the value generated in current step in an array for each RNN step.
This array could be accessed by read_array method.
Args:
array (Variable): The array to append the new variable to.
value (Variable): The newly generated value to be stored.
"""
self._assert_in_decoder_block('update_array')
if not isinstance(array, Variable):
raise TypeError(
'The input argument `array` of must be a Variable.')
if not isinstance(value, Variable):
raise TypeError('The input argument `value` of must be a Variable.')
array = self._array_dict.get(array.name, None)
if array is None:
raise ValueError('Please invoke read_array before update_array.')
self._array_link.append((value, array))
def __call__(self):
"""
Run the decode process and return the final decode result.
Returns:
A tuple of decoded (id, score) pairs. id is a Variable that holds
the generated tokens, and score is a Variable with the same shape
as id, holds the score for each generated token.
"""
if self._status != BeamSearchDecoder.AFTER_BEAM_SEARCH_DECODER:
raise ValueError('Output of BeamSearchDecoder object can '
'only be visited outside the block.')
return layers.beam_search_decode(
ids=self._ids_array,
scores=self._scores_array,
beam_size=self._beam_size,
end_id=self._end_id)
@property
def state_cell(self):
self._assert_in_decoder_block('state_cell')
return self._state_cell
def _parent_block(self):
"""
Getter of parent block.
Returns:
The parent block of decoder.
"""
program = self._helper.main_program
parent_block_idx = program.current_block().parent_idx
if parent_block_idx < 0:
raise ValueError('Invalid block with index %d.' % parent_block_idx)
parent_block = program.block(parent_block_idx)
return parent_block
def _assert_in_decoder_block(self, method):
if self._status != BeamSearchDecoder.IN_BEAM_SEARCH_DECODER:
raise ValueError('%s should be invoked inside block of '
'BeamSearchDecoder object.' % method)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle import paddle
import sys import sys
import numpy import numpy
...@@ -134,4 +135,4 @@ def main(use_cuda): ...@@ -134,4 +135,4 @@ def main(use_cuda):
if __name__ == '__main__': if __name__ == '__main__':
# for use_cuda in (False, True): # for use_cuda in (False, True):
main(use_cuda=True) main(use_cuda=core.is_compiled_with_cuda())
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import paddle.fluid.core as core
import math import math
import os import os
import sys import sys
...@@ -257,6 +258,8 @@ def inject_test_method(use_cuda, parallel, nn_type, combine): ...@@ -257,6 +258,8 @@ def inject_test_method(use_cuda, parallel, nn_type, combine):
def inject_all_tests(): def inject_all_tests():
for use_cuda in (False, True): for use_cuda in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
for parallel in (False, True): for parallel in (False, True):
for nn_type in ('mlp', 'conv'): for nn_type in ('mlp', 'conv'):
inject_test_method(use_cuda, parallel, nn_type, True) inject_test_method(use_cuda, parallel, nn_type, True)
......
...@@ -245,7 +245,7 @@ def inject_test_method(use_cuda, is_sparse, is_parallel): ...@@ -245,7 +245,7 @@ def inject_test_method(use_cuda, is_sparse, is_parallel):
is_sparse=is_sparse, is_sparse=is_sparse,
is_parallel=is_parallel) is_parallel=is_parallel)
if use_cuda and is_sparse: if (not fluid.core.is_compiled_with_cuda() or use_cuda) and is_sparse:
fn = __impl__ fn = __impl__
else: else:
# skip the other test when on CI server # skip the other test when on CI server
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A simple machine translation demo using beam search decoder.
"""
import contextlib
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
from paddle.fluid.contrib.decoder.beam_search_decoder import *
import unittest
import os
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)
hidden_dim = 32
word_dim = 32
decoder_size = hidden_dim
IS_SPARSE = True
batch_size = 2
max_length = 8
topk_size = 50
trg_dic_size = 10000
beam_size = 2
def encoder():
# encoder
src_word = layers.data(
name="src_word", shape=[1], dtype='int64', lod_level=1)
src_embedding = layers.embedding(
input=src_word,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=IS_SPARSE)
fc1 = layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4)
encoder_out = layers.sequence_last_step(input=lstm_hidden0)
return encoder_out
def decoder_state_cell(context):
h = InitState(init=context, need_reorder=True)
state_cell = StateCell(inputs={'x': None}, states={'h': h}, out_state='h')
@state_cell.state_updater
def updater(state_cell):
current_word = state_cell.get_input('x')
prev_h = state_cell.get_state('h')
# make sure lod of h heritted from prev_h
h = layers.fc(input=[prev_h, current_word],
size=decoder_size,
act='tanh')
state_cell.set_state('h', h)
return state_cell
def decoder_train(state_cell):
# decoder
trg_language_word = layers.data(
name="target_word", shape=[1], dtype='int64', lod_level=1)
trg_embedding = layers.embedding(
input=trg_language_word,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=IS_SPARSE)
decoder = TrainingDecoder(state_cell)
with decoder.block():
current_word = decoder.step_input(trg_embedding)
decoder.state_cell.compute_state(inputs={'x': current_word})
current_score = layers.fc(input=decoder.state_cell.get_state('h'),
size=target_dict_dim,
act='softmax')
decoder.state_cell.update_states()
decoder.output(current_score)
return decoder()
def decoder_decode(state_cell):
init_ids = layers.data(
name="init_ids", shape=[1], dtype="int64", lod_level=2)
init_scores = layers.data(
name="init_scores", shape=[1], dtype="float32", lod_level=2)
decoder = BeamSearchDecoder(
state_cell=state_cell,
init_ids=init_ids,
init_scores=init_scores,
target_dict_dim=target_dict_dim,
word_dim=word_dim,
input_var_dict={},
topk_size=topk_size,
sparse_emb=IS_SPARSE,
max_len=max_length,
beam_size=beam_size,
end_id=1,
name=None)
decoder.decode()
translation_ids, translation_scores = decoder()
return translation_ids, translation_scores
def train_main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
context = encoder()
state_cell = decoder_state_cell(context)
rnn_out = decoder_train(state_cell)
label = layers.data(
name="target_next_word", shape=[1], dtype='int64', lod_level=1)
cost = layers.cross_entropy(input=rnn_out, label=label)
avg_cost = layers.mean(x=cost)
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-3)
optimizer.minimize(avg_cost)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)
feed_order = ['src_word', 'target_word', 'target_next_word']
exe = Executor(place)
def train_loop(main_program):
exe.run(framework.default_startup_program())
feed_list = [
main_program.global_block().var(var_name) for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
for pass_id in xrange(1):
for batch_id, data in enumerate(train_reader()):
outs = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_cost])
avg_cost_val = np.array(outs[0])
print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) +
" avg_cost=" + str(avg_cost_val))
if batch_id > 3:
break
train_loop(framework.default_main_program())
def decode_main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
context = encoder()
state_cell = decoder_state_cell(context)
translation_ids, translation_scores = decoder_decode(state_cell)
exe = Executor(place)
exe.run(framework.default_startup_program())
init_ids_data = np.array([0 for _ in range(batch_size)], dtype='int64')
init_scores_data = np.array(
[1. for _ in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_lod = [1] * batch_size
init_lod = [init_lod, init_lod]
init_ids = fluid.create_lod_tensor(init_ids_data, init_lod, place)
init_scores = fluid.create_lod_tensor(init_scores_data, init_lod, place)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)
feed_order = ['src_word']
feed_list = [
framework.default_main_program().global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
data = train_reader().next()
feed_dict = feeder.feed(map(lambda x: [x[0]], data))
feed_dict['init_ids'] = init_ids
feed_dict['init_scores'] = init_scores
result_ids, result_scores = exe.run(
framework.default_main_program(),
feed=feed_dict,
fetch_list=[translation_ids, translation_scores],
return_numpy=False)
print result_ids.lod()
class TestBeamSearchDecoder(unittest.TestCase):
pass
@contextlib.contextmanager
def scope_prog_guard():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
def inject_test_train(use_cuda):
f_name = 'test_{0}_train'.format('cuda' if use_cuda else 'cpu')
def f(*args):
with scope_prog_guard():
train_main(use_cuda)
setattr(TestBeamSearchDecoder, f_name, f)
def inject_test_decode(use_cuda, decorator=None):
f_name = 'test_{0}_decode'.format('cuda' if use_cuda else 'cpu', 'sparse')
def f(*args):
with scope_prog_guard():
decode_main(use_cuda)
if decorator is not None:
f = decorator(f)
setattr(TestBeamSearchDecoder, f_name, f)
for _use_cuda_ in (False, True):
inject_test_train(_use_cuda_)
for _use_cuda_ in (False, True):
_decorator_ = None
inject_test_decode(use_cuda=_use_cuda_, decorator=_decorator_)
if __name__ == '__main__':
unittest.main()
...@@ -12,6 +12,11 @@ endif(NOT WITH_MKLDNN) ...@@ -12,6 +12,11 @@ endif(NOT WITH_MKLDNN)
if(NOT WITH_DISTRIBUTE) if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_recv_op) list(REMOVE_ITEM TEST_OPS test_recv_op)
list(REMOVE_ITEM TEST_OPS test_dist_transpiler)
list(REMOVE_ITEM TEST_OPS test_simple_dist_transpiler)
list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op)
LIST(REMOVE_ITEM TEST_OPS test_dist_mnist)
LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec)
endif(NOT WITH_DISTRIBUTE) endif(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
...@@ -47,9 +52,11 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -47,9 +52,11 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL) if(WITH_DISTRIBUTE)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 180)
endif()
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 180)
...@@ -100,6 +100,8 @@ class TestBeamSearchDecodeOp(unittest.TestCase): ...@@ -100,6 +100,8 @@ class TestBeamSearchDecodeOp(unittest.TestCase):
np.array_equal(np.array(sentence_scores), expected_data)) np.array_equal(np.array(sentence_scores), expected_data))
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestBeamSearchDecodeOpGPU(TestBeamSearchDecodeOp): class TestBeamSearchDecodeOpGPU(TestBeamSearchDecodeOp):
def setUp(self): def setUp(self):
self.scope = core.Scope() self.scope = core.Scope()
......
...@@ -191,12 +191,16 @@ class TestWithDilation(TestConv2dTransposeOp): ...@@ -191,12 +191,16 @@ class TestWithDilation(TestConv2dTransposeOp):
# ------------ test_cudnn ------------ # ------------ test_cudnn ------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNN(TestConv2dTransposeOp): class TestCUDNN(TestConv2dTransposeOp):
def init_op_type(self): def init_op_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv2d_transpose" self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithPad(TestWithPad): class TestCUDNNWithPad(TestWithPad):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
...@@ -212,6 +216,8 @@ class TestCUDNNWithPad(TestWithPad): ...@@ -212,6 +216,8 @@ class TestCUDNNWithPad(TestWithPad):
self.op_type = "conv2d_transpose" self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithStride(TestWithStride): class TestCUDNNWithStride(TestWithStride):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
...@@ -227,6 +233,8 @@ class TestCUDNNWithStride(TestWithStride): ...@@ -227,6 +233,8 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv2d_transpose" self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithGroups(TestWithGroups): class TestCUDNNWithGroups(TestWithGroups):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
......
...@@ -197,12 +197,16 @@ class TestWithDilation(TestConv3dTransposeOp): ...@@ -197,12 +197,16 @@ class TestWithDilation(TestConv3dTransposeOp):
# ------------ test_cudnn ------------ # ------------ test_cudnn ------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNN(TestConv3dTransposeOp): class TestCUDNN(TestConv3dTransposeOp):
def init_op_type(self): def init_op_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv3d_transpose" self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithPad(TestWithPad): class TestCUDNNWithPad(TestWithPad):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1, 1] self.pad = [1, 1, 1]
...@@ -218,6 +222,8 @@ class TestCUDNNWithPad(TestWithPad): ...@@ -218,6 +222,8 @@ class TestCUDNNWithPad(TestWithPad):
self.op_type = "conv3d_transpose" self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithStride(TestWithStride): class TestCUDNNWithStride(TestWithStride):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1, 1] self.pad = [1, 1, 1]
...@@ -233,6 +239,8 @@ class TestCUDNNWithStride(TestWithStride): ...@@ -233,6 +239,8 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv3d_transpose" self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithGroups(TestWithGroups): class TestCUDNNWithGroups(TestWithGroups):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1, 1] self.pad = [1, 1, 1]
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle.dataset.flowers as flowers import paddle.dataset.flowers as flowers
import math import math
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
...@@ -92,6 +93,7 @@ class TestFetchOp(unittest.TestCase): ...@@ -92,6 +93,7 @@ class TestFetchOp(unittest.TestCase):
train_inputs.append(tst_reader_iter.next()) train_inputs.append(tst_reader_iter.next())
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
if core.is_compiled_with_cuda():
self.parallel_exe(train_inputs, seed=1, use_cuda=True) self.parallel_exe(train_inputs, seed=1, use_cuda=True)
self.parallel_exe(train_inputs, seed=1, use_cuda=False) self.parallel_exe(train_inputs, seed=1, use_cuda=False)
...@@ -137,6 +139,7 @@ class TestFeedParallel(unittest.TestCase): ...@@ -137,6 +139,7 @@ class TestFeedParallel(unittest.TestCase):
def test_feed_op(self): def test_feed_op(self):
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
if core.is_compiled_with_cuda():
self.parallel_exe(use_cuda=True, seed=1) self.parallel_exe(use_cuda=True, seed=1)
self.parallel_exe(use_cuda=False, seed=1) self.parallel_exe(use_cuda=False, seed=1)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from parallel_executor_test_base import TestParallelExecutorBase from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np import numpy as np
import paddle import paddle
import paddle.dataset.mnist as mnist import paddle.dataset.mnist as mnist
...@@ -98,6 +99,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -98,6 +99,8 @@ class TestMNIST(TestParallelExecutorBase):
MNIST_RECORDIO_FILE, reader, feeder) MNIST_RECORDIO_FILE, reader, feeder)
def check_simple_fc_convergence(self, use_cuda, use_reduce=False): def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(simple_fc_net, use_cuda=use_cuda) self.check_network_convergence(simple_fc_net, use_cuda=use_cuda)
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True) simple_fc_net, use_cuda=use_cuda, allow_op_delay=True)
...@@ -122,6 +125,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -122,6 +125,8 @@ class TestMNIST(TestParallelExecutorBase):
self.check_simple_fc_convergence(False, True) self.check_simple_fc_convergence(False, True)
def check_simple_fc_parallel_accuracy(self, use_cuda, use_reduce=False): def check_simple_fc_parallel_accuracy(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence( single_first_loss, single_last_loss = self.check_network_convergence(
...@@ -155,6 +160,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -155,6 +160,8 @@ class TestMNIST(TestParallelExecutorBase):
self.check_simple_fc_parallel_accuracy(False, True) self.check_simple_fc_parallel_accuracy(False, True)
def check_batchnorm_fc_convergence(self, use_cuda, use_reduce=False): def check_batchnorm_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda) self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda)
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
......
...@@ -16,6 +16,7 @@ import paddle.fluid as fluid ...@@ -16,6 +16,7 @@ import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import paddle.fluid.core as core
from parallel_executor_test_base import TestParallelExecutorBase from parallel_executor_test_base import TestParallelExecutorBase
import unittest import unittest
import math import math
...@@ -140,6 +141,9 @@ class TestResnet(TestParallelExecutorBase): ...@@ -140,6 +141,9 @@ class TestResnet(TestParallelExecutorBase):
use_reduce=False, use_reduce=False,
iter=20): iter=20):
if use_cuda and not core.is_compiled_with_cuda():
return
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
def _cosine_decay(learning_rate, step_each_epoch, epochs=120): def _cosine_decay(learning_rate, step_each_epoch, epochs=120):
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np import numpy as np
import unittest import unittest
import os import os
...@@ -92,6 +93,7 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -92,6 +93,7 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def test_parallel_testing(self): def test_parallel_testing(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
if core.is_compiled_with_cuda():
self.check_network_convergence( self.check_network_convergence(
use_cuda=True, build_strategy=build_strategy) use_cuda=True, build_strategy=build_strategy)
self.check_network_convergence( self.check_network_convergence(
...@@ -100,6 +102,7 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -100,6 +102,7 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def test_parallel_testing_with_new_strategy(self): def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
if core.is_compiled_with_cuda():
self.check_network_convergence( self.check_network_convergence(
use_cuda=True, build_strategy=build_strategy) use_cuda=True, build_strategy=build_strategy)
self.check_network_convergence( self.check_network_convergence(
......
...@@ -56,6 +56,8 @@ class TestPrintOpCPU(unittest.TestCase): ...@@ -56,6 +56,8 @@ class TestPrintOpCPU(unittest.TestCase):
return_numpy=False) return_numpy=False)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestPrintOpGPU(TestPrintOpCPU): class TestPrintOpGPU(TestPrintOpCPU):
def setUp(self): def setUp(self):
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
......
...@@ -79,12 +79,18 @@ class TestProfiler(unittest.TestCase): ...@@ -79,12 +79,18 @@ class TestProfiler(unittest.TestCase):
pass_acc_calculator.add(value=acc, weight=b_size) pass_acc_calculator.add(value=acc, weight=b_size)
pass_acc = pass_acc_calculator.eval() pass_acc = pass_acc_calculator.eval()
@unittest.skipIf(not core.is_compiled_with_cuda(),
"profiler is enabled only with GPU")
def test_cpu_profiler(self): def test_cpu_profiler(self):
self.net_profiler('CPU') self.net_profiler('CPU')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"profiler is enabled only with GPU")
def test_cuda_profiler(self): def test_cuda_profiler(self):
self.net_profiler('GPU') self.net_profiler('GPU')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"profiler is enabled only with GPU")
def test_all_profiler(self): def test_all_profiler(self):
self.net_profiler('All', '/tmp/profile_out') self.net_profiler('All', '/tmp/profile_out')
with open('/tmp/profile_out', 'r') as f: with open('/tmp/profile_out', 'r') as f:
......
...@@ -61,6 +61,8 @@ class TestSequenceSoftmaxOp(OpTest): ...@@ -61,6 +61,8 @@ class TestSequenceSoftmaxOp(OpTest):
# ----------------cudnn Sequencesoftmax---------------- # ----------------cudnn Sequencesoftmax----------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSequenceSoftmaxCUDNNOp(TestSequenceSoftmaxOp): class TestSequenceSoftmaxCUDNNOp(TestSequenceSoftmaxOp):
def init_op_type(self): def init_op_type(self):
self.use_cudnn = True self.use_cudnn = True
......
...@@ -63,11 +63,15 @@ class TestSoftmaxOp(OpTest): ...@@ -63,11 +63,15 @@ class TestSoftmaxOp(OpTest):
self.check_grad(["X"], "Out", max_relative_error=0.01) self.check_grad(["X"], "Out", max_relative_error=0.01)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp(TestSoftmaxOp): class TestSoftmaxCUDNNOp(TestSoftmaxOp):
def init_kernel_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxFP16Op(TestSoftmaxOp): class TestSoftmaxFP16Op(TestSoftmaxOp):
def init_kernel_type(self): def init_kernel_type(self):
self.dtype = np.float16 self.dtype = np.float16
...@@ -79,6 +83,8 @@ class TestSoftmaxFP16Op(TestSoftmaxOp): ...@@ -79,6 +83,8 @@ class TestSoftmaxFP16Op(TestSoftmaxOp):
self.check_output_with_place(place, atol=1e-3) self.check_output_with_place(place, atol=1e-3)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp): class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp):
def init_kernel_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
......
...@@ -68,8 +68,14 @@ def reader_creator(image_filename, label_filename, buffer_size): ...@@ -68,8 +68,14 @@ def reader_creator(image_filename, label_filename, buffer_size):
for i in xrange(buffer_size): for i in xrange(buffer_size):
yield images[i, :], int(labels[i]) yield images[i, :], int(labels[i])
finally: finally:
try:
m.terminate() m.terminate()
except:
pass
try:
l.terminate() l.terminate()
except:
pass
return reader return reader
......
...@@ -104,6 +104,8 @@ packages=['paddle', ...@@ -104,6 +104,8 @@ packages=['paddle',
'paddle.fluid.proto', 'paddle.fluid.proto',
'paddle.fluid.proto.profiler', 'paddle.fluid.proto.profiler',
'paddle.fluid.layers', 'paddle.fluid.layers',
'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder',
'paddle.fluid.transpiler', 'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details'] 'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册