未验证 提交 7268760f 编写于 作者: Y yuyang18

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/combine_open_files_and_double_buffer
......@@ -18,7 +18,21 @@ learning to many products at Baidu.
Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Lastest PaddlePaddle Version: [Fluid](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid)
### Latest PaddlePaddle Release: [Fluid 0.14.0](https://github.com/PaddlePaddle/Paddle/tree/v0.14.0)
### Install Latest Stable Release:
```
# Linux CPU
pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==0.14.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==0.14.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/
```
## Features
......
## Motivation
There is a ```gap``` between the ```Program``` defined by
user and the ```Executable``` that can be scheduled
efficiently on heterogeneous hardware, either locally
or distributedly.
Usually, the ```gap``` is bridged by
* A serious transformations with defined order.
* These transformations usually involve
```insert, delete, clustering, split, dependency analysis```.
* Has a simple way to verify and debug each transformation.
* Flexible to add, remove or customize transformations to fit
the requirements of various algorithms (models) and hardware secenarios.
Some other events also push us to a better unified pattern.
* The deep learning framework is built around the concepts of graphs.
To leverage tools such as compilation (e.g. TVM and nGraph) or
cross-framework conversion (e.g. ONNX), we also need a intermediate
representation that can be connected to the rest of the ecosystem.
We need a unified pattern to naturally support the requirements
described above. The pattern should fit both training, inference
and other offline serielized model transformations.
Learned from LLVM and other deep learning framework, we draft the
design below.
## Design
### Major Concepts
#### Node
```Node``` represents an operation that performs some computation or
a variable that is input or output of operation.
```Node```s are connected to other ```Node```s via inputs and outputs.
Other properties (maybe device placement information) can be added
to ```Node``` in the future if it's a
common requirement of many other ```Pass```es. Otherwise, it should live
in a ```Node``` wrapper class that is private to some ```Pass``` or be
a local member of a ```Pass```.
#### Graph
```Graph``` contains a list of ```Node```s, which are connected to
each other via inputs and outputs.
TODO: Better definitions for the graph.
```Graph``` can also contain ```Attribute```s. ```Attribute```s
can be ``any`` thing. For example, it can be a list of "wraper"
nodes. The ```wrapper``` nodes compose ```Node```s and provide
helper method for execution or transformation. ```Attribute```
can also contain other things that describe some properties of
the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed
across ```Pass```. However, it should be used with care.
#### Pass
```Pass``` represents a transformation of ```Graph```. Its input
is a ```Graph``` and its output is also a ```Graph```. For example,
a ```Pass``` can simply print out the ```Graph```. A ```Pass```
can also fuse some ```Graph```'s ```Node```s.
#### Optimize
```Optimize``` contains a series of ```Pass``` with defined order.
```Optimize``` transforms a ```Graph``` that only contains raw
modeling logic to a ```Graph``` that can be run efficiently while
maintaining the original modeling logic.
### Optimize Process
* Program is first converted to Graph.
* Graph goes through a series of Pass
* Graph is transformed from raw model logic to a
form that is efficient to execute.
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
......@@ -341,6 +341,26 @@ paddle.fluid.layers.polynomial_decay ArgSpec(args=['learning_rate', 'decay_steps
paddle.fluid.layers.piecewise_decay ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.noam_decay ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.append_LARS ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.InitState.__init__ ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32'))
paddle.fluid.contrib.StateCell.__init__ ArgSpec(args=['self', 'inputs', 'states', 'out_state', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.StateCell.compute_state ArgSpec(args=['self', 'inputs'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.get_input ArgSpec(args=['self', 'input_name'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.get_state ArgSpec(args=['self', 'state_name'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.out_state ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.set_state ArgSpec(args=['self', 'state_name', 'state_value'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.state_updater ArgSpec(args=['self', 'updater'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.StateCell.update_states ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.TrainingDecoder.__init__ ArgSpec(args=['self', 'state_cell', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.TrainingDecoder.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.contrib.TrainingDecoder.output ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None)
paddle.fluid.contrib.TrainingDecoder.static_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.TrainingDecoder.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.__init__ ArgSpec(args=['self', 'state_cell', 'init_ids', 'init_scores', 'target_dict_dim', 'word_dim', 'input_var_dict', 'topk_size', 'sparse_emb', 'max_len', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=({}, 50, True, 100, 1, 1, None))
paddle.fluid.contrib.BeamSearchDecoder.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.decode ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.early_stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.read_array ArgSpec(args=['self', 'init', 'is_ids', 'is_scores'], varargs=None, keywords=None, defaults=(False, False))
paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
......
add_subdirectory(details)
add_subdirectory(ir)
# ddim lib
proto_library(framework_proto SRCS framework.proto)
......@@ -93,7 +94,7 @@ else()
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
......@@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......
......@@ -23,10 +23,14 @@ namespace framework {
namespace details {
#ifdef PADDLE_WITH_CUDA
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
......@@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
}
}
#else
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
void AllReduceOpHandle::RunImpl() {
......
......@@ -30,11 +30,11 @@ namespace details {
struct AllReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
std::string Name() const override;
......
......@@ -35,10 +35,13 @@ namespace details {
struct BroadcastOpHandle : public OpHandleBase {
public:
#ifdef PADDLE_WITH_CUDA
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
......@@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase {
}
}
#else
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string Name() const override;
......
......@@ -96,48 +96,61 @@ struct TestBroadcastOpHandle {
}
param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(
new ir::Node("node0", ir::Node::Type::kOperation));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
nccl_ctxs_.get()));
#else
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
op_handle_.reset(
new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_));
#endif
}
auto* in_var_handle =
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
std::unique_ptr<ir::Node> v(
new ir::Node("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add dummy var
vars_.emplace_back(new DummyVarHandle());
std::unique_ptr<ir::Node> v2(
new ir::Node("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
dummy_var_handle->generated_op_ = nullptr;
dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
std::unique_ptr<ir::Node> v3(
new ir::Node("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
// add dummy var
vars_.emplace_back(new DummyVarHandle());
std::unique_ptr<ir::Node> v4(
new ir::Node("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
out_dummy_var_handle->generated_op_ = nullptr;
out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle);
}
......
......@@ -19,9 +19,10 @@
namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope),
place_(place) {}
......@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() {
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->generated_op_ &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
in_var && in_var->GeneratedOp() &&
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait;
}
......
......@@ -28,8 +28,7 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place);
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
std::string Name() const override;
......
......@@ -22,10 +22,10 @@ namespace details {
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes,
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places) {
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) {
for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs->DevCtx(p);
......@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle(
}
#else
DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes,
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string DataBalanceOpHandle::Name() const { return "data balance"; }
......
......@@ -30,11 +30,11 @@ namespace details {
struct DataBalanceOpHandle : public OpHandleBase {
public:
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
......
......@@ -21,13 +21,16 @@ namespace paddle {
namespace framework {
namespace details {
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset,
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes)
: data_(data), offset_(offset), local_scopes_(local_scopes) {}
: OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() {
for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this);
input_var->RemoveOutput(this, this->Node());
}
}
......@@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() {
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) {
if (input->generated_op_) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
if (input->GeneratedOp()) {
input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx);
}
}
}
......
......@@ -28,7 +28,7 @@ namespace details {
struct FetchOpHandle : public OpHandleBase {
public:
FetchOpHandle(FeedFetchList *data, size_t offset,
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes);
~FetchOpHandle();
......
......@@ -30,10 +30,12 @@ namespace details {
struct FuseVarsOpHandle : public OpHandleBase {
public:
FuseVarsOpHandle(Scope *local_scope, const platform::Place &place,
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type)
: local_scope_(local_scope),
: OpHandleBase(node),
local_scope_(local_scope),
place_(place),
inputs_numel_(inputs_numel),
type_(var_type) {
......
......@@ -20,9 +20,10 @@ namespace paddle {
namespace framework {
namespace details {
GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
GatherOpHandle::GatherOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() {
if (places_.size() == 1) return;
......
......@@ -30,7 +30,7 @@ namespace details {
struct GatherOpHandle : public OpHandleBase {
public:
GatherOpHandle(const std::vector<Scope *> &local_scopes,
GatherOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
std::string Name() const override;
......
......@@ -70,6 +70,7 @@ struct TestGatherOpHandle {
}
void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
......@@ -81,30 +82,37 @@ struct TestGatherOpHandle {
}
param_scopes_[input_scope_idx]->Var("out");
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_));
nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation));
op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable));
auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
// add dummy var
vars_.emplace_back(new DummyVarHandle());
nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr;
in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle);
// add output
auto* out_var_handle =
new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]);
nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable));
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
// add dummy var
vars_.emplace_back(new DummyVarHandle());
nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
op_handle_->AddOutput(dummy_var_handle);
......
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace platform {
......@@ -45,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy);
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t device_id) const;
void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const;
private:
std::string loss_var_name_;
......@@ -63,48 +62,46 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
platform::NCCLContextMap *nccl_ctxs_;
#endif
bool IsScaleLossOp(const OpDesc &op) const;
bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(Graph *result, ir::Node *node) const;
void CreateDistTrainOp(Graph *result, ir::Node *node) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
void CreateComputationalOps(Graph *result, ir::Node *node,
size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
int dev_id) const;
void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const;
bool IsParameterGradientOnce(
const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(const OpDesc &op) const;
int GetOpDeviceID(ir::Node *node) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
void InsertAllReduceOp(Graph *result, const std::string &og) const;
void InsertDataBalanceOp(SSAGraph *result,
void InsertDataBalanceOp(Graph *result,
const std::vector<std::string> &datas) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
void CreateBroadcastOp(Graph *result, const std::string &p_name,
size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const;
......
......@@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
void OpHandleBase::AddInput(VarHandleBase *in) {
this->inputs_.emplace_back(in);
in->pending_ops_.insert(this);
node_->inputs.push_back(in->Node());
in->AddOutput(this, this->Node());
}
void OpHandleBase::AddOutput(VarHandleBase *out) {
outputs_.emplace_back(out);
out->generated_op_ = this;
node_->outputs.push_back(out->Node());
out->AddInput(this, this->Node());
}
void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second);
}
}
}
......@@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() {
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]);
}
}
}
......@@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const {
}
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_;
return in_var && in_var->GeneratedOp();
}
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
......
......@@ -17,6 +17,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
......@@ -26,9 +27,11 @@ namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
class OpHandleBase {
public:
OpHandleBase() {}
explicit OpHandleBase(ir::Node *node) : node_(node) {}
virtual ~OpHandleBase();
......@@ -82,6 +85,8 @@ class OpHandleBase {
size_t NoDummyInputSize() const;
ir::Node *Node() { return node_; }
protected:
void RunAndRecordEvent(const std::function<void()> &callback);
......@@ -90,6 +95,7 @@ class OpHandleBase {
virtual void RunImpl() = 0;
ir::Node *node_;
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
......
......@@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
ReduceOpHandle(const std::vector<Scope *> &local_scopes,
ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
......@@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase {
}
}
#else
ReduceOpHandle(const std::vector<Scope *> &local_scopes,
ReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string Name() const override;
......
......@@ -84,6 +84,7 @@ struct TestReduceOpHandle {
}
void InitReduceOp(size_t out_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
// init scope
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
......@@ -96,19 +97,21 @@ struct TestReduceOpHandle {
}
param_scopes_[out_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node("node"));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
gpu_list_, nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
gpu_list_, nccl_ctxs_.get()));
#else
op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_));
op_handle_.reset(
new ReduceOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
#endif
}
......@@ -118,8 +121,10 @@ struct TestReduceOpHandle {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
in_var_handle->generated_op_ = nullptr;
nodes.emplace_back(new ir::Node("node1"));
auto *in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
in_var_handle->ClearGeneratedOp();
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
......@@ -128,12 +133,13 @@ struct TestReduceOpHandle {
vars_.emplace_back(new DummyVarHandle());
DummyVarHandle *in_dummy_var_handle =
static_cast<DummyVarHandle *>(vars_.back().get());
in_dummy_var_handle->generated_op_ = nullptr;
in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle);
// add output
auto *out_var_handle =
new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]);
nodes.emplace_back(new ir::Node("node2"));
auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx,
"out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
......
......@@ -18,10 +18,11 @@ namespace paddle {
namespace framework {
namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
const Scope *local_scope, const std::string &name,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
name_(name),
place_(place) {}
......@@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK
continue;
}
if (in->generated_op_) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
......@@ -28,8 +28,9 @@ namespace framework {
namespace details {
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const std::string& name, const platform::Place& place);
RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc,
const Scope* local_scope, const std::string& name,
const platform::Place& place);
std::string Name() const override;
......
......@@ -19,10 +19,14 @@
namespace paddle {
namespace framework {
namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope,
platform::Place place,
platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {
: OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {
dev_ctxes_[place_] = dev_ctx;
}
......
......@@ -25,7 +25,8 @@ namespace framework {
namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place,
ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *context);
~ScaleLossGradOpHandle() final;
......
......@@ -17,6 +17,9 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
namespace paddle {
namespace framework {
namespace details {
// A SSA graph used by parallel executor.
struct SSAGraph {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a different version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
std::vector<std::unique_ptr<OpHandleBase>> ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -17,8 +17,8 @@
namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
for (auto &var_map : graph->vars_) {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
......@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = (*it_new)->generated_op_;
auto &read_ops = (*it_old)->pending_ops_;
OpHandleBase *write_op = (*it_new)->GeneratedOp();
const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
......@@ -37,10 +37,11 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
continue;
}
auto *dep_var = new DummyVarHandle();
auto *dep_var = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
}
}
}
......@@ -48,13 +49,20 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &var_holders = graph->vars_[place_offset];
auto &var_holder = var_holders[each_var_name];
Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr;
if (var_holder.empty()) {
var = new VarHandle(0, place_offset, each_var_name, place);
if (node->Var()) {
var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
node->Name(), place);
} else {
var = new VarHandle(
graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0,
place_offset, node->Name(), place);
}
var_holder.emplace_back(var);
} else {
var = var_holder.rbegin()->get();
......@@ -62,24 +70,26 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var;
}
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place);
auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) {
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle();
graph->dep_vars_.emplace(dummy_leaf);
auto *dummy_leaf = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}
......
......@@ -16,20 +16,42 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class SSAGraphBuilder {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder : public ir::Pass {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......@@ -42,20 +64,19 @@ class SSAGraphBuilder {
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(SSAGraph *graph);
static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
const std::string &each_var_name,
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node,
const platform::Place &place,
size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, const platform::Place &place,
size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph);
static void AddOutputToLeafOps(Graph *graph);
};
} // namespace details
} // namespace framework
......
......@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
......@@ -28,12 +28,12 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var);
if (var->generated_op_ == nullptr) {
if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var);
}
};
for (auto &var_map : graph->vars_) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
......@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
}
}
for (auto &var : graph->dep_vars_) {
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) {
insert_pending_var(var.get());
}
for (auto &op : graph->ops_) {
for (auto &op : graph->Get<GraphOps>("ops")) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
} else {
......@@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
for (auto *op : ready_var->PendingOps()) {
auto &deps = --pending_ops[op];
if (deps == 0) {
ready_ops.insert(op);
......
......@@ -21,7 +21,6 @@
namespace paddle {
namespace framework {
namespace details {
struct SSAGraph;
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
......@@ -29,17 +28,17 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return std::move(new_graph);
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
bool IsValidGraph(const SSAGraph* graph) const;
bool IsValidGraph(const Graph* graph) const;
private:
std::unique_ptr<SSAGraphBuilder> builder_;
......
......@@ -18,8 +18,8 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
......
......@@ -14,15 +14,15 @@
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
template <typename Callback>
static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
static inline void IterAllVar(const Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
......@@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
}
}
for (auto &var : graph.dep_vars_) {
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) {
callback(*var);
}
}
void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
void GraphvizSSAGraphPrinter::Print(const Graph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
......@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
});
size_t op_id = 0;
for (auto &op : graph.ops_) {
for (auto &op : graph.Get<GraphOps>("ops")) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......
......@@ -21,16 +21,16 @@
namespace paddle {
namespace framework {
namespace details {
struct SSAGraph;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0;
virtual void Print(const Graph& graph, std::ostream& sout) const = 0;
};
class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public:
void Print(const SSAGraph& graph, std::ostream& sout) const override;
void Print(const Graph& graph, std::ostream& sout) const override;
};
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
......@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
printer_->Print(*graph, stream_ref_);
return graph;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_);
return std::move(new_graph);
}
int GetVarDeviceID(const std::string& var_name) const override {
......
......@@ -14,13 +14,14 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph)
const std::vector<platform::Place> &places, std::unique_ptr<Graph> &&graph)
: graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr),
......@@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->vars_) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
}
}
}
for (auto &var : graph_->dep_vars_) {
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
InsertPendingVar(&pending_vars, &ready_vars, var.get());
}
for (auto &op : graph_->ops_) {
for (auto &op : graph_->Get<details::GraphOps>("ops")) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
......@@ -64,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
&pending_vars, &ready_vars, &fetch_data);
InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
&pending_ops, &pending_vars, &ready_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
......@@ -125,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
......@@ -151,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
......@@ -158,7 +161,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
......@@ -168,14 +171,16 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program "
"is not set to ParallelExecutor)");
auto &vars = fetched_var_it->second;
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
&local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
......@@ -186,7 +191,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var);
}
auto *fetch_dummy = new DummyVarHandle();
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
......@@ -204,7 +210,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
pending_vars->insert(var);
if (var->generated_op_ == nullptr) {
if (var->GeneratedOp() == nullptr) {
ready_vars->Push(var);
}
}
......
......@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
......@@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph);
std::unique_ptr<Graph> &&graph);
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
......@@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op);
private:
std::unique_ptr<SSAGraph> graph_;
std::unique_ptr<Graph> graph_;
std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
......@@ -71,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
......
......@@ -13,11 +13,14 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <sstream>
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
......@@ -25,19 +28,60 @@ namespace framework {
namespace details {
class OpHandleBase;
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
//
// VarHandleBase is the var node in the dependency graph.
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {}
virtual ~VarHandleBase();
virtual std::string DebugString() const = 0;
void AddInput(OpHandleBase* in, ir::Node* node) {
node_->inputs.clear();
node_->inputs.push_back(node);
generated_op_ = in;
}
void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) {
pending_ops_.insert(out);
node_->outputs.push_back(node);
}
}
void RemoveOutput(OpHandleBase* out, ir::Node* node) {
pending_ops_.erase(out);
node_->outputs.erase(
std::remove(node_->outputs.begin(), node_->outputs.end(), node),
node_->outputs.end());
}
void ClearGeneratedOp() {
generated_op_ = nullptr;
node_->inputs.clear();
}
OpHandleBase* GeneratedOp() { return generated_op_; }
const std::unordered_set<OpHandleBase*>& PendingOps() const {
return pending_ops_;
}
ir::Node* Node() { return node_; }
protected:
// The operator who generate this variable. nullptr if the variable
// is a root node.
OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready.
std::unordered_set<OpHandleBase*> pending_ops_;
ir::Node* node_;
};
// VarHandle is actually a single version of Runtime Variable.
......@@ -46,11 +90,14 @@ struct VarHandleBase {
//
// NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override;
VarHandle(size_t version, size_t scope_index, std::string name,
platform::Place place)
: version_(version),
VarHandle(ir::Node* node, size_t version, size_t scope_index,
std::string name, platform::Place place)
: VarHandleBase(node),
version_(version),
scope_idx_(scope_index),
name_(std::move(name)),
place_(std::move(place)) {}
......@@ -70,6 +117,8 @@ struct VarHandle : public VarHandleBase {
// Dummy Variable. It is used to represent dependencies between operators
struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
std::string DebugString() const override;
};
......
cc_library(graph SRCS graph.cc node)
cc_library(node SRCS node.cc)
cc_library(pass SRCS pass.cc graph node)
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace framework {
// NOTE(paddle-dev): This graph contains circle.
Graph::Graph(const ProgramDesc &program) : program_(program) {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var);
}
std::map<std::string, ir::Node *> var_nodes;
for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op);
for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name);
} else if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
} else {
// TODO(paddle-dev): Seems some assumption doesn't hold?
LOG(ERROR) << op->Type()
<< " input var not in all_var list: " << each_var_name;
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name] = var;
}
node->inputs.push_back(var);
var->outputs.push_back(node);
}
for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name);
} else {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
}
node->outputs.push_back(var);
var->inputs.push_back(node);
}
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
class Graph {
public:
explicit Graph(const ProgramDesc& program);
virtual ~Graph() {
for (auto& attr : attrs_) {
attr_dels_[attr.first]();
}
attrs_.clear();
attr_dels_.clear();
}
template <typename AttrType>
AttrType& Get(const std::string& attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
}
template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
ir::Node* CreateVarNode(VarDesc* var_desc) {
nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get();
}
ir::Node* CreateOpNode(OpDesc* op_desc) {
nodes.emplace_back(new ir::Node(op_desc));
return nodes.back().get();
}
ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) {
nodes.emplace_back(new ir::Node(name, type));
return nodes.back().get();
}
std::vector<std::unique_ptr<ir::Node>> nodes;
private:
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc& program_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class NOP : public OperatorBase {
public:
NOP(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope &scope,
const platform::Place &place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker,
paddle::framework::SumOpVarTypeInference);
REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
paddle::framework::SumOpMaker);
namespace paddle {
namespace framework {
TEST(GraphTest, Basic) {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
std::unique_ptr<Graph> g(new Graph(prog));
ASSERT_EQ(g->nodes[0]->Name(), "sum");
ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a");
ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b");
ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c");
ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out");
ASSERT_EQ(g->nodes[1]->Name(), "test_a");
ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[2]->Name(), "test_b");
ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[3]->Name(), "test_c");
ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[4]->Name(), "test_out");
ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes.size(), 5);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace framework {} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
namespace ir {
class Node {
public:
enum class Type { kOperation, kVariable };
explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()),
var_desc_(var_desc),
op_desc_(nullptr),
type_(Type::kVariable) {}
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(op_desc),
type_(Type::kOperation) {}
Type NodeType() const { return type_; }
std::string Name() const { return name_; }
VarDesc* Var() {
PADDLE_ENFORCE(type_ == Type::kVariable);
return var_desc_;
}
OpDesc* Op() {
PADDLE_ENFORCE(type_ == Type::kOperation);
return op_desc_;
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
protected:
const std::string name_;
VarDesc* var_desc_;
OpDesc* op_desc_;
Type type_;
private:
DISABLE_COPY_AND_ASSIGN(Node);
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -18,6 +18,8 @@ limitations under the License. */
#include <tuple>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
......@@ -129,12 +131,11 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW("Not compiled with CUDA.");
#endif
}
builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph(new Graph(main_program));
graph = builder_->Apply(std::move(graph));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places,
builder_->Build(main_program)));
exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
......
......@@ -19,10 +19,14 @@ function (inference_analysis_test TARGET)
set(multiValueArgs SRCS)
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(mem_opt "")
if(WITH_GPU)
set(mem_opt "--fraction_of_gpu_memory_to_use=0.5")
endif()
cc_test(${TARGET}
SRCS "${analysis_test_SRCS}"
DEPS analysis
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5)
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING)
endfunction(inference_analysis_test)
......
......@@ -66,6 +66,7 @@ bool NativePaddlePredictor::Init(
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
PADDLE_ENFORCE_NOT_NULL(sub_scope_, "create sub scope fail");
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
......@@ -102,7 +103,6 @@ bool NativePaddlePredictor::Init(
NativePaddlePredictor::~NativePaddlePredictor() {
if (sub_scope_) {
PADDLE_ENFORCE_NOT_NULL(scope_, "Should have parent scope!");
scope_->DeleteScope(sub_scope_);
}
}
......
......@@ -57,4 +57,4 @@ By specifying the engine kind and config, one can get a specific implementation.
## Reference
- [paddle_inference_api.h](./paddle_inference_api.h)
- [some demos](./demo)
- [some demos](./demo_ci)
......@@ -83,5 +83,5 @@ CHECK(predictor->Run(slots, &outputs));
## 详细代码参考
- [inference demos](./demo)
- [复杂单线程/多线程例子](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/contrib/inference/test_paddle_inference_api_impl.cc)
- [inference demos](./demo_ci)
- [复杂单线程/多线程例子](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/api/test_api_impl.cc)
......@@ -38,6 +38,7 @@ limitations under the License. */
#endif
#endif
#include <boost/any.hpp>
#include <boost/mpl/comparison.hpp>
#include <boost/mpl/less_equal.hpp>
#include <boost/variant.hpp>
......@@ -248,15 +248,11 @@ PYBIND11_PLUGIN(core) {
#endif
})
.def("rows", [](SelectedRows &self) {
#ifndef PADDLE_WITH_CUDA
return self.rows();
#else
auto rows = self.rows();
std::vector<int64_t> new_rows;
new_rows.reserve(rows.size());
std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
return new_rows;
#endif
auto rows = self.rows();
std::vector<int64_t> new_rows;
new_rows.reserve(rows.size());
std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
return new_rows;
});
py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
......
......@@ -30,7 +30,9 @@ class RecordIOWriter {
public:
RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
size_t max_num_record)
: stream_(filename), writer_(&stream_, compressor, max_num_record) {}
: closed_(false),
stream_(filename),
writer_(&stream_, compressor, max_num_record) {}
void AppendTensor(const framework::LoDTensor& tensor) {
tensors_.push_back(tensor);
......@@ -47,9 +49,17 @@ class RecordIOWriter {
PADDLE_ENFORCE(tensors_.empty());
writer_.Flush();
stream_.close();
closed_ = true;
}
~RecordIOWriter() {
if (!closed_) {
Close();
}
}
private:
bool closed_;
std::vector<framework::LoDTensor> tensors_;
std::ofstream stream_;
recordio::Writer writer_;
......
......@@ -68,8 +68,14 @@ def reader_creator(image_filename, label_filename, buffer_size):
for i in xrange(buffer_size):
yield images[i, :], int(labels[i])
finally:
m.terminate()
l.terminate()
try:
m.terminate()
except:
pass
try:
l.terminate()
except:
pass
return reader
......
......@@ -35,6 +35,7 @@ import io
import evaluator
import initializer
import layers
import contrib
import nets
import optimizer
import backward
......@@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
'io',
'initializer',
'layers',
'contrib',
'transpiler',
'nets',
'optimizer',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import decoder
from decoder import *
__all__ = decoder.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import beam_search_decoder
from beam_search_decoder import *
__all__ = beam_search_decoder.__all__
......@@ -14,6 +14,7 @@
from __future__ import print_function
import argparse
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle
import sys
import numpy
......@@ -134,4 +135,4 @@ def main(use_cuda):
if __name__ == '__main__':
# for use_cuda in (False, True):
main(use_cuda=True)
main(use_cuda=core.is_compiled_with_cuda())
......@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import print_function
import paddle.fluid.core as core
import math
import os
import sys
......@@ -257,6 +258,8 @@ def inject_test_method(use_cuda, parallel, nn_type, combine):
def inject_all_tests():
for use_cuda in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
for parallel in (False, True):
for nn_type in ('mlp', 'conv'):
inject_test_method(use_cuda, parallel, nn_type, True)
......
......@@ -245,7 +245,7 @@ def inject_test_method(use_cuda, is_sparse, is_parallel):
is_sparse=is_sparse,
is_parallel=is_parallel)
if use_cuda and is_sparse:
if (not fluid.core.is_compiled_with_cuda() or use_cuda) and is_sparse:
fn = __impl__
else:
# skip the other test when on CI server
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A simple machine translation demo using beam search decoder.
"""
import contextlib
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
from paddle.fluid.contrib.decoder.beam_search_decoder import *
import unittest
import os
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)
hidden_dim = 32
word_dim = 32
decoder_size = hidden_dim
IS_SPARSE = True
batch_size = 2
max_length = 8
topk_size = 50
trg_dic_size = 10000
beam_size = 2
def encoder():
# encoder
src_word = layers.data(
name="src_word", shape=[1], dtype='int64', lod_level=1)
src_embedding = layers.embedding(
input=src_word,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=IS_SPARSE)
fc1 = layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4)
encoder_out = layers.sequence_last_step(input=lstm_hidden0)
return encoder_out
def decoder_state_cell(context):
h = InitState(init=context, need_reorder=True)
state_cell = StateCell(inputs={'x': None}, states={'h': h}, out_state='h')
@state_cell.state_updater
def updater(state_cell):
current_word = state_cell.get_input('x')
prev_h = state_cell.get_state('h')
# make sure lod of h heritted from prev_h
h = layers.fc(input=[prev_h, current_word],
size=decoder_size,
act='tanh')
state_cell.set_state('h', h)
return state_cell
def decoder_train(state_cell):
# decoder
trg_language_word = layers.data(
name="target_word", shape=[1], dtype='int64', lod_level=1)
trg_embedding = layers.embedding(
input=trg_language_word,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=IS_SPARSE)
decoder = TrainingDecoder(state_cell)
with decoder.block():
current_word = decoder.step_input(trg_embedding)
decoder.state_cell.compute_state(inputs={'x': current_word})
current_score = layers.fc(input=decoder.state_cell.get_state('h'),
size=target_dict_dim,
act='softmax')
decoder.state_cell.update_states()
decoder.output(current_score)
return decoder()
def decoder_decode(state_cell):
init_ids = layers.data(
name="init_ids", shape=[1], dtype="int64", lod_level=2)
init_scores = layers.data(
name="init_scores", shape=[1], dtype="float32", lod_level=2)
decoder = BeamSearchDecoder(
state_cell=state_cell,
init_ids=init_ids,
init_scores=init_scores,
target_dict_dim=target_dict_dim,
word_dim=word_dim,
input_var_dict={},
topk_size=topk_size,
sparse_emb=IS_SPARSE,
max_len=max_length,
beam_size=beam_size,
end_id=1,
name=None)
decoder.decode()
translation_ids, translation_scores = decoder()
return translation_ids, translation_scores
def train_main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
context = encoder()
state_cell = decoder_state_cell(context)
rnn_out = decoder_train(state_cell)
label = layers.data(
name="target_next_word", shape=[1], dtype='int64', lod_level=1)
cost = layers.cross_entropy(input=rnn_out, label=label)
avg_cost = layers.mean(x=cost)
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-3)
optimizer.minimize(avg_cost)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)
feed_order = ['src_word', 'target_word', 'target_next_word']
exe = Executor(place)
def train_loop(main_program):
exe.run(framework.default_startup_program())
feed_list = [
main_program.global_block().var(var_name) for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
for pass_id in xrange(1):
for batch_id, data in enumerate(train_reader()):
outs = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_cost])
avg_cost_val = np.array(outs[0])
print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) +
" avg_cost=" + str(avg_cost_val))
if batch_id > 3:
break
train_loop(framework.default_main_program())
def decode_main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
context = encoder()
state_cell = decoder_state_cell(context)
translation_ids, translation_scores = decoder_decode(state_cell)
exe = Executor(place)
exe.run(framework.default_startup_program())
init_ids_data = np.array([0 for _ in range(batch_size)], dtype='int64')
init_scores_data = np.array(
[1. for _ in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_lod = [1] * batch_size
init_lod = [init_lod, init_lod]
init_ids = fluid.create_lod_tensor(init_ids_data, init_lod, place)
init_scores = fluid.create_lod_tensor(init_scores_data, init_lod, place)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)
feed_order = ['src_word']
feed_list = [
framework.default_main_program().global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
data = train_reader().next()
feed_dict = feeder.feed(map(lambda x: [x[0]], data))
feed_dict['init_ids'] = init_ids
feed_dict['init_scores'] = init_scores
result_ids, result_scores = exe.run(
framework.default_main_program(),
feed=feed_dict,
fetch_list=[translation_ids, translation_scores],
return_numpy=False)
print result_ids.lod()
class TestBeamSearchDecoder(unittest.TestCase):
pass
@contextlib.contextmanager
def scope_prog_guard():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
def inject_test_train(use_cuda):
f_name = 'test_{0}_train'.format('cuda' if use_cuda else 'cpu')
def f(*args):
with scope_prog_guard():
train_main(use_cuda)
setattr(TestBeamSearchDecoder, f_name, f)
def inject_test_decode(use_cuda, decorator=None):
f_name = 'test_{0}_decode'.format('cuda' if use_cuda else 'cpu', 'sparse')
def f(*args):
with scope_prog_guard():
decode_main(use_cuda)
if decorator is not None:
f = decorator(f)
setattr(TestBeamSearchDecoder, f_name, f)
for _use_cuda_ in (False, True):
inject_test_train(_use_cuda_)
for _use_cuda_ in (False, True):
_decorator_ = None
inject_test_decode(use_cuda=_use_cuda_, decorator=_decorator_)
if __name__ == '__main__':
unittest.main()
......@@ -12,6 +12,11 @@ endif(NOT WITH_MKLDNN)
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_recv_op)
list(REMOVE_ITEM TEST_OPS test_dist_transpiler)
list(REMOVE_ITEM TEST_OPS test_simple_dist_transpiler)
list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op)
LIST(REMOVE_ITEM TEST_OPS test_dist_mnist)
LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec)
endif(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
......@@ -47,9 +52,11 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
if(WITH_DISTRIBUTE)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 180)
endif()
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 180)
......@@ -100,6 +100,8 @@ class TestBeamSearchDecodeOp(unittest.TestCase):
np.array_equal(np.array(sentence_scores), expected_data))
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestBeamSearchDecodeOpGPU(TestBeamSearchDecodeOp):
def setUp(self):
self.scope = core.Scope()
......
......@@ -191,12 +191,16 @@ class TestWithDilation(TestConv2dTransposeOp):
# ------------ test_cudnn ------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNN(TestConv2dTransposeOp):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithPad(TestWithPad):
def init_test_case(self):
self.pad = [1, 1]
......@@ -212,6 +216,8 @@ class TestCUDNNWithPad(TestWithPad):
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithStride(TestWithStride):
def init_test_case(self):
self.pad = [1, 1]
......@@ -227,6 +233,8 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithGroups(TestWithGroups):
def init_test_case(self):
self.pad = [1, 1]
......
......@@ -197,12 +197,16 @@ class TestWithDilation(TestConv3dTransposeOp):
# ------------ test_cudnn ------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNN(TestConv3dTransposeOp):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithPad(TestWithPad):
def init_test_case(self):
self.pad = [1, 1, 1]
......@@ -218,6 +222,8 @@ class TestCUDNNWithPad(TestWithPad):
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithStride(TestWithStride):
def init_test_case(self):
self.pad = [1, 1, 1]
......@@ -233,6 +239,8 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithGroups(TestWithGroups):
def init_test_case(self):
self.pad = [1, 1, 1]
......
......@@ -15,6 +15,7 @@
import paddle.dataset.flowers as flowers
import math
import paddle.fluid as fluid
import paddle.fluid.core as core
import unittest
import numpy as np
import paddle
......@@ -92,7 +93,8 @@ class TestFetchOp(unittest.TestCase):
train_inputs.append(tst_reader_iter.next())
os.environ['CPU_NUM'] = str(4)
self.parallel_exe(train_inputs, seed=1, use_cuda=True)
if core.is_compiled_with_cuda():
self.parallel_exe(train_inputs, seed=1, use_cuda=True)
self.parallel_exe(train_inputs, seed=1, use_cuda=False)
......@@ -137,7 +139,8 @@ class TestFeedParallel(unittest.TestCase):
def test_feed_op(self):
os.environ['CPU_NUM'] = str(4)
self.parallel_exe(use_cuda=True, seed=1)
if core.is_compiled_with_cuda():
self.parallel_exe(use_cuda=True, seed=1)
self.parallel_exe(use_cuda=False, seed=1)
......
......@@ -14,6 +14,7 @@
from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import paddle
import paddle.dataset.mnist as mnist
......@@ -98,6 +99,8 @@ class TestMNIST(TestParallelExecutorBase):
MNIST_RECORDIO_FILE, reader, feeder)
def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(simple_fc_net, use_cuda=use_cuda)
self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True)
......@@ -122,6 +125,8 @@ class TestMNIST(TestParallelExecutorBase):
self.check_simple_fc_convergence(False, True)
def check_simple_fc_parallel_accuracy(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence(
......@@ -155,6 +160,8 @@ class TestMNIST(TestParallelExecutorBase):
self.check_simple_fc_parallel_accuracy(False, True)
def check_batchnorm_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda)
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
......
......@@ -16,6 +16,7 @@ import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import paddle.fluid.core as core
from parallel_executor_test_base import TestParallelExecutorBase
import unittest
import math
......@@ -140,6 +141,9 @@ class TestResnet(TestParallelExecutorBase):
use_reduce=False,
iter=20):
if use_cuda and not core.is_compiled_with_cuda():
return
os.environ['CPU_NUM'] = str(4)
def _cosine_decay(learning_rate, step_each_epoch, epochs=120):
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import unittest
import os
......@@ -92,16 +93,18 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def test_parallel_testing(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
self.check_network_convergence(
use_cuda=True, build_strategy=build_strategy)
if core.is_compiled_with_cuda():
self.check_network_convergence(
use_cuda=True, build_strategy=build_strategy)
self.check_network_convergence(
use_cuda=False, build_strategy=build_strategy)
def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.check_network_convergence(
use_cuda=True, build_strategy=build_strategy)
if core.is_compiled_with_cuda():
self.check_network_convergence(
use_cuda=True, build_strategy=build_strategy)
self.check_network_convergence(
use_cuda=False, build_strategy=build_strategy)
......
......@@ -56,6 +56,8 @@ class TestPrintOpCPU(unittest.TestCase):
return_numpy=False)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestPrintOpGPU(TestPrintOpCPU):
def setUp(self):
self.place = core.CUDAPlace(0)
......
......@@ -79,12 +79,18 @@ class TestProfiler(unittest.TestCase):
pass_acc_calculator.add(value=acc, weight=b_size)
pass_acc = pass_acc_calculator.eval()
@unittest.skipIf(not core.is_compiled_with_cuda(),
"profiler is enabled only with GPU")
def test_cpu_profiler(self):
self.net_profiler('CPU')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"profiler is enabled only with GPU")
def test_cuda_profiler(self):
self.net_profiler('GPU')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"profiler is enabled only with GPU")
def test_all_profiler(self):
self.net_profiler('All', '/tmp/profile_out')
with open('/tmp/profile_out', 'r') as f:
......
......@@ -61,6 +61,8 @@ class TestSequenceSoftmaxOp(OpTest):
# ----------------cudnn Sequencesoftmax----------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSequenceSoftmaxCUDNNOp(TestSequenceSoftmaxOp):
def init_op_type(self):
self.use_cudnn = True
......
......@@ -63,11 +63,15 @@ class TestSoftmaxOp(OpTest):
self.check_grad(["X"], "Out", max_relative_error=0.01)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_cudnn = True
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxFP16Op(TestSoftmaxOp):
def init_kernel_type(self):
self.dtype = np.float16
......@@ -79,6 +83,8 @@ class TestSoftmaxFP16Op(TestSoftmaxOp):
self.check_output_with_place(place, atol=1e-3)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_cudnn = True
......
......@@ -68,8 +68,14 @@ def reader_creator(image_filename, label_filename, buffer_size):
for i in xrange(buffer_size):
yield images[i, :], int(labels[i])
finally:
m.terminate()
l.terminate()
try:
m.terminate()
except:
pass
try:
l.terminate()
except:
pass
return reader
......
......@@ -104,6 +104,8 @@ packages=['paddle',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.layers',
'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册