提交 ce725863 编写于 作者: W wangguibao

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into async_executor

...@@ -45,7 +45,7 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML") ...@@ -45,7 +45,7 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE() ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF() ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
...@@ -54,7 +54,7 @@ ExternalProject_Add( ...@@ -54,7 +54,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS} DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944" GIT_TAG "21fb5f2af1dd14e132af4f1b79160977ee487818"
PREFIX ${MKLDNN_SOURCES_DIR} PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None ...@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
...@@ -189,6 +190,7 @@ paddle.fluid.layers.batch ArgSpec(args=['reader', 'batch_size'], varargs=None, k ...@@ -189,6 +190,7 @@ paddle.fluid.layers.batch ArgSpec(args=['reader', 'batch_size'], varargs=None, k
paddle.fluid.layers.double_buffer ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.double_buffer ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.random_data_generator ArgSpec(args=['low', 'high', 'shapes', 'lod_levels', 'for_parallel'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.layers.random_data_generator ArgSpec(args=['low', 'high', 'shapes', 'lod_levels', 'for_parallel'], varargs=None, keywords=None, defaults=(True,))
paddle.fluid.layers.py_reader ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lod_levels', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, None, True)) paddle.fluid.layers.py_reader ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lod_levels', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, None, True))
paddle.fluid.layers.create_py_reader_by_data ArgSpec(args=['capacity', 'feed_list', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, True))
paddle.fluid.layers.Preprocessor.__init__ ArgSpec(args=['self', 'reader', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.Preprocessor.__init__ ArgSpec(args=['self', 'reader', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Preprocessor.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.layers.Preprocessor.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.layers.Preprocessor.inputs ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Preprocessor.inputs ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
......
...@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle { ...@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
std::vector<p::Place> place_list_; std::vector<p::Place> place_list_;
bool use_gpu_; bool use_gpu_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle { ...@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
} }
void InitBroadcastOp(size_t input_scope_idx) { void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear();
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle { ...@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n = nodes_.emplace_back(
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation); ir::CreateNodeForTest("node0", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get())); place_list_, nccl_ctxs_.get());
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get())); place_list_, nccl_ctxs_.get());
#else #else
op_handle_.reset( op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
new BroadcastOpHandle(n.get(), local_scopes_, place_list_)); place_list_);
#endif #endif
} }
std::unique_ptr<ir::Node> v = nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
place_list_[input_scope_idx]); "input", place_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v2 = nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
dummy_var_handle->ClearGeneratedOp(); dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
...@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle { ...@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get());
} }
std::unique_ptr<ir::Node> v3 = nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", place_list_[j]); new VarHandle(nodes_.back().get(), 2, j, "out", place_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v4 = nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
out_dummy_var_handle->ClearGeneratedOp(); out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle); op_handle_->AddOutput(out_dummy_var_handle);
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_ + pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps 1), // add one more thread for generate op_deps
fetch_ctxs_(places) { fetch_ctxs_(places) {
auto &ops = graph_->Get<details::GraphOps>("ops"); for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
for (auto &op : ops) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op.get(), dep); op_deps_.emplace(op, dep);
if (dep == 0) { if (dep == 0) {
bootstrap_ops_.emplace_back(op.get()); bootstrap_ops_.emplace_back(op);
} }
} }
...@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
...@@ -110,7 +109,10 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -110,7 +109,10 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
complete_q->Pop(); complete_q->Pop();
} }
} }
exception_.ReThrow(); if (exception_.IsCaught()) {
ClearFetchOp(graph_.get(), &fetch_ops);
exception_.ReThrow();
}
} }
num_complete += num_comp; num_complete += num_comp;
} }
......
...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, ...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
offset_(offset), offset_(offset),
local_scopes_(local_scopes) {} local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() { FetchOpHandle::~FetchOpHandle() {}
for (auto *input_var : inputs_) {
input_var->RemoveOutput(this, this->Node());
}
}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
......
...@@ -22,8 +22,10 @@ namespace details { ...@@ -22,8 +22,10 @@ namespace details {
struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
std::vector<std::string> out_varnames_; std::vector<std::string> out_varnames_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) { void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear();
// initialize scope and var // initialize scope and var
for (size_t i = 0; i < place_list_.size(); ++i) { for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
...@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { ...@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
} }
// create op handle node // create op handle node
std::unique_ptr<ir::Node> n = nodes_.emplace_back(
ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation); ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle( op_handle_ = new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get())); nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else #else
PADDLE_THROW("CUDA is not supported."); PADDLE_THROW("CUDA is not supported.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle( op_handle_ = new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get())); nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else #else
op_handle_.reset( op_handle_ = new FusedBroadcastOpHandle(nodes_.back().get(),
new FusedBroadcastOpHandle(n.get(), local_scopes_, place_list_)); local_scopes_, place_list_);
#endif #endif
} }
for (size_t i = 0; i < input_scope_idxes.size(); ++i) { for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle // add input var handle
std::unique_ptr<ir::Node> in_node = nodes_.emplace_back(
ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable); ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable));
VarHandle* in_var_handle = VarHandle* in_var_handle =
new VarHandle(in_node.get(), 1, input_scope_idxes[i], "in_var" + i, new VarHandle(nodes_.back().get(), 1, input_scope_idxes[i],
place_list_[input_scope_idxes[i]]); "in_var" + i, place_list_[input_scope_idxes[i]]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add output var handle // add output var handle
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
std::unique_ptr<ir::Node> out_node = nodes_.emplace_back(
ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable); ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle = new VarHandle(
new VarHandle(out_node.get(), 2, j, "out_var" + i, place_list_[j]); nodes_.back().get(), 2, j, "out_var" + i, place_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
......
...@@ -31,9 +31,10 @@ struct TestGatherOpHandle { ...@@ -31,9 +31,10 @@ struct TestGatherOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void WaitAll() { void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) { for (size_t j = 0; j < ctxs_.size(); ++j) {
...@@ -70,7 +71,7 @@ struct TestGatherOpHandle { ...@@ -70,7 +71,7 @@ struct TestGatherOpHandle {
} }
void InitGatherOp(size_t input_scope_idx) { void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes; nodes_.clear();
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -82,44 +83,45 @@ struct TestGatherOpHandle { ...@@ -82,44 +83,45 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_.reset( op_handle_ =
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_);
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes_.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
in_dummy_var_handle->ClearGeneratedOp(); in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle =
"out", gpu_list_[input_scope_idx]); new VarHandle(nodes_.back().get(), 2, input_scope_idx, "out",
gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
op_handle_->AddOutput(dummy_var_handle); op_handle_->AddOutput(dummy_var_handle);
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h" #include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( ...@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl( std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const { std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops); OpGraphView graph_view(all_ops);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue; if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free = bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view); IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair);
} }
} }
} }
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) { for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get()); insert_pending_var(var);
} }
for (auto &op : graph->Get<GraphOps>(kGraphOps)) { for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); pending_ops.insert({op, op->NoDupInputSize()});
} }
} }
...@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
REGISTER_PASS(multi_devices_check_pass, REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker) paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars) .RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars) .RequireGraphAttr(paddle::framework::details::kGraphDepVars);
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
...@@ -34,7 +34,14 @@ ...@@ -34,7 +34,14 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
namespace { namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
void PolishGraphToSupportDataHazards(ir::Graph *graph) { void PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
...@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, ...@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
} }
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
var = var_holder.rbegin()->get(); var = *var_holder.rbegin();
} }
return var; return var;
} }
...@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ...@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result.Set(kGraphVars, new GraphVars(places_.size())); result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps); result.Set(kGraphOps, new GraphOps);
result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// related op in the place 0 // related op in the place 0
...@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
bool is_forwarding = true; bool is_forwarding = true;
bool is_dist_train = false; bool is_dist_train = false;
std::unordered_map<std::string, int> sharded_var_device;
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(&result, node); int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place."); "Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") { if (node->Op()->Type() == "recv") {
...@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} else if (boost::get<int>(node->Op()->GetAttr( } else if (boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) == OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) { static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node); int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name); bcast_var_name_set[op_dev_id].emplace(origin_param_name);
...@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(result, node); int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
graph->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device.emplace(n->Name(), op_dev_id);
.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
...@@ -398,8 +405,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -398,8 +405,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name}); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device.emplace(g_name, cur_device_id);
.emplace(g_name, cur_device_id);
if (!is_dist_train) { if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
} }
...@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
PADDLE_ENFORCE(!ir::HasCircle(result)); result.Erase<GraphOps>(kGraphOps);
return graph; return graph;
} }
...@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
auto *in = auto *in =
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( ...@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
for (auto &p_name : bcast_varnames[dev_id]) { for (auto &p_name : bcast_varnames[dev_id]) {
auto *in = auto *in =
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) { for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
auto &p = places_[out_dev_id]; auto &p = places_[out_dev_id];
...@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad);
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
...@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) { for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back());
auto var = new VarHandle( auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p); vars.size(), i, d_name, p);
...@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
} }
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node) const { const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
...@@ -631,16 +638,22 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, ...@@ -631,16 +638,22 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(graph, param_grad[1]); int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetVarDeviceID(
const std::string &varname) const { const ir::Graph &graph, const std::string &varname,
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice); const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname); auto got = sharded_var_device.find(varname);
if (got == sharded_var_device.end()) {
auto pos = varname.find(framework::kNewGradSuffix);
if (pos != std::string::npos) {
got = sharded_var_device.find(varname.substr(0, pos));
}
}
return got == sharded_var_device.end() ? -1 : got->second; return got == sharded_var_device.end() ? -1 : got->second;
} }
...@@ -690,7 +703,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -690,7 +703,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -698,7 +711,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -698,7 +711,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad);
} }
auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
auto var = auto var =
...@@ -709,8 +722,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -709,8 +722,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var; return var;
} }
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateDistTrainOp(
ir::Node *node) const { ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -725,23 +739,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -725,23 +739,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type() == "split_selected_rows" || node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") { node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else if (node->Op()->Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else { } else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
...@@ -759,14 +772,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -759,14 +772,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
} }
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (ir::Node *input : node->inputs) { for (ir::Node *input : node->inputs) {
VarHandle *var = nullptr; VarHandle *var = nullptr;
for (int place_offset = 0; place_offset < num_places; ++place_offset) { for (int place_offset = 0; place_offset < num_places; ++place_offset) {
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset]; auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[input->Name()]; auto &var_holder = var_holders[input->Name()];
if (!var_holder.empty()) { if (!var_holder.empty()) {
var = var_holder.rbegin()->get(); var = *var_holder.rbegin();
op_handle->AddInput(var); op_handle->AddInput(var);
} }
} }
...@@ -774,12 +787,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { ...@@ -774,12 +787,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateRPCOp(
ir::Node *node) const { ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name()); op_dev_id =
GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device);
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
...@@ -797,11 +812,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -797,11 +812,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
VLOG(10) << "send grad " << input_var_names[0] << " origin " VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id; << send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(send_param_grad[1], op_dev_id);
.emplace(send_param_grad[1], op_dev_id);
} }
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -811,7 +824,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -811,7 +824,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); op_dev_id =
GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1] << " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id; << " place: " << op_dev_id;
...@@ -819,8 +833,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -819,8 +833,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier, fetch_barrier will run on place 0; // send_barrier, fetch_barrier will run on place 0;
...@@ -839,7 +852,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -839,7 +852,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from // send_barrier, recv, fetch_barrier's inputs are deps var, get them from
// all places // all places
auto p = places_[op_dev_id]; auto p = places_[op_dev_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -847,7 +860,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -847,7 +860,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id; int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(*result, output->Name()); outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1); PADDLE_ENFORCE_NE(outvar_dev_id, -1);
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
......
...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const; int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateRPCOp(
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const; const std::vector<ir::Node *> &nodes) const;
...@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; int dev_id) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const; int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>(kGraphOps)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
......
...@@ -35,23 +35,14 @@ namespace details { ...@@ -35,23 +35,14 @@ namespace details {
// The outside vector is the device vector. Each element of this vector is a // The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name, // map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the // will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles. // `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector< typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars"; const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<VarHandleBase*> GraphDepVars;
const char kGraphDepVars[] = "dep_vars"; const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,19 +20,16 @@ namespace paddle { ...@@ -20,19 +20,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
OpGraphView::OpGraphView( OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) { void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
for (auto &op : ops) { for (auto &op : ops) {
preceding_ops_[op.get()]; preceding_ops_[op];
pending_ops_[op.get()]; pending_ops_[op];
for (auto &var : op->Outputs()) { for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) { for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get()); preceding_ops_[pending_op].insert(op);
pending_ops_[op.get()].insert(pending_op); pending_ops_[op].insert(pending_op);
} }
} }
} }
...@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) { ...@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
"There are duplicate ops in graph."); "There are duplicate ops in graph.");
} }
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const { std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret; std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) { for (auto &pair : preceding_ops_) {
...@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const { ...@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
op == nullptr ? "nullptr" : op->DebugString()); op == nullptr ? "nullptr" : op->DebugString());
} }
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps( const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const { OpHandleBase *op) const {
EnforceHasOp(op); EnforceHasOp(op);
......
...@@ -26,21 +26,16 @@ namespace details { ...@@ -26,21 +26,16 @@ namespace details {
class OpGraphView { class OpGraphView {
public: public:
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops); explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
size_t OpNumber() const;
std::unordered_set<OpHandleBase *> AllOps() const; std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PrecedingOps(
OpHandleBase *op) const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const; const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const; bool HasOp(OpHandleBase *op) const;
private: private:
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops); void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const; void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>> std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
......
...@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; ...@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node. // It's responsible for populating necessary fields of ir::Node.
class OpHandleBase { class OpHandleBase {
public: public:
explicit OpHandleBase(ir::Node *node) : node_(node) {} // Owned by `node`. No need to be deleted explicitly.
explicit OpHandleBase(ir::Node *node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~OpHandleBase(); virtual ~OpHandleBase();
......
...@@ -30,8 +30,8 @@ struct TestReduceOpHandle { ...@@ -30,8 +30,8 @@ struct TestReduceOpHandle {
Scope g_scope_; Scope g_scope_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> param_scopes_; std::vector<Scope *> param_scopes_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase *op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase *> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_; std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h" #include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -71,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -71,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables // Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops // in computation ops
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>> std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map; compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&]( auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op; std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr || if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace())) !platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op; return var_names_in_op;
...@@ -121,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -121,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
}; };
auto update_ref_cnts_from_non_compute_op = [&]( auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) { if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) { for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base); auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
...@@ -151,21 +150,21 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -151,21 +150,21 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get()); AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
} }
} }
} }
} }
}; };
auto &all_ops = graph->Get<GraphOps>(kGraphOps); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue; if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(), in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end()); out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace()); auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node = ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
...@@ -173,7 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -173,7 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node, compute_op->GetScope(), place, in_var_names, ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get()); AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[compute_op] = ref_cnt_handle;
} }
for (auto &op : all_ops) { for (auto &op : all_ops) {
...@@ -181,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -181,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
update_ref_cnts_from_non_compute_op(op, op->Outputs()); update_ref_cnts_from_non_compute_op(op, op->Outputs());
} }
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops; std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) { for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op)); new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) { if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle // Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
......
...@@ -19,14 +19,16 @@ namespace framework { ...@@ -19,14 +19,16 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
for (auto& in_var : op->Inputs()) {
in_var->RemoveOutput(op, op->Node());
}
graph->RemoveNode(op->Node()); graph->RemoveNode(op->Node());
} }
fetch_ops->clear(); fetch_ops->clear();
......
...@@ -38,8 +38,7 @@ class SSAGraphExecutor { ...@@ -38,8 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, ready_vars.get(), version_pair.get()); InsertPendingVar(&pending_vars, ready_vars.get(), version_pair);
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var.get()); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
InsertPendingOp(&pending_ops, op.get()); InsertPendingOp(&pending_ops, op);
} }
} }
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
...@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
ClearFetchOp(graph_.get(), &fetch_ops);
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} else { } else {
continue; continue;
...@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) { BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
...@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
......
...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
BlockingQueue<VarHandleBase *> *ready_vars, BlockingQueue<VarHandleBase *> *ready_vars,
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps( void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
const std::vector<std::string> &fetch_tensors, std::vector<FetchOpHandle *> *fetch_ops,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_set<VarHandleBase *> *pending_vars,
std::unordered_set<VarHandleBase *> *pending_vars, BlockingQueue<VarHandleBase *> *ready_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data); FeedFetchList *fetch_data);
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
......
...@@ -20,6 +20,8 @@ namespace details { ...@@ -20,6 +20,8 @@ namespace details {
VarHandleBase::~VarHandleBase() {} VarHandleBase::~VarHandleBase() {}
VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const { std::string VarHandle::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << name_ << ":" << place_; ss << name_ << ":" << place_;
...@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const { ...@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const {
} }
std::string DummyVarHandle::DebugString() const { return node_->Name(); } std::string DummyVarHandle::DebugString() const { return node_->Name(); }
DummyVarHandle::~DummyVarHandle() {
VLOG(4) << "deleting dummy var handle " << DebugString();
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,10 @@ class OpHandleBase; ...@@ -35,7 +35,10 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
// This is a single assignment graph. // This is a single assignment graph.
struct VarHandleBase { struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {} // Owned by `node`. No need to be deleted explicitly.
explicit VarHandleBase(ir::Node* node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~VarHandleBase(); virtual ~VarHandleBase();
...@@ -94,6 +97,8 @@ struct VarHandleBase { ...@@ -94,6 +97,8 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {} explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~VarHandle();
std::string DebugString() const override; std::string DebugString() const override;
VarHandle(ir::Node* node, size_t version, size_t scope_index, VarHandle(ir::Node* node, size_t version, size_t scope_index,
...@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase { ...@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase {
struct DummyVarHandle : public VarHandleBase { struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {} explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~DummyVarHandle();
std::string DebugString() const override; std::string DebugString() const override;
}; };
......
...@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") ...@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass) cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
......
...@@ -102,6 +102,15 @@ class Graph { ...@@ -102,6 +102,15 @@ class Graph {
attr_dels_[attr_name] = []() {}; attr_dels_[attr_name] = []() {};
} }
template <typename AttrType>
void Erase(const std::string &attr_name) {
PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph",
attr_name);
attr_dels_[attr_name]();
attrs_.erase(attr_name);
attr_dels_.erase(attr_name);
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; } const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
......
...@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph); ...@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph); const Graph &graph);
template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
}
return ret;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,7 +15,10 @@ limitations under the License. */ ...@@ -15,7 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <typeindex>
#include <typeinfo>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -24,9 +27,33 @@ namespace paddle { ...@@ -24,9 +27,33 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// Node should normally created by Graph::CreateXXXNode(). // Node should only created by Graph::CreateXXXNode().
// 1. Every Node should be part of a graph. No dangling Node exists.
// 2. Node only contains members necessary for building graph structure.
// It doesn't contain other unrelated members, such as device, etc.
//
// Sometimes, for specific usages, Node needs to have additional members,
// such as device_placement, version in order to be executed. It is suggested
// to use composition pattern.
//
// class RunnableOp {
// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); }
//
// int any_thing_;
// }
//
// RunnableOp is owned by the ir::Node that composes it. In other words.
// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node
// is deleted from the graph.
class Node { class Node {
public: public:
virtual ~Node() {
if (!wrapper_.empty()) {
VLOG(4) << "ir::Node deleting a wrapper node " << Name();
wrapper_deleter_();
}
}
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var"; static constexpr char kControlDepVarName[] = "__control_var";
...@@ -44,6 +71,29 @@ class Node { ...@@ -44,6 +71,29 @@ class Node {
return op_desc_.get(); return op_desc_.get();
} }
// Set the `wrapper` that wraps the Node. `wrapper` is owned by Node.
template <typename T>
void WrappedBy(T* wrapper) {
if (!wrapper_.empty()) {
wrapper_deleter_();
}
wrapper_ = wrapper;
wrapper_deleter_ = [wrapper]() { delete wrapper; };
wrapper_type_ = std::type_index(typeid(T));
}
// Return a reference to the `wrapper`.
template <typename T>
T& Wrapper() {
return *boost::any_cast<T*>(wrapper_);
}
// Test if the Node is wrapped by type T.
template <typename T>
bool IsWrappedBy() {
return std::type_index(typeid(T)) == wrapper_type_;
}
// Please don't use this API! // Please don't use this API!
int id() const { return id_; } int id() const { return id_; }
...@@ -95,6 +145,11 @@ class Node { ...@@ -95,6 +145,11 @@ class Node {
static int count_; static int count_;
// Please don't use this API or make this public. // Please don't use this API or make this public.
static void ResetId() { count_ = 0; } static void ResetId() { count_ = 0; }
boost::any wrapper_;
std::function<void(void)> wrapper_deleter_;
std::type_index wrapper_type_ = std::type_index(typeid(void));
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
}; };
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RunnableOp {
public:
RunnableOp(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
class RunnableOp2 {
public:
RunnableOp2(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp2() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
TEST(NodeTest, Basic) {
bool alive1 = true;
bool alive2 = true;
std::unique_ptr<Node> n1(CreateNodeForTest("n1", Node::Type::kVariable));
std::unique_ptr<Node> n2(CreateNodeForTest("n2", Node::Type::kVariable));
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp2>());
new RunnableOp(n1.get(), &alive1);
new RunnableOp2(n2.get(), &alive2);
EXPECT_TRUE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_TRUE(n2->IsWrappedBy<RunnableOp2>());
EXPECT_TRUE(alive1);
EXPECT_TRUE(alive2);
n1.reset(nullptr);
n2.reset(nullptr);
EXPECT_FALSE(alive1);
EXPECT_FALSE(alive2);
}
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable& var) { ...@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>(); return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
} }
const Tensor* GetTensorFromVar(const Variable& var) { const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
if (var.IsType<LoDTensor>()) { if (var.IsType<LoDTensor>()) {
return static_cast<const Tensor*>(&(var.Get<LoDTensor>())); return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<SelectedRows>()) {
...@@ -369,7 +369,7 @@ const Tensor* GetTensorFromVar(const Variable& var) { ...@@ -369,7 +369,7 @@ const Tensor* GetTensorFromVar(const Variable& var) {
} }
} }
static Tensor* GetMutableTensorFromVar(Variable* var) { Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
...@@ -414,8 +414,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const { ...@@ -414,8 +414,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); return Input<LoDTensor>(name);
return var == nullptr ? nullptr : GetTensorFromVar(*var);
} }
template <> template <>
...@@ -425,17 +424,21 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( ...@@ -425,17 +424,21 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
std::vector<const Tensor*> res; std::vector<const Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> const Tensor* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(*var); if (var == nullptr) return nullptr;
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
return &(var->Get<LoDTensor>());
}); });
return res; return res;
} }
template <> template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const { Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto var = OutputVar(name); return Output<LoDTensor>(name);
return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
} }
template <> template <>
...@@ -445,10 +448,14 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( ...@@ -445,10 +448,14 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
std::vector<Tensor*> res; std::vector<Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> Tensor* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr if (var == nullptr) return nullptr;
: GetMutableTensorFromVar(var); PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
return var->GetMutable<LoDTensor>();
}); });
return res; return res;
} }
...@@ -768,11 +775,12 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -768,11 +775,12 @@ void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& transfer_scope) const { const Scope& transfer_scope) const {
for (auto& var_name : inplace_vars) { for (auto& var_name : inplace_vars) {
VLOG(3) << "share inplace var " + var_name + " back to it's original scope"; VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name)); auto* original_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(scope.FindVar(var_name));
auto* var = transfer_scope.FindVar(var_name); auto* var = transfer_scope.FindVar(var_name);
PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr", PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
var_name); var_name);
auto* transformed_tensor = GetTensorFromVar(*var); auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
} }
} }
...@@ -789,7 +797,7 @@ Scope* OperatorWithKernel::TryTransferData( ...@@ -789,7 +797,7 @@ Scope* OperatorWithKernel::TryTransferData(
continue; continue;
} }
auto* tensor_in = GetTensorFromVar(*var); auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
if (!tensor_in->IsInitialized()) { if (!tensor_in->IsInitialized()) {
continue; continue;
} }
......
...@@ -54,6 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD"; ...@@ -54,6 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO"; constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@";
// define some kernel priority // define some kernel priority
/* Define multiple kernel type fallback order*/ /* Define multiple kernel type fallback order*/
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority; extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
...@@ -63,7 +66,8 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -63,7 +66,8 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
...@@ -224,7 +228,7 @@ class ExecutionContext { ...@@ -224,7 +228,7 @@ class ExecutionContext {
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> const T* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : &var->Get<T>(); return var == nullptr ? nullptr : &var->Get<T>();
}); });
...@@ -237,7 +241,7 @@ class ExecutionContext { ...@@ -237,7 +241,7 @@ class ExecutionContext {
std::vector<T*> res; std::vector<T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> T* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : var->GetMutable<T>(); return var == nullptr ? nullptr : var->GetMutable<T>();
}); });
......
...@@ -37,8 +37,8 @@ if(WITH_TESTING) ...@@ -37,8 +37,8 @@ if(WITH_TESTING)
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${PYTHON_TESTS_DIR}/book) ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${PYTHON_TESTS_DIR}/book)
set_tests_properties(test_api_impl PROPERTIES DEPENDS test_image_classification) set_tests_properties(test_api_impl PROPERTIES DEPENDS test_image_classification)
endif() endif()
cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps} paddle_inference_api cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps}
ARGS --dirname=${PYTHON_TESTS_DIR}/book) ARGS --dirname=${WORD2VEC_MODEL_DIR})
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
cc_library(paddle_inference_tensorrt_subgraph_engine cc_library(paddle_inference_tensorrt_subgraph_engine
......
...@@ -24,7 +24,7 @@ using contrib::AnalysisConfig; ...@@ -24,7 +24,7 @@ using contrib::AnalysisConfig;
TEST(AnalysisPredictor, ZeroCopy) { TEST(AnalysisPredictor, ZeroCopy) {
AnalysisConfig config; AnalysisConfig config;
config.model_dir = FLAGS_dirname + "/word2vec.inference.model"; config.model_dir = FLAGS_dirname;
config.use_feed_fetch_ops = false; config.use_feed_fetch_ops = false;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config); auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
......
...@@ -296,7 +296,6 @@ op_library(cos_sim_op DEPS cos_sim_functor) ...@@ -296,7 +296,6 @@ op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op) op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op) op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding) op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op) op_library(unstack_op DEPS stack_op)
......
...@@ -375,8 +375,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -375,8 +375,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation. std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr. // Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
......
...@@ -28,9 +28,9 @@ struct AddFunctor { ...@@ -28,9 +28,9 @@ struct AddFunctor {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_add(const framework::ExecutionContext& ctx, void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor *x,
const framework::Tensor* y, framework::Tensor* z) { const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z); AddFunctor<T>(), z);
...@@ -40,9 +40,9 @@ template <typename DeviceContext, typename T> ...@@ -40,9 +40,9 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_floating_point<T>::value && std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx, elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor* z) { framework::Tensor *z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x); auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y); auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z); auto eigen_z = framework::EigenVector<T>::Flatten(*z);
...@@ -55,21 +55,20 @@ template <typename DeviceContext, typename T> ...@@ -55,21 +55,20 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
!std::is_floating_point<T>::value || !std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx, elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor* z) { framework::Tensor *z) {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z); default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
using Tensor = framework::Tensor; auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
const auto x = ctx.Input<Tensor>("X");
const auto y = ctx.Input<Tensor>("Y");
auto z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x->dims() == y->dims(); auto dims_equal = x->dims() == y->dims();
...@@ -87,13 +86,13 @@ struct IdentityGrad { ...@@ -87,13 +86,13 @@ struct IdentityGrad {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_add_grad(const framework::ExecutionContext& ctx, void default_elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor *x,
const framework::Tensor* y, const framework::Tensor *y,
const framework::Tensor* out, const framework::Tensor *out,
const framework::Tensor* dout, const framework::Tensor *dout,
framework::Tensor* dx, framework::Tensor *dx,
framework::Tensor* dy) { framework::Tensor *dy) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>, ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
...@@ -106,11 +105,11 @@ template <typename DeviceContext, typename T> ...@@ -106,11 +105,11 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_floating_point<T>::value && std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx, elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor* out, const framework::Tensor *out,
const framework::Tensor* dout, framework::Tensor* dx, const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor* dy) { framework::Tensor *dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) { if (dx) {
...@@ -128,27 +127,27 @@ template <typename DeviceContext, typename T> ...@@ -128,27 +127,27 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
!std::is_floating_point<T>::value || !std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx, elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor* out, const framework::Tensor *out,
const framework::Tensor* dout, framework::Tensor* dx, const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor* dy) { framework::Tensor *dy) {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> { class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out, x, y // skip out, x, y
auto* out = dout; auto *out = dout;
auto *x = dout, *y = dout; auto *x = dout, *y = dout;
if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr && if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&
......
...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T> ...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T>
class ElementwiseDivKernel : public framework::OpKernel<T> { class ElementwiseDivKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
......
...@@ -29,11 +29,10 @@ template <typename DeviceContext, typename T> ...@@ -29,11 +29,10 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxKernel : public framework::OpKernel<T> { class ElementwiseMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
......
...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T> ...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T>
class ElementwiseMinKernel : public framework::OpKernel<T> { class ElementwiseMinKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
......
...@@ -60,11 +60,10 @@ template <typename DeviceContext, typename T> ...@@ -60,11 +60,10 @@ template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> { class ElementwiseMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
if (x->numel() == y->numel()) { if (x->numel() == y->numel()) {
elementwise_mul<DeviceContext, T>(ctx, x, y, z); elementwise_mul<DeviceContext, T>(ctx, x, y, z);
......
...@@ -13,10 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -29,7 +31,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -29,7 +31,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null."); "Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), PADDLE_ENFORCE(ctx->HasInput("Y"),
...@@ -37,6 +40,17 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -37,6 +40,17 @@ class ElementwiseOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null."); "Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Y").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Y").front(), ctx->GetInputsVarType("Y").front());
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
...@@ -47,9 +61,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -47,9 +61,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
framework::ToDataType(ctx.Input<Tensor>("X")->type());
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
...@@ -64,12 +77,12 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -64,12 +77,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class ElementwiseOpInferVarType : public framework::VarTypeInference { class ElementwiseOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc *block) const override {
auto x_name = op_desc.Input("X")[0]; auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
auto& x = block->FindRecursiveOrCreateVar(x_name); auto &x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name); auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType()); out.SetType(x.GetType());
out.SetDataType(x.GetDataType()); out.SetDataType(x.GetDataType());
} }
...@@ -131,6 +144,7 @@ But the output only shares the LoD information with the input $X$. ...@@ -131,6 +144,7 @@ But the output only shares the LoD information with the input $X$.
protected: protected:
virtual std::string GetName() const = 0; virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0; virtual std::string GetEquation() const = 0;
}; };
...@@ -139,7 +153,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -139,7 +153,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
...@@ -165,7 +179,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -165,7 +179,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::ToDataType( auto input_data_type = framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
...@@ -187,7 +201,7 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -187,7 +201,7 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
using operators::ElementwiseOpGrad::GetExpectedKernelType; using operators::ElementwiseOpGrad::GetExpectedKernelType;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
...@@ -209,11 +223,11 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -209,11 +223,11 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
template <typename T> template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> { class ElemwiseGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* dx = auto *dx =
context.Output<framework::LoDTensor>(framework::GradVarName("X")); context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (dx != nullptr) { if (dx != nullptr) {
auto& dout = auto &dout =
*context.Input<framework::LoDTensor>(framework::GradVarName("Out")); *context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
dx->set_lod(dout.lod()); dx->set_lod(dout.lod());
} }
...@@ -234,7 +248,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> { ...@@ -234,7 +248,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
\ \
protected: \ protected: \
std::unique_ptr<paddle::framework::OpDesc> Apply() const override { \ std::unique_ptr<paddle::framework::OpDesc> Apply() const override { \
auto* op = new paddle::framework::OpDesc(); \ auto *op = new paddle::framework::OpDesc(); \
op->SetType(#kernel_type "_grad"); \ op->SetType(#kernel_type "_grad"); \
op->SetInput("Y", Input("Y")); \ op->SetInput("Y", Input("Y")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \ op->SetInput(::paddle::framework::GradVarName("Out"), \
......
...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T> ...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> { class ElementwiseSubKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class ExtractRowsOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ExtractRowsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ExtractRowsOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X")[0],
framework::proto::VarType::SELECTED_ROWS,
"The type of input(X) must be SelectedRows.");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(
"Out", framework::make_ddim(std::vector<int64_t>{in_dims[0], 1}));
}
};
class ExtractRowsOp : public framework::OperatorBase {
public:
ExtractRowsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>();
auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto &in_rows = in.rows();
auto out_dim = framework::make_ddim(
std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1});
auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place());
if (paddle::platform::is_gpu_place(in.place())) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(in.place());
auto src_ptr = in_rows.Data(in.place());
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(*dev_ctx)
.stream();
memory::Copy(boost::get<platform::CUDAPlace>(out->place()), dst_ptr,
boost::get<platform::CUDAPlace>(in.place()), src_ptr,
in_rows.size() * sizeof(int64_t), stream);
#else
PADDLE_THROW("Not compiled with CUDA.");
#endif
} else {
memory::Copy(platform::CPUPlace(), dst_ptr, platform::CPUPlace(),
in_rows.data(), in_rows.size() * sizeof(int64_t));
}
}
};
class ExtractRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(SelectedRows). The input tensor of extract_rows operator,"
" and its type is SelectedRows.");
AddOutput("Out", "(Tensor). The the rows of input(X).");
AddComment(R"DOC(
ExtractRows Operator.
The function of extract_rows_op is extracting the rows from the input(X)
whose type is SelectedRows.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(extract_rows, ops::ExtractRowsOp, ops::ExtractRowsOpMaker,
ops::ExtractRowsOpInferShape);
...@@ -64,6 +64,8 @@ struct SelectedRowsSumTo { ...@@ -64,6 +64,8 @@ struct SelectedRowsSumTo {
framework::SelectedRows* input2); framework::SelectedRows* input2);
}; };
// FIXME: The result of SelectedRowsAddToTensor maybe non deterministic,
// because it uses CudaAtomicAdd.
// input2 = input1 + input2 // input2 = input1 + input2
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SelectedRowsAddToTensor { struct SelectedRowsAddToTensor {
......
...@@ -24,19 +24,13 @@ class ScaleKernel : public framework::OpKernel<T> { ...@@ -24,19 +24,13 @@ class ScaleKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in_var = ctx.InputVar("X"); auto* in_var = ctx.InputVar("X");
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
auto* out_var = ctx.OutputVar("Out");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());
PADDLE_ENFORCE_EQ(in->dims(), out->dims(),
"in and out should have the same dim");
auto scale = static_cast<T>(ctx.Attr<float>("scale")); auto scale = static_cast<T>(ctx.Attr<float>("scale"));
auto bias = static_cast<T>(ctx.Attr<float>("bias")); auto bias = static_cast<T>(ctx.Attr<float>("bias"));
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale"); auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) { if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) {
auto& in_slr = in_var->Get<framework::SelectedRows>(); auto& in_slr = in_var->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>(); auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
...@@ -44,6 +38,13 @@ class ScaleKernel : public framework::OpKernel<T> { ...@@ -44,6 +38,13 @@ class ScaleKernel : public framework::OpKernel<T> {
out_slr->set_height(in_slr.height()); out_slr->set_height(in_slr.height());
} }
auto* out =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
out->mutable_data<T>(in->place());
PADDLE_ENFORCE_EQ(in->dims(), out->dims(),
"in and out should have the same dim");
auto eigen_out = framework::EigenVector<T>::Flatten(*out); auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in); auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/space_to_depth_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class SpaceToDepthOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SpaceToDepthOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SpaceToDepthOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor");
auto blocksize = ctx->Attrs().Get<int64_t>("blocksize");
PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1");
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize");
VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims
<< "Attribute blocksize" << blocksize << std::endl;
std::vector<int64_t> output_shape(4, 0); // [B,C,H,W]
output_shape[0] = x_dims[0];
output_shape[1] = x_dims[1] * blocksize * blocksize;
output_shape[2] = x_dims[2] / blocksize;
output_shape[3] = x_dims[3] / blocksize;
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
};
class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor). The input should be a 4D tensor B * C * W * H of "
"SpaceToDepthOp "
"operator.");
AddOutput("Out",
"(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of "
"SpaceToDepthOp operator.");
AddAttr<int64_t>(
"blocksize",
"(int64_t, default 2) blocksize used to do change Space To Depth.")
.SetDefault(2)
.GreaterThan(1);
AddComment(R"DOC(
reorg operator used in Yolo v2.
The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize,
Reshape Input(X) into the shape according to Attr(blocksize). The
data in Input(X) are unchanged.
Examples:
1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the blocksize is 2, the reorg operator will transform Input(X)
into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged.
)DOC");
}
};
class SpaceToDepthGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp);
REGISTER_OP_CPU_KERNEL(
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, double>,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
space_to_depth_grad,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/space_to_depth_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, double>,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
space_to_depth_grad,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
class space_to_depth_compute {
public:
HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c,
int64_t batch, int64_t blocksize,
int64_t forward, T *out)
: x_(x),
w_(w),
h_(h),
c_(c),
batch_(batch),
blocksize_(blocksize),
forward_(forward),
out_(out) {}
HOSTDEVICE void operator()(int64_t in_index) {
int64_t out_c = c_ / (blocksize_ * blocksize_);
// calculate each dim position with index of tensor
int64_t b = in_index / (c_ * h_ * w_);
int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_);
int64_t j = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) / w_;
int64_t i = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) % w_;
int64_t c2 = k % out_c;
int64_t offset = k / out_c;
int64_t w2 = i * blocksize_ + offset % blocksize_;
int64_t h2 = j * blocksize_ + offset / blocksize_;
int64_t out_index =
w2 + w_ * blocksize_ * (h2 + h_ * blocksize_ * (c2 + out_c * b));
if (forward_)
out_[out_index] = x_[in_index];
else
out_[in_index] = x_[out_index];
}
private:
const T *x_;
int64_t w_, h_, c_, batch_, blocksize_, forward_;
T *out_;
};
template <typename DeviceContext, typename T>
class SpaceToDepthKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out");
auto *x = context.Input<framework::LoDTensor>("X");
auto blocksize = context.Attr<int64_t>("blocksize");
auto in_dims = x->dims();
out->mutable_data(context.GetPlace(), x->type());
auto out_dims = out->dims();
auto B = in_dims[0];
auto C = in_dims[1];
auto H = in_dims[2];
auto W = in_dims[3];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(x->numel()));
auto *x_data = x->data<T>();
auto *out_data = out->data<T>();
paddle::operators::space_to_depth_compute<T> computer(
x_data, W, H, C, B, blocksize, 1, out_data);
for_range(computer);
out->Resize(out_dims);
}
};
template <typename DeviceContext, typename T>
class SpaceToDepthGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto blocksize = context.Attr<int64_t>("blocksize");
auto in_dims = d_x->dims();
d_x->mutable_data(context.GetPlace(), d_out->type());
auto B = in_dims[0];
auto C = in_dims[1];
auto H = in_dims[2];
auto W = in_dims[3];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(d_x->numel()));
auto *dx_data = d_x->data<T>();
auto *dout_data = d_out->data<T>();
paddle::operators::space_to_depth_compute<T> computer(
dout_data, W, H, C, B, blocksize, 0, dx_data);
for_range(computer);
d_x->Resize(in_dims);
}
};
} // namespace operators
} // namespace paddle
...@@ -64,8 +64,7 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -64,8 +64,7 @@ class SplitIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::GetDataTypeOfVar(ctx.MultiInputVar("Ids").front()),
ctx.MultiInput<framework::Tensor>("Ids").front()->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -113,6 +113,10 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -113,6 +113,10 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
row_width * sizeof(T)); row_width * sizeof(T));
} }
} }
} else {
PADDLE_THROW(
"% should be LoDTensor or SelectedRows, but the received type is %s",
ctx.Inputs("Ids")[0], ids_var->Type().name());
} }
} }
}; };
......
...@@ -85,8 +85,8 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -85,8 +85,8 @@ class SumOp : public framework::OperatorWithKernel {
for (size_t idx = 0; idx < x_vars.size(); ++idx) { for (size_t idx = 0; idx < x_vars.size(); ++idx) {
PADDLE_ENFORCE(x_vars[idx] != nullptr, PADDLE_ENFORCE(x_vars[idx] != nullptr,
"Input var[%s] should not be nullptr", x_vars_name[idx]); "Input var[%s] should not be nullptr", x_vars_name[idx]);
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor. auto tensor =
auto tensor = framework::GetTensorFromVar(*x_vars[idx]); framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
if (tensor->numel() == 0) { if (tensor->numel() == 0) {
continue; continue;
} }
......
...@@ -27,6 +27,7 @@ void BindConstValue(pybind11::module* m) { ...@@ -27,6 +27,7 @@ void BindConstValue(pybind11::module* m) {
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
m->def("kControlDepVarName", m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; }); [] { return framework::ir::Node::kControlDepVarName; });
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
auto op_proto_and_checker_maker = auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker"); m->def_submodule("op_proto_and_checker_maker");
......
...@@ -367,7 +367,12 @@ function run_test() { ...@@ -367,7 +367,12 @@ function run_test() {
Running unit tests ... Running unit tests ...
======================================== ========================================
EOF EOF
ctest --output-on-failure if [ ${TESTING_DEBUG_MODE:-OFF} == "ON" ] ; then
ctest -V
else
ctest --output-on-failure
fi
# make install should also be test when unittest # make install should also be test when unittest
make install -j `nproc` make install -j `nproc`
pip install ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl pip install ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl
......
...@@ -30,7 +30,8 @@ from ..unique_name import generate as unique_name ...@@ -30,7 +30,8 @@ from ..unique_name import generate as unique_name
__all__ = [ __all__ = [
'data', 'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer', 'data', 'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'random_data_generator', 'py_reader', 'Preprocessor', 'load' 'random_data_generator', 'py_reader', 'create_py_reader_by_data',
'Preprocessor', 'load'
] ]
...@@ -475,6 +476,159 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): ...@@ -475,6 +476,159 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
def _py_reader(capacity,
shapes,
dtypes,
lod_levels=None,
name=None,
use_double_buffer=True,
feed_list=None):
if feed_list is not None:
if not isinstance(feed_list, list):
raise TypeError("feed_list should be a list of Variable"
" instead of " + str(type(feed_list)))
lod_levels = []
dtypes = []
shape_concat = []
ranks = []
shapes = []
for feed_data in feed_list:
dtypes.append(feed_data.dtype)
shape_concat.extend(feed_data.shape)
ranks.append(len(feed_data.shape))
shapes.append(feed_data.shape)
lod_levels.append(feed_data.lod_level)
else:
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
if lod_levels is None:
lod_levels = [0] * len(shapes)
if name is None:
queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_py_reader')
double_buffer_name = unique_name('double_buffer')
else:
queue_name = "_".join([name, "queue"])
reader_name = "_".join([name, "reader"])
double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name)
startup_blk.append_op(
type='create_py_reader',
inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
reader = monkey_patch_reader_methods(main_prog_var)
if use_double_buffer:
double_buffer_reader = double_buffer(reader, name=double_buffer_name)
# we return a double buffer reader. However, the reset method comes from
# py_reader.
double_buffer_reader.reset = reader.reset
reader = double_buffer_reader
# monkey patch py_reader special methods
reader.queue = feed_queue
current_reset_method = reader.reset
reader.thread = None
reader.tensor_provider = None
reader.exited = False
def start_provide_thread(func):
def __provider_thread__():
for tensors in func():
array = core.LoDTensorArray()
for item in tensors:
if not isinstance(item, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
array.append(item)
if reader.exited:
break
feed_queue.push(array)
if reader.exited:
break
feed_queue.close()
reader.thread = threading.Thread(target=__provider_thread__)
reader.thread.daemon = True
reader.thread.start()
def __set_tensor_provider__(func):
reader.tensor_provider = func
def __set_paddle_reader__(paddle_reader):
with program_guard(Program(), Program()):
actual_feed_list = feed_list
if actual_feed_list is None:
actual_feed_list = []
counter = 0
for dtype, shape, lod_level in zip(dtypes, shapes, lod_levels):
name = str(counter)
actual_feed_list.append(
data(
name=name,
dtype=dtype,
shape=shape,
lod_level=lod_level))
counter += 1
data_names = [feed_data.name for feed_data in actual_feed_list]
feeder = DataFeeder(
feed_list=actual_feed_list, place=core.CPUPlace())
paddle_reader = feeder.decorate_reader(
paddle_reader, multi_devices=False)
def __tensor_provider__():
for slots in paddle_reader():
yield [slots[data_name] for data_name in data_names]
__set_tensor_provider__(__tensor_provider__)
def __reset__():
current_reset_method()
if reader.thread is not None and reader.tensor_provider is not None:
reader.exited = True
reader.thread.join()
reader.exited = False
def __start__():
start_provide_thread(reader.tensor_provider)
reader.reset = __reset__
reader.decorate_tensor_provider = __set_tensor_provider__
reader.decorate_paddle_reader = __set_paddle_reader__
reader.start = __start__
return reader
def py_reader(capacity, def py_reader(capacity,
shapes, shapes,
dtypes, dtypes,
...@@ -599,128 +753,72 @@ def py_reader(capacity, ...@@ -599,128 +753,72 @@ def py_reader(capacity,
>>> except fluid.core.EOFException: >>> except fluid.core.EOFException:
>>> test_reader.reset() >>> test_reader.reset()
""" """
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] return _py_reader(
shape_concat = [] capacity=capacity,
ranks = [] shapes=shapes,
dtypes=dtypes,
for shape in shapes: lod_levels=lod_levels,
shape_concat.extend(shape) name=name,
ranks.append(len(shape)) use_double_buffer=use_double_buffer)
if lod_levels is None:
lod_levels = [0] * len(shapes)
if name is None:
queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_py_reader')
double_buffer_name = unique_name('double_buffer')
else:
queue_name = "_".join([name, "queue"])
reader_name = "_".join([name, "reader"])
double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name)
startup_blk.append_op(
type='create_py_reader',
inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
reader = monkey_patch_reader_methods(main_prog_var)
if use_double_buffer:
double_buffer_reader = double_buffer(reader, name=double_buffer_name)
# we return a double buffer reader. However, the reset method comes from
# py_reader.
double_buffer_reader.reset = reader.reset
reader = double_buffer_reader
# monkey patch py_reader special methods
reader.queue = feed_queue
current_reset_method = reader.reset
reader.thread = None
reader.tensor_provider = None
reader.exited = False
def start_provide_thread(func):
def __provider_thread__():
for tensors in func():
array = core.LoDTensorArray()
for item in tensors:
if not isinstance(item, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
array.append(item)
if reader.exited:
break
feed_queue.push(array)
if reader.exited:
break
feed_queue.close()
reader.thread = threading.Thread(target=__provider_thread__)
reader.thread.daemon = True
reader.thread.start()
def __set_tensor_provider__(func): def create_py_reader_by_data(capacity,
reader.tensor_provider = func feed_list,
name=None,
use_double_buffer=True):
"""
Create a Python reader for data feeding in Python
def __set_paddle_reader__(paddle_reader): This layer returns a Reader Variable.
with program_guard(Program(), Program()):
feed_list = []
counter = 0
for dtype, shape, lod_level in zip(dtypes, shapes, lod_levels):
name = str(counter)
feed_list.append(
data(
name=name,
dtype=dtype,
shape=shape,
lod_level=lod_level))
counter += 1
feeder = DataFeeder(feed_list=feed_list, place=core.CPUPlace())
paddle_reader = feeder.decorate_reader(
paddle_reader, multi_devices=False)
def __tensor_provider__(): Works much like py_reader except that it's input is feed_list
for slots in paddle_reader(): instead of shapes, dtypes and lod_levels
yield [slots[str(idx)] for idx in six.moves.xrange(counter)]
__set_tensor_provider__(__tensor_provider__) Args:
capacity(int): The buffer capacity maintained by :code:`py_reader`.
feed_list(list(Variable)): The data feed list.
name(basestring): The prefix Python queue name and Reader name. None will
be generated automatically.
use_double_buffer(bool): Whether use double buffer or not.
def __reset__(): Returns:
current_reset_method() Variable: A Reader from which we can get feeding data.
if reader.thread is not None and reader.tensor_provider is not None:
reader.exited = True
reader.thread.join()
reader.exited = False
def __start__(): Examples:
start_provide_thread(reader.tensor_provider)
reader.reset = __reset__ 1. The basic usage of :code:`py_reader` is as follows:
reader.decorate_tensor_provider = __set_tensor_provider__
reader.decorate_paddle_reader = __set_paddle_reader__
reader.start = __start__
return reader >>> import paddle.fluid as fluid
>>> import paddle.dataset.mnist as mnist
>>>
>>> image = fluid.layers.data(name='image', shape=[3,224,224], dtypes='float32')
>>> label = fluid.layers.data(name='label', shape=[1], dtypes='int64')
>>> reader = fluid.layers.create_py_reader_by_data(capacity=64, feed_list=[image, label])
>>> reader.decorate_paddle_reader(
>>> paddle.reader.shuffle(paddle.batch(mnist.train())
>>>
>>> img, label = fluid.layers.read_file(reader)
>>> loss = network(img, label) # some network definition
>>>
>>> fluid.Executor(fluid.CUDAPlace(0)).run(fluid.default_startup_program())
>>>
>>> exe = fluid.ParallelExecutor(use_cuda=True, loss_name=loss.name)
>>> for epoch_id in range(10):
>>> reader.start()
>>> try:
>>> while True:
>>> exe.run(fetch_list=[loss.name])
>>> except fluid.core.EOFException:
>>> reader.reset()
"""
return _py_reader(
capacity=capacity,
shapes=None,
dtypes=None,
lod_levels=None,
name=name,
use_double_buffer=use_double_buffer,
feed_list=feed_list)
def open_files(filenames, def open_files(filenames,
......
...@@ -154,6 +154,7 @@ __all__ = [ ...@@ -154,6 +154,7 @@ __all__ = [
'mul', 'mul',
'sigmoid_cross_entropy_with_logits', 'sigmoid_cross_entropy_with_logits',
'maxout', 'maxout',
'space_to_depth',
'affine_grid', 'affine_grid',
'sequence_reverse', 'sequence_reverse',
'affine_channel', 'affine_channel',
...@@ -3060,7 +3061,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None): ...@@ -3060,7 +3061,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None):
x = fluid.layers.data(name='y', shape=[10, 5], x = fluid.layers.data(name='y', shape=[10, 5],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
pad_value = fluid.layers.assign( pad_value = fluid.layers.assign(
input=numpy.array([0], dtype=numpy.float32)) input=numpy.array([0.0], dtype=numpy.float32))
out = fluid.layers.sequence_pad(x=x, pad_value=pad_value) out = fluid.layers.sequence_pad(x=x, pad_value=pad_value)
""" """
...@@ -7674,6 +7675,66 @@ def maxout(x, groups, name=None): ...@@ -7674,6 +7675,66 @@ def maxout(x, groups, name=None):
return out return out
def space_to_depth(x, blocksize, name=None):
"""
Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width]
This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
The attr blocksize indicates the input block size.
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]:
space_to_depth is used to This operation is useful for resizing the activations between convolutions
(but keeping all data)
- Non-overlapping blocks of size block_size x block size are rearranged into depth at each location.
- The depth of the output tensor is block_size * block_size * input channel
- The Y, X coordinates within each block of the input become the high order component of the output channel index
- channel should be divisible by square of blocksize
- height, width should be divsible by blocksize
Args:
x(variable): The input LoDtensor.
blocksize(variable): The blocksize to select the element on each feature map should be > 2
Returns:
Variable: The output LoDtensor.
Raises:
TypeError: blocksize type must be a long.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[1, 4, 2, 2], dtype='float32')
space_to_depthed = fluid.layers.space_to_depth(
x=data, blocksize=2)
"""
helper = LayerHelper("space_to_depth", **locals())
if not (isinstance(blocksize, int)):
raise ValueError("blocksize must be a python Int")
if name is None:
out = helper.create_variable_for_type_inference(
dtype=x.dtype) #fix create
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="space_to_depth",
inputs={"X": x},
attrs={"blocksize": blocksize},
outputs={"Out": out})
return out
@templatedoc() @templatedoc()
def sequence_reverse(x, name=None): def sequence_reverse(x, name=None):
""" """
......
...@@ -108,6 +108,8 @@ class OpDescCreationMethod(object): ...@@ -108,6 +108,8 @@ class OpDescCreationMethod(object):
new_attr.i = user_defined_attr new_attr.i = user_defined_attr
elif attr.type == framework_pb2.FLOAT: elif attr.type == framework_pb2.FLOAT:
new_attr.f = user_defined_attr new_attr.f = user_defined_attr
elif attr.type == framework_pb2.LONG:
new_attr.l = user_defined_attr
elif attr.type == framework_pb2.STRING: elif attr.type == framework_pb2.STRING:
new_attr.s = user_defined_attr new_attr.s = user_defined_attr
elif attr.type == framework_pb2.BOOLEAN: elif attr.type == framework_pb2.BOOLEAN:
......
...@@ -61,14 +61,25 @@ def append_regularization_ops(parameters_and_grads, regularization=None): ...@@ -61,14 +61,25 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
params_and_grads.append((param, grad)) params_and_grads.append((param, grad))
continue continue
assert grad.shape == regularization_term.shape new_grad = grad
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
# the grad's type and name will be changed. But the gradient's name
# is used in ParallelExecutor Reduce mode, so I add a flag for
# the new_grad here.
new_grad = grad.block.create_var(
name=grad.name + core.kNewGradSuffix(),
dtype=param.dtype,
shape=param.shape,
lod_level=param.lod_level,
type=core.VarDesc.VarType.LOD_TENSOR)
grad.block.append_op( grad.block.append_op(
type='elementwise_add', type='sum',
inputs={"X": grad, inputs={"X": [grad, regularization_term]},
"Y": regularization_term}, outputs={"Out": new_grad})
outputs={"Out": grad})
params_and_grads.append((param, grad)) params_and_grads.append((param, new_grad))
return params_and_grads return params_and_grads
...@@ -142,26 +153,7 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -142,26 +153,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
decay = block.create_var( decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
idx = block.create_var(
dtype="int64",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': idx},
outputs={'Out': decay},
attrs={'is_sparse': True})
param = decay
# Append Op to calculate decay # Append Op to calculate decay
block.append_op( block.append_op(
...@@ -218,27 +210,9 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -218,27 +210,9 @@ class L1DecayRegularizer(WeightDecayRegularizer):
""" """
assert isinstance(param, framework.Parameter) assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
decay = block.create_var( decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
idx = block.create_var(
dtype="int64",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': idx},
outputs={'Out': decay},
attrs={'is_sparse': True})
param = decay
# Append sign op # Append sign op
block.append_op( block.append_op(
......
...@@ -225,29 +225,29 @@ class TestWithInput1x1Filter1x1(TestConv2dOp): ...@@ -225,29 +225,29 @@ class TestWithInput1x1Filter1x1(TestConv2dOp):
#----------------Conv2dCUDNN---------------- #----------------Conv2dCUDNN----------------
def create_test_cudnn_class(parent, cls_name): def create_test_cudnn_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestCUDNNCase(parent): class TestCUDNNCase(parent):
def init_kernel_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
cls_name = "{0}".format(cls_name) cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
TestCUDNNCase.__name__ = cls_name TestCUDNNCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNCase globals()[cls_name] = TestCUDNNCase
create_test_cudnn_class(TestConv2dOp, "TestPool2DCUDNNOp") create_test_cudnn_class(TestConv2dOp)
create_test_cudnn_class(TestWithPad, "TestPool2DCUDNNOpCase1") create_test_cudnn_class(TestWithPad)
create_test_cudnn_class(TestWithStride, "TestPool2DCUDNNOpCase2") create_test_cudnn_class(TestWithStride)
create_test_cudnn_class(TestWithGroup, "TestPool2DCUDNNOpCase3") create_test_cudnn_class(TestWithGroup)
create_test_cudnn_class(TestWith1x1, "TestPool2DCUDNNOpCase4") create_test_cudnn_class(TestWith1x1)
create_test_cudnn_class(TestWithInput1x1Filter1x1, "TestPool2DCUDNNOpCase4") create_test_cudnn_class(TestWithInput1x1Filter1x1)
#----------------Conv2dCUDNN---------------- #----------------Conv2dCUDNN----------------
def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True): def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestConv2DCUDNNFp16(parent): class TestConv2DCUDNNFp16(parent):
...@@ -279,23 +279,17 @@ def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True): ...@@ -279,23 +279,17 @@ def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True):
max_relative_error=0.02, max_relative_error=0.02,
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
cls_name = "{0}".format(cls_name) cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16")
TestConv2DCUDNNFp16.__name__ = cls_name TestConv2DCUDNNFp16.__name__ = cls_name
globals()[cls_name] = TestConv2DCUDNNFp16 globals()[cls_name] = TestConv2DCUDNNFp16
create_test_cudnn_fp16_class( create_test_cudnn_fp16_class(TestConv2dOp, grad_check=False)
TestConv2dOp, "TestPool2DCUDNNFp16Op", grad_check=False) create_test_cudnn_fp16_class(TestWithPad, grad_check=False)
create_test_cudnn_fp16_class( create_test_cudnn_fp16_class(TestWithStride, grad_check=False)
TestWithPad, "TestPool2DCUDNNFp16OpCase1", grad_check=False) create_test_cudnn_fp16_class(TestWithGroup, grad_check=False)
create_test_cudnn_fp16_class( create_test_cudnn_fp16_class(TestWith1x1, grad_check=False)
TestWithStride, "TestPool2DCUDNNFp16OpCase2", grad_check=False) create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False)
create_test_cudnn_fp16_class(
TestWithGroup, "TestPool2DCUDNNFp16OpCase3", grad_check=False)
create_test_cudnn_fp16_class(
TestWith1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
create_test_cudnn_fp16_class(
TestWithInput1x1Filter1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
# -------TestDepthwiseConv # -------TestDepthwiseConv
......
...@@ -98,17 +98,18 @@ class TestDistRunnerBase(object): ...@@ -98,17 +98,18 @@ class TestDistRunnerBase(object):
strategy.allow_op_delay = False strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy() build_stra = fluid.BuildStrategy()
if args.batch_merge_repeat > 1:
pass_builder = build_stra._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
if args.use_reduce: if args.use_reduce:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
else: else:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
if args.batch_merge_repeat > 1:
pass_builder = build_stra._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
args.use_cuda, args.use_cuda,
loss_name=avg_cost.name, loss_name=avg_cost.name,
......
...@@ -373,9 +373,8 @@ class TestL2Decay(TranspilerTest): ...@@ -373,9 +373,8 @@ class TestL2Decay(TranspilerTest):
self.assertEqual(len(pserver.blocks), 3) self.assertEqual(len(pserver.blocks), 3)
self.assertEqual([op.type for op in pserver.blocks[1].ops], self.assertEqual([op.type for op in pserver.blocks[1].ops],
["sum", "scale", "clip", "sgd"]) ["sum", "scale", "clip", "sgd"])
self.assertEqual( self.assertEqual([op.type for op in pserver.blocks[2].ops],
[op.type for op in pserver.blocks[2].ops], ["sum", "scale", "clip", "scale", "sum", "sgd"])
["sum", "scale", "clip", "scale", "elementwise_add", "sgd"])
# TODO(typhoonzero): test clipping and L2Decay ops are removed from trainer # TODO(typhoonzero): test clipping and L2Decay ops are removed from trainer
...@@ -416,12 +415,10 @@ class TestL2DecayWithPiecewise(TranspilerTest): ...@@ -416,12 +415,10 @@ class TestL2DecayWithPiecewise(TranspilerTest):
"logical_and", "conditional_block", "fill_constant", "logical_and", "conditional_block", "fill_constant",
"conditional_block" "conditional_block"
]) ])
self.assertEqual( self.assertEqual([op.type for op in pserver.blocks[7].ops],
[op.type for op in pserver.blocks[7].ops], ["sum", "scale", "scale", "sum", "momentum"])
["sum", "scale", "scale", "elementwise_add", "momentum"]) self.assertEqual([op.type for op in pserver.blocks[8].ops],
self.assertEqual( ["sum", "scale", "scale", "sum", "momentum"])
[op.type for op in pserver.blocks[8].ops],
["sum", "scale", "scale", "elementwise_add", "momentum"])
class TestEmptyPserverOptimizeBlocks(TranspilerTest): class TestEmptyPserverOptimizeBlocks(TranspilerTest):
......
...@@ -117,56 +117,5 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): ...@@ -117,56 +117,5 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
} }
class TestElementWiseMulSelectedRows(OpTest):
def setUp(self):
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.feature = 12
self.height = 100
self.input_shape = (len(self.rows), self.feature)
def prepare_input(self, scope, place):
self.input = {
"X": np.random.random(self.input_shape).astype("float32"),
"Y": np.random.random(self.input_shape).astype("float32")
}
def init_input(in_name):
x_selected_rows = scope.var(in_name).get_selected_rows()
x_selected_rows.set_height(self.height)
x_selected_rows.set_rows(self.rows)
x_array = self.input[in_name]
x_tensor = x_selected_rows.get_tensor()
x_tensor.set(x_array, place)
init_input("X")
init_input("Y")
def create_out_selected_row(self, scope):
return scope.var('Out').get_selected_rows()
def check_result(self, out_selected_rows):
assert out_selected_rows.height() == self.height
assert out_selected_rows.rows() == self.rows
out_tensor = np.array(out_selected_rows.get_tensor())
assert out_tensor.shape == self.input_shape
def check_with_place(self, place):
scope = core.Scope()
self.prepare_input(scope, place)
out_selected_rows = self.create_out_selected_row(scope)
out_selected_rows.set_height(0)
out_selected_rows.set_rows([])
elementwise_mul = Operator("elementwise_mul", X='X', Y='Y', Out='Out')
elementwise_mul.run(scope, place)
self.check_result(out_selected_rows)
def test_elewisemul_with_selected_rows_input(self):
places = [core.CPUPlace()]
for place in places:
self.check_with_place(place)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
class TestExtractRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Variable
feature_len = 12
rows = [0, 4, 4, 7]
np_array = np.ones((len(rows), feature_len)).astype("float32")
in_x = scope.var('X').get_selected_rows()
in_x.set_height(len(rows))
in_x.set_rows(rows)
in_x_tensor = in_x.get_tensor()
in_x_tensor.set(np_array, place)
# create Out Variable
out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator
extract_rows_op = Operator("extract_rows", X='X', Out='Out')
extract_rows_op.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
result_array = [ele[0] for ele in result_array]
assert result_array == rows
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == '__main__':
unittest.main()
...@@ -248,6 +248,17 @@ class TestBook(unittest.TestCase): ...@@ -248,6 +248,17 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(layers.softmax(hid)) self.assertIsNotNone(layers.softmax(hid))
print(str(program)) print(str(program))
def test_space_to_depth(self):
program = Program()
with program_guard(program):
data = layers.data(
name='data',
shape=[32, 9, 6, 6],
append_batch_size=False,
dtype='float32')
self.assertIsNotNone(layers.space_to_depth(data, 3))
print(str(program))
def test_sequence_unsqueeze(self): def test_sequence_unsqueeze(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
......
...@@ -53,15 +53,24 @@ def simple_fc_net(in_size, ...@@ -53,15 +53,24 @@ def simple_fc_net(in_size,
hidden_sizes, hidden_sizes,
batch_size, batch_size,
queue_capacity, queue_capacity,
use_double_buffer=False): use_double_buffer=False,
reader = fluid.layers.py_reader( use_feed_list=True):
capacity=queue_capacity, if use_feed_list:
shapes=[[-1, in_size], [-1, 1]], data = fluid.layers.data(name="data", dtype='float32', shape=[in_size])
lod_levels=[0, 0], label = fluid.layers.data(name='label', dtype='int64', shape=[1])
dtypes=['float32', 'int64'], py_reader = fluid.layers.create_py_reader_by_data(
use_double_buffer=False) capacity=queue_capacity,
feed_queue = reader.queue use_double_buffer=False,
reader = fluid.layers.batch(reader, batch_size=batch_size) feed_list=[data, label])
else:
py_reader = fluid.layers.py_reader(
capacity=queue_capacity,
shapes=[[-1, in_size], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
use_double_buffer=False)
feed_queue = py_reader.queue
reader = fluid.layers.batch(py_reader, batch_size=batch_size)
if use_double_buffer: if use_double_buffer:
reader = fluid.layers.double_buffer(reader) reader = fluid.layers.double_buffer(reader)
...@@ -83,7 +92,7 @@ def simple_fc_net(in_size, ...@@ -83,7 +92,7 @@ def simple_fc_net(in_size,
optimizer = fluid.optimizer.Adam() optimizer = fluid.optimizer.Adam()
optimizer.minimize(loss) optimizer.minimize(loss)
return in_data, label, loss, optimizer, feed_queue return in_data, label, loss, optimizer, feed_queue, py_reader
class TestPyReaderUsingExecutor(unittest.TestCase): class TestPyReaderUsingExecutor(unittest.TestCase):
...@@ -100,16 +109,22 @@ class TestPyReaderUsingExecutor(unittest.TestCase): ...@@ -100,16 +109,22 @@ class TestPyReaderUsingExecutor(unittest.TestCase):
if core.is_compiled_with_cuda() else [False]): if core.is_compiled_with_cuda() else [False]):
for use_parallel_executor in [False, True]: for use_parallel_executor in [False, True]:
for use_double_buffer in [False, True]: for use_double_buffer in [False, True]:
print('Test Parameters:'), for use_feed_list in [False, True]:
print({ for use_decorate_paddle_reader in [False, True]:
'use_cuda': use_cuda, print('Test Parameters:'),
'use_parallel_executor': use_parallel_executor, print({
'use_double_buffer': use_double_buffer 'use_cuda': use_cuda,
}) 'use_parallel_executor': use_parallel_executor,
self.main(use_cuda, use_parallel_executor, 'use_double_buffer': use_double_buffer,
use_double_buffer) 'use_feed_list': use_feed_list,
'use_decorate_paddle_reader':
def random_reader(self): use_decorate_paddle_reader
})
self.main(use_cuda, use_parallel_executor,
use_double_buffer, use_feed_list,
use_decorate_paddle_reader)
def tensor_reader(self, use_decorate_paddle_reader):
def reader(): def reader():
self.inputs = [] self.inputs = []
cnt = 0 cnt = 0
...@@ -133,34 +148,43 @@ class TestPyReaderUsingExecutor(unittest.TestCase): ...@@ -133,34 +148,43 @@ class TestPyReaderUsingExecutor(unittest.TestCase):
elif not self.use_double_buffer: elif not self.use_double_buffer:
break break
yield tensors if use_decorate_paddle_reader:
yield [(in_data, label)]
else:
yield tensors
cnt += 1 cnt += 1
yield None if not use_decorate_paddle_reader:
yield None
return reader return reader
def main(self, def main(self,
use_cuda=True, use_cuda=True,
use_parallel_executor=False, use_parallel_executor=False,
use_double_buffer=False): use_double_buffer=False,
use_feed_list=False,
use_decorate_paddle_reader=False):
assert not use_cuda or use_cuda and core.is_compiled_with_cuda() assert not use_cuda or use_cuda and core.is_compiled_with_cuda()
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.use_parallel_executor = use_parallel_executor self.use_parallel_executor = use_parallel_executor
self.use_double_buffer = use_double_buffer self.use_double_buffer = use_double_buffer
self.use_feed_list = use_feed_list
self.use_decorate_paddle_reader = use_decorate_paddle_reader
startup_program = fluid.Program() startup_program = fluid.Program()
main_program = fluid.Program() main_program = fluid.Program()
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
in_data, label, loss, optimizer, feed_queue = simple_fc_net( in_data, label, loss, optimizer, feed_queue, py_reader = simple_fc_net(
in_size=self.in_size, in_size=self.in_size,
class_num=self.class_num, class_num=self.class_num,
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
batch_size=self.batch_size, batch_size=self.batch_size,
queue_capacity=self.queue_capacity, queue_capacity=self.queue_capacity,
use_double_buffer=self.use_double_buffer) use_double_buffer=self.use_double_buffer,
use_feed_list=self.use_feed_list)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
...@@ -178,10 +202,14 @@ class TestPyReaderUsingExecutor(unittest.TestCase): ...@@ -178,10 +202,14 @@ class TestPyReaderUsingExecutor(unittest.TestCase):
main_exe = startup_exe main_exe = startup_exe
self.batch_size_times = 1 self.batch_size_times = 1
reader = self.random_reader() reader = self.tensor_reader(use_decorate_paddle_reader)
thread = threading.Thread( if use_decorate_paddle_reader:
target=feed_data, args=(feed_queue, reader)) py_reader.decorate_paddle_reader(reader)
thread.start() py_reader.start()
else:
thread = threading.Thread(
target=feed_data, args=(feed_queue, reader))
thread.start()
self.outputs = [] self.outputs = []
for _ in range(self.iterations): for _ in range(self.iterations):
......
...@@ -55,7 +55,7 @@ class TestL2DecayRegularizer(unittest.TestCase): ...@@ -55,7 +55,7 @@ class TestL2DecayRegularizer(unittest.TestCase):
params_grads = optimizer.append_regularization_ops(params_grads) params_grads = optimizer.append_regularization_ops(params_grads)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(block.ops), count_ops + 2) self.assertEqual(len(block.ops), count_ops + 2)
self.assertEqual(block.ops[-1].type, 'elementwise_add') self.assertEqual(block.ops[-1].type, 'sum')
self.assertEqual(block.ops[-2].type, 'scale') self.assertEqual(block.ops[-2].type, 'scale')
...@@ -92,7 +92,7 @@ class TestL1DecayRegularizer(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestL1DecayRegularizer(unittest.TestCase):
params_grads = optimizer.append_regularization_ops(params_grads) params_grads = optimizer.append_regularization_ops(params_grads)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(block.ops), count_ops + 3) self.assertEqual(len(block.ops), count_ops + 3)
self.assertEqual(block.ops[-1].type, 'elementwise_add') self.assertEqual(block.ops[-1].type, 'sum')
self.assertEqual(block.ops[-2].type, 'scale') self.assertEqual(block.ops[-2].type, 'scale')
self.assertEqual(block.ops[-3].type, 'sign') self.assertEqual(block.ops[-3].type, 'sign')
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from op_test import OpTest
class TestSpaceToDepthOp(OpTest):
@staticmethod
def helper(in_, width, height, channel, batch, blocksize, forward, out_):
channel_out = channel // (blocksize * blocksize)
for b in range(batch):
for k in range(channel):
for j in range(height):
for i in range(width):
in_index = i + width * (j + height * (k + channel * b))
channel2 = k % channel_out
offset = k // channel_out
width2 = i * blocksize + offset % blocksize
height2 = j * blocksize + offset // blocksize
out_index = width2 + width * blocksize * (
height2 + height * blocksize *
(channel2 + channel_out * b))
if forward:
out_[out_index] = in_[in_index]
else:
out_[in_index] = in_[out_index]
def setUp(self):
self.init_data()
self.op_type = "space_to_depth"
self.inputs = {"X": self.x}
self.helper(self.x_1d, self.x.shape[3], self.x.shape[2],
self.x.shape[1], self.x.shape[0], self.blocksize,
self.forward, self.out_1d)
self.out = np.reshape(self.out_1d, self.infered_shape)
self.attrs = {"blocksize": self.blocksize}
self.outputs = {"Out": self.out}
def init_data(self):
self.ori_shape = (32, 12, 6, 6)
self.infered_shape = (32, 48, 3, 3)
self.one_d_len = 32 * 48 * 3 * 3
self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
def test_check_output(self):
place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.core.CPUPlace()
self.check_output_with_place(place, 1e-5, None, False)
def test_check_grad(self):
place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.core.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out')
class TestSpaceToDepthOpBasic(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 8, 6, 6)
self.infered_shape = (32, 32, 3, 3)
self.one_d_len = 32 * 32 * 3 * 3
self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 8, 6, 6)
self.infered_shape = (32, 32, 3, 3)
self.one_d_len = 32 * 32 * 3 * 3
self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float64')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float64')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 9, 6, 6)
self.infered_shape = (32, 81, 2, 2)
self.one_d_len = 32 * 81 * 2 * 2
self.blocksize = 3
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 9, 9, 6)
self.infered_shape = (32, 81, 3, 2)
self.one_d_len = 32 * 81 * 3 * 2
self.blocksize = 3
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
if __name__ == '__main__':
unittest.main()
...@@ -49,11 +49,14 @@ class TestSumOp(OpTest): ...@@ -49,11 +49,14 @@ class TestSumOp(OpTest):
class TestSelectedRowsSumOp(OpTest): class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place, inplace): def setUp(self):
self.height = 10 self.height = 10
self.row_numel = 12 self.row_numel = 12
self.rows = [0, 1, 2, 3, 4, 5, 6] self.rows = [0, 1, 2, 3, 4, 5, 6]
self.dtype = np.float32
self.init_kernel_type()
def check_with_place(self, place, inplace):
self.check_input_and_optput(core.Scope(), place, inplace, True, True, self.check_input_and_optput(core.Scope(), place, inplace, True, True,
True) True)
self.check_input_and_optput(core.Scope(), place, inplace, False, True, self.check_input_and_optput(core.Scope(), place, inplace, False, True,
...@@ -64,12 +67,12 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -64,12 +67,12 @@ class TestSelectedRowsSumOp(OpTest):
False) False)
def init_kernel_type(self): def init_kernel_type(self):
self.dtype = np.float32 pass
def _get_array(self, row_num, row_numel): def _get_array(self, rows, row_numel):
array = np.ones((row_num, row_numel)).astype(self.dtype) array = np.ones((len(rows), row_numel)).astype(self.dtype)
for i in range(row_num): for i in range(len(rows)):
array[i] *= i array[i] *= rows[i]
return array return array
def check_input_and_optput(self, def check_input_and_optput(self,
...@@ -105,7 +108,7 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -105,7 +108,7 @@ class TestSelectedRowsSumOp(OpTest):
self.assertTrue( self.assertTrue(
np.array_equal( np.array_equal(
np.array(out.get_tensor()), np.array(out.get_tensor()),
self._get_array(len(self.rows), self.row_numel) * self._get_array(self.rows, self.row_numel) *
has_data_w_num)) has_data_w_num))
else: else:
self.assertEqual(len(out.rows()), 0) self.assertEqual(len(out.rows()), 0)
...@@ -121,7 +124,7 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -121,7 +124,7 @@ class TestSelectedRowsSumOp(OpTest):
w_selected_rows = var.get_selected_rows() w_selected_rows = var.get_selected_rows()
w_selected_rows.set_height(self.height) w_selected_rows.set_height(self.height)
w_selected_rows.set_rows(rows) w_selected_rows.set_rows(rows)
w_array = self._get_array(len(rows), self.row_numel) w_array = self._get_array(self.rows, self.row_numel)
w_tensor = w_selected_rows.get_tensor() w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place) w_tensor.set(w_array, place)
...@@ -136,36 +139,91 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -136,36 +139,91 @@ class TestSelectedRowsSumOp(OpTest):
self.check_with_place(place, inplace) self.check_with_place(place, inplace)
class TestLoDTensorAndSelectedRowsOp(TestSelectedRowsSumOp):
def setUp(self):
self.height = 10
self.row_numel = 12
self.rows = [0, 1, 2, 2, 4, 5, 6]
def check_with_place(self, place, inplace):
scope = core.Scope()
if inplace:
self.create_lod_tensor(scope, place, "x1")
self.create_selected_rows(scope, place, "x2", True)
out = scope.var("x1").get_tensor()
out_name = "x1"
else:
self.create_selected_rows(scope, place, "x1", True)
self.create_lod_tensor(scope, place, "x2")
out = scope.var("out").get_tensor()
out_name = "out"
# create and run sum operator
sum_op = Operator("sum", X=["x1", "x2"], Out=out_name)
sum_op.run(scope, place)
result = np.ones((1, self.height)).astype(np.int32).tolist()[0]
for ele in self.rows:
result[ele] += 1
out_t = np.array(out)
self.assertEqual(out_t.shape[0], self.height)
self.assertTrue(
np.array_equal(out_t,
self._get_array([i for i in range(
self.height)], self.row_numel) * np.tile(
np.array(result).reshape(self.height, 1),
self.row_numel)))
def create_lod_tensor(self, scope, place, var_name):
var = scope.var(var_name)
w_tensor = var.get_tensor()
w_array = self._get_array([i for i in range(self.height)],
self.row_numel)
w_tensor.set(w_array, place)
return var
#----------- test fp16 -----------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16SumOp(TestSumOp): class TestFP16SumOp(TestSumOp):
def init_kernel_type(self): def init_kernel_type(self):
self.dtype = np.float16 self.dtype = np.float16
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): place = core.CUDAPlace(0)
place = core.CUDAPlace(0) if core.is_float16_supported(place):
if core.is_float16_supported(place): self.check_output_with_place(place, atol=2e-2)
self.check_output_with_place(place, atol=2e-2)
# FIXME: Because of the precision fp16, max_relative_error # FIXME: Because of the precision fp16, max_relative_error
# should be 0.15 here. # should be 0.15 here.
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): place = core.CUDAPlace(0)
place = core.CUDAPlace(0) if core.is_float16_supported(place):
if core.is_float16_supported(place): self.check_grad(['x0'], 'Out', max_relative_error=0.15)
self.check_grad(['x0'], 'Out', max_relative_error=0.15)
class TestFP16SelectedRowsSumOp(TestSelectedRowsSumOp): def create_test_sum_fp16_class(parent):
def init_kernel_type(self): @unittest.skipIf(not core.is_compiled_with_cuda(),
self.dtype = np.float16 "core is not compiled with CUDA")
class TestSumFp16Case(parent):
def init_kernel_type(self):
self.dtype = np.float16
def test_w_is_selected_rows(self): def test_w_is_selected_rows(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
for inplace in [True, False]: for inplace in [True, False]:
self.check_with_place(place, inplace) self.check_with_place(place, inplace)
cls_name = "{0}_{1}".format(parent.__name__, "SumFp16Test")
TestSumFp16Case.__name__ = cls_name
globals()[cls_name] = TestSumFp16Case
create_test_sum_fp16_class(TestSelectedRowsSumOp)
create_test_sum_fp16_class(TestLoDTensorAndSelectedRowsOp)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -1706,13 +1706,27 @@ to transpile() call.") ...@@ -1706,13 +1706,27 @@ to transpile() call.")
outputs=outputs, outputs=outputs,
attrs=opt_op.all_attrs()) attrs=opt_op.all_attrs())
def _is_splited_grad_var(self, var, var_dict): def _get_pserver_grad_param_var(self, var, var_dict):
"""
Return pserver side grad/param variable, return None
if the variable is not grad/param, e.g.
a@GRAD -> a@GRAD.block0
a@GRAD -> a@GRAD (a is not splited)
fc_0.w_0 -> fc_0.w_0.block_0
fc_0.w_0 -> fc_0.w_0 (weight is not splited)
_generated_var_123 -> None
"""
grad_block = None grad_block = None
for _, g in six.iteritems(var_dict): for _, g in six.iteritems(var_dict):
if self._orig_varname(g.name) == self._orig_varname(var.name): if self._orig_varname(g.name) == self._orig_varname(var.name):
# skip per trainer vars
if g.name.find(".trainer_") == -1: if g.name.find(".trainer_") == -1:
grad_block = g # only param or grads have splited blocks
break if self._orig_varname(g.name) in self.grad_name_to_param_name or\
self._orig_varname(g.name) in self.param_name_to_grad_name:
grad_block = g
break
return grad_block return grad_block
def _clone_lr_op(self, program, block, op): def _clone_lr_op(self, program, block, op):
...@@ -1745,32 +1759,38 @@ to transpile() call.") ...@@ -1745,32 +1759,38 @@ to transpile() call.")
for key, varlist in six.iteritems(inputs): for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for i in range(len(varlist)):
# for ops like clipping and weight decay, get the splited var var = varlist[i]
# for ops like clipping and weight decay, get the splited var (xxx.block0)
# for inputs/outputs # for inputs/outputs
grad_block = self._is_splited_grad_var( grad_block = self._get_pserver_grad_param_var(
var, program.global_block().vars) var, program.global_block().vars)
if grad_block: if grad_block:
inputs[key] = grad_block varlist[i] = grad_block
elif var.name not in program.global_block().vars: elif var.name not in program.global_block().vars:
program.global_block().create_var( tmpvar = program.global_block()._clone_variable(var)
name=var.name, varlist[i] = tmpvar
persistable=var.persistable, else:
dtype=var.dtype, varlist[i] = program.global_block().vars[var.name]
shape=var.shape) inputs[key] = varlist
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in six.iteritems(outputs): for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for i in range(len(varlist)):
grad_block = self._is_splited_grad_var( var = varlist[i]
grad_block = self._get_pserver_grad_param_var(
var, program.global_block().vars) var, program.global_block().vars)
if grad_block: if grad_block:
outputs[key] = grad_block varlist[i] = grad_block
elif var.name not in program.global_block().vars: elif var.name not in program.global_block().vars:
program.global_block()._clone_variable(var) tmpvar = program.global_block()._clone_variable(var)
varlist[i] = tmpvar
else:
varlist[i] = program.global_block().vars[var.name]
outputs[key] = varlist
return optimize_block.append_op( return optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册