提交 4d8fed23 编写于 作者: W wangguibao

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into async_executor

...@@ -28,3 +28,4 @@ third_party/ ...@@ -28,3 +28,4 @@ third_party/
build_* build_*
# clion workspace. # clion workspace.
cmake-build-* cmake-build-*
model_test
...@@ -69,6 +69,7 @@ option(WITH_ANAKIN "Compile with Anakin library" OFF) ...@@ -69,6 +69,7 @@ option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(WITH_INFERENCE "Compile fluid inference library" ON) option(WITH_INFERENCE "Compile fluid inference library" ON)
option(ON_INFER "Turn on inference optimization." OFF)
option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
...@@ -179,6 +180,7 @@ include(external/eigen) # download eigen3 ...@@ -179,6 +180,7 @@ include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11 include(external/pybind11) # download pybind11
include(external/cares) include(external/cares)
include(external/cub) include(external/cub)
include(external/xxhash) # download xxhash
if (NOT WIN32) if (NOT WIN32)
# there is no official support of snappystream, warpctc, nccl, cupti in windows # there is no official support of snappystream, warpctc, nccl, cupti in windows
...@@ -301,3 +303,8 @@ if(WITH_DOC) ...@@ -301,3 +303,8 @@ if(WITH_DOC)
find_python_module(recommonmark REQUIRED) find_python_module(recommonmark REQUIRED)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()
if (ON_INFER)
message(WARNING "On inference mode, will take place some specific optimization.")
add_definitions(-DPADDLE_ON_INFERENCE)
endif()
...@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \ ...@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \
pip3 install -U docopt PyYAML sphinx==1.5.6 && \ pip3 install -U docopt PyYAML sphinx==1.5.6 && \
pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \ easy_install -U pip && \
pip install -U wheel && \ pip install -U pip setuptools wheel && \
pip install -U docopt PyYAML sphinx==1.5.6 && \ pip install -U docopt PyYAML sphinx==1.5.6 && \
pip install sphinx-rtd-theme==0.1.9 recommonmark pip install sphinx-rtd-theme==0.1.9 recommonmark
RUN pip3 install pre-commit 'ipython==5.3.0' && \ RUN pip3 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip3 install opencv-python && \ pip3 install opencv-python && \
pip install pre-commit 'ipython==5.3.0' && \ pip install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python pip install opencv-python
......
INCLUDE(ExternalProject)
set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash)
set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash)
set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include")
IF(WITH_STATIC_LIB)
SET(BUILD_CMD make lib)
ELSE()
SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib)
ENDIF()
ExternalProject_Add(
extern_xxhash
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/Cyan4973/xxHash"
GIT_TAG "v0.6.5"
PREFIX ${XXHASH_SOURCE_DIR}
DOWNLOAD_NAME "xxhash"
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
PATCH_COMMAND
BUILD_COMMAND ${BUILD_CMD}
INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install
TEST_COMMAND ""
)
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
INCLUDE_DIRECTORIES(${XXHASH_INCLUDE_DIR})
add_library(xxhash STATIC IMPORTED GLOBAL)
set_property(TARGET xxhash PROPERTY IMPORTED_LOCATION ${XXHASH_LIBRARIES})
include_directories(${XXHASH_INCLUDE_DIR})
add_dependencies(xxhash extern_xxhash)
LIST(APPEND external_project_dependencies xxhash)
IF(WITH_C_API)
INSTALL(DIRECTORY ${XXHASH_INCLUDE_DIR} DESTINATION third_party/xxhash)
IF(ANDROID)
INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib/${ANDROID_ABI})
ELSE()
INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib)
ENDIF()
ENDIF()
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# make package for paddle fluid shared and static library # make package for paddle fluid shared and static library
function(copy TARGET) function(copy TARGET)
if (NOT ON_INFER)
message(WARNING "Turn on the ON_INFER flag when building inference_lib only.")
endif()
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DSTS DEPS) set(multiValueArgs SRCS DSTS DEPS)
...@@ -31,7 +34,7 @@ function(copy TARGET) ...@@ -31,7 +34,7 @@ function(copy TARGET)
foreach(index RANGE ${len}) foreach(index RANGE ${len})
list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_SRCS ${index} src)
list(GET copy_lib_DSTS ${index} dst) list(GET copy_lib_DSTS ${index} dst)
add_custom_command(TARGET ${TARGET} PRE_BUILD add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND mkdir -p "${dst}" COMMAND mkdir -p "${dst}"
COMMAND cp -r "${src}" "${dst}" COMMAND cp -r "${src}" "${dst}"
COMMENT "copying ${src} -> ${dst}") COMMENT "copying ${src} -> ${dst}")
...@@ -67,6 +70,13 @@ copy(boost_lib ...@@ -67,6 +70,13 @@ copy(boost_lib
DEPS boost DEPS boost
) )
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/xxhash")
copy(xxhash_lib
SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib
DEPS xxhash
)
if(NOT PROTOBUF_FOUND) if(NOT PROTOBUF_FOUND)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf")
copy(protobuf_lib copy(protobuf_lib
...@@ -186,7 +196,7 @@ copy(cmake_cache ...@@ -186,7 +196,7 @@ copy(cmake_cache
DSTS ${FLUID_INSTALL_DIR}) DSTS ${FLUID_INSTALL_DIR})
# This command generates a complete fluid library for both train and inference # This command generates a complete fluid library for both train and inference
add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep}) add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep})
# Following commands generate a inference-only fluid library # Following commands generate a inference-only fluid library
# third_party, version.txt and CMakeCache.txt are the same position with ${FLUID_INSTALL_DIR} # third_party, version.txt and CMakeCache.txt are the same position with ${FLUID_INSTALL_DIR}
......
...@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name' ...@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None)) paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
...@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', ...@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label',
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)) paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None)) paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None))
paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None))
...@@ -174,7 +174,9 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None ...@@ -174,7 +174,9 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
......
...@@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) { ...@@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) {
std::deque<ir::Node *> q_nodes; std::deque<ir::Node *> q_nodes;
std::vector<std::unordered_set<ir::Node *>> graph_nodes; std::vector<std::unordered_set<ir::Node *>> graph_nodes;
std::unordered_set<ir::Node *> g_nodes; std::unordered_set<ir::Node *> g_nodes;
// q_set used to record records in the queue.
std::unordered_set<ir::Node *> q_set;
size_t graph_count = 0; size_t graph_count = 0;
auto traverse_nodes = [&visited_nodes, auto traverse_nodes = [&visited_nodes, &q_nodes,
&q_nodes](const std::vector<ir::Node *> &nodes) { &q_set](const std::vector<ir::Node *> &nodes) {
std::copy_if( for (auto n : nodes) {
nodes.begin(), nodes.end(), std::back_inserter(q_nodes), if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) {
[&visited_nodes](Node *node) { return !visited_nodes.count(node); }); q_nodes.push_back(n);
q_set.insert(n);
}
}
}; };
while (visited_nodes.size() != nodes.size()) { while (visited_nodes.size() != nodes.size()) {
if (!q_nodes.empty()) { if (!q_nodes.empty()) {
auto cur_node = q_nodes.front(); auto cur_node = q_nodes.front();
q_nodes.pop_front(); q_nodes.pop_front();
q_set.erase(cur_node);
visited_nodes.insert(cur_node); visited_nodes.insert(cur_node);
g_nodes.insert(cur_node); g_nodes.insert(cur_node);
traverse_nodes(cur_node->inputs); traverse_nodes(cur_node->inputs);
...@@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) { ...@@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) {
for (auto &n : nodes) { for (auto &n : nodes) {
if (visited_nodes.count(n) == 0) { if (visited_nodes.count(n) == 0) {
q_nodes.push_back(n); q_nodes.push_back(n);
q_set.insert(n);
break; break;
} }
} }
......
...@@ -18,6 +18,82 @@ limitations under the License. */ ...@@ -18,6 +18,82 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// NOTE The vector<LoDTensor> can't be replaced with the class LoDTensorArray
// directly, because there are many vector<LoDTensor> used accross the project,
// and some of them are treated as LoDTensorArray.
#if !defined(PADDLE_ON_INFERENCE)
using LoDTensorArray = std::vector<LoDTensor>; using LoDTensorArray = std::vector<LoDTensor>;
}
#else // !PADDLE_ON_INFERENCE
#pragma message "LoDTensorArray is replaced with the inference one."
/*
* A LoDTensorArray which will not deallocate buffer when resized, fix the data
* diff in inference, and more performance friendly in the concurrency
* scenerios.
*/
class LoDTensorArray {
public:
LoDTensorArray() = default;
using iterator = std::vector<LoDTensor>::iterator;
using const_iterator = std::vector<LoDTensor>::const_iterator;
const_iterator begin() const { return array_.begin(); }
const_iterator end() const { return array_.begin() + size_; }
iterator begin() { return array_.begin(); }
iterator end() { return array_.begin() + size_; }
void push_back(const LoDTensor& x) {
if (size_ < array_.size()) {
array_[size_++] = x;
} else {
array_.push_back(x);
++size_;
}
}
void resize(size_t size) {
if (array_.size() < size) {
array_.resize(size);
}
size_ = size;
}
void emplace_back() { array_.emplace_back(); }
void emplace_back(LoDTensor&& x) { array_.emplace_back(std::move(x)); }
LoDTensor& back() { return array_.back(); }
size_t space() const { return array_.size(); }
void reserve(size_t size) {
// Naive warning to tell user this array might be to large. The memory and
// buffer used by this TensorArray will not be deleted during the training
// and inference phase, so attention not to make it expand too long.
if (size > 800UL) {
LOG(WARNING) << "TensorArray has more than 800 items";
}
array_.reserve(size);
}
bool empty() const { return size_ == 0UL; }
void clear() { size_ = 0UL; }
LoDTensor& operator[](size_t id) { return array_[id]; }
const LoDTensor& operator[](size_t id) const { return array_[id]; }
LoDTensor& at(size_t id) { return array_.at(id); }
const LoDTensor& at(size_t id) const { return array_.at(id); }
size_t size() const { return size_; }
private:
size_t size_{0};
std::vector<LoDTensor> array_;
};
#endif // !PADDLE_ON_INFERENCE
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> { ...@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
this->reserve(this->size() + size_t(end - begin)); this->reserve(this->size() + size_t(end - begin));
this->insert(this->end(), begin, end); this->insert(this->end(), begin, end);
} }
const T *CUDAData(platform::Place place) const {
PADDLE_THROW(
"Vector::CUDAData() method is not supported in CPU-only version");
}
T *CUDAMutableData(platform::Place place) {
PADDLE_THROW(
"Vector::CUDAMutableData() method is not supported in CPU-only "
"version");
}
const T *Data(platform::Place place) const {
PADDLE_ENFORCE(
platform::is_cpu_place(place),
"Vector::Data() method is not supported when not in CPUPlace");
return this->data();
}
T *MutableData(platform::Place place) {
PADDLE_ENFORCE(
platform::is_cpu_place(place),
"Vector::MutableData() method is not supported when not in CPUPlace");
return this->data();
}
const void *Handle() const { return static_cast<const void *>(this); }
}; };
template <typename T> template <typename T>
......
...@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() { ...@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_.swap(ops); ops_.swap(ops);
} }
void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) {
#ifdef PADDLE_WITH_MKLDNN
VLOG(3) << "use_mkldnn=True";
for (size_t block_id = 0; block_id < program.Size(); ++block_id) {
auto *block = const_cast<ProgramDesc &>(program).MutableBlock(block_id);
for (auto *op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) {
op->SetAttr("use_mkldnn", true);
}
}
}
#else
LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -48,8 +48,6 @@ class NaiveExecutor { ...@@ -48,8 +48,6 @@ class NaiveExecutor {
void CleanFeedFetchOps(); void CleanFeedFetchOps();
void EnableMKLDNN(const ProgramDesc& program);
protected: protected:
void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id);
......
...@@ -28,12 +28,12 @@ enum class OpRole { ...@@ -28,12 +28,12 @@ enum class OpRole {
kBackward = 0x0001, kBackward = 0x0001,
kOptimize = 0x0002, kOptimize = 0x0002,
// RPC role is for send/recv releated op // RPC role is for send/recv releated op
kRPC = 0x0003, kRPC = 0x0004,
// Dist role is for split_byref/split_selected_rows/concat // Dist role is for split_byref/split_selected_rows/concat
// used for distributed training. // used for distributed training.
kDist = 0x0004, kDist = 0x0008,
// Tag all learning rate scheduler operators. // Tag all learning rate scheduler operators.
kLRSched = 0x0005, kLRSched = 0x0016,
kLoss = 0x0100, kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and // The default value of op's role. This should be only used for unittests and
......
...@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
...@@ -78,6 +78,8 @@ class Scope { ...@@ -78,6 +78,8 @@ class Scope {
/// Drop all kids scopes belonged to this scope. /// Drop all kids scopes belonged to this scope.
void DropKids(); void DropKids();
std::list<Scope*>& kids() const { return kids_; }
/// Find if a scope exists in the kid scopes /// Find if a scope exists in the kid scopes
bool HasKid(const Scope* scope) const; bool HasKid(const Scope* scope) const;
......
...@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0, ...@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0,
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr); std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr);
std::once_flag ThreadPool::init_flag_; std::once_flag ThreadPool::init_flag_;
...@@ -47,8 +46,7 @@ void ThreadPool::Init() { ...@@ -47,8 +46,7 @@ void ThreadPool::Init() {
} }
} }
ThreadPool::ThreadPool(int num_threads) ThreadPool::ThreadPool(int num_threads) : running_(true) {
: total_threads_(num_threads), idle_threads_(num_threads), running_(true) {
threads_.resize(num_threads); threads_.resize(num_threads);
for (auto& thread : threads_) { for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number // TODO(Yancey1989): binding the thread on the specify CPU number
...@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads) ...@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads)
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
{ {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all(); scheduled_.notify_all();
} }
...@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() { ...@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() {
} }
} }
void ThreadPool::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
void ThreadPool::TaskLoop() { void ThreadPool::TaskLoop() {
while (running_) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) { scheduled_.wait(
break; lock, [this] { return !this->tasks_.empty() || !this->running_; });
if (!running_ || tasks_.empty()) {
return;
} }
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); auto task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
--idle_threads_;
lock.unlock(); lock.unlock();
// run the task // run the task
task(); task();
{
std::unique_lock<std::mutex> lock(mutex_);
++idle_threads_;
if (Done()) {
completed_.notify_all();
}
}
} }
} }
......
...@@ -57,15 +57,6 @@ class ThreadPool { ...@@ -57,15 +57,6 @@ class ThreadPool {
~ThreadPool(); ~ThreadPool();
// Returns the number of threads created by the constructor.
size_t Threads() const { return total_threads_; }
// Returns the number of currently idle threads.
size_t IdleThreads() {
std::unique_lock<std::mutex> lock(mutex_);
return idle_threads_;
}
// Run pushes a function to the task queue and returns a std::future // Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call // object. To wait for the completion of the task, call
// std::future::wait(). // std::future::wait().
...@@ -94,25 +85,13 @@ class ThreadPool { ...@@ -94,25 +85,13 @@ class ThreadPool {
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
// Wait until all the tasks are completed.
void Wait();
private: private:
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
// Note: don't delete the data member total_threads_ and use
// threads_.size() instead; because you'd need to lock the mutex
// before accessing threads_.
bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
// The constructor starts threads to run TaskLoop, which retrieves // The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue. // and runs tasks from the queue.
void TaskLoop(); void TaskLoop();
...@@ -125,14 +104,11 @@ class ThreadPool { ...@@ -125,14 +104,11 @@ class ThreadPool {
static std::once_flag init_flag_; static std::once_flag init_flag_;
std::vector<std::unique_ptr<std::thread>> threads_; std::vector<std::unique_ptr<std::thread>> threads_;
const size_t total_threads_;
size_t idle_threads_;
std::queue<Task> tasks_; std::queue<Task> tasks_;
std::mutex mutex_; std::mutex mutex_;
bool running_; bool running_;
std::condition_variable scheduled_; std::condition_variable scheduled_;
std::condition_variable completed_;
}; };
class ThreadPoolIO : ThreadPool { class ThreadPoolIO : ThreadPool {
......
...@@ -19,10 +19,11 @@ limitations under the License. */ ...@@ -19,10 +19,11 @@ limitations under the License. */
namespace framework = paddle::framework; namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>* sum, int cnt) { void do_sum(std::vector<std::future<void>>* fs, std::mutex* mu,
std::vector<std::future<void>> fs; std::atomic<int>* sum, int cnt) {
for (int i = 0; i < cnt; ++i) { for (int i = 0; i < cnt; ++i) {
fs.push_back(framework::Async([sum]() { sum->fetch_add(1); })); std::lock_guard<std::mutex> l(*mu);
fs->push_back(framework::Async([sum]() { sum->fetch_add(1); }));
} }
} }
...@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) { ...@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) {
} }
TEST(ThreadPool, ConcurrentRun) { TEST(ThreadPool, ConcurrentRun) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0); std::atomic<int> sum(0);
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::vector<std::future<void>> fs;
std::mutex fs_mu;
int n = 50; int n = 50;
// sum = (n * (n + 1)) / 2 // sum = (n * (n + 1)) / 2
for (int i = 1; i <= n; ++i) { for (int i = 1; i <= n; ++i) {
std::thread t(do_sum, pool, &sum, i); std::thread t(do_sum, &fs, &fs_mu, &sum, i);
threads.push_back(std::move(t)); threads.push_back(std::move(t));
} }
for (auto& t : threads) { for (auto& t : threads) {
t.join(); t.join();
} }
pool->Wait(); for (auto& t : fs) {
t.wait();
}
EXPECT_EQ(sum, ((n + 1) * n) / 2); EXPECT_EQ(sum, ((n + 1) * n) / 2);
} }
...@@ -31,7 +31,7 @@ if (WITH_GPU AND TENSORRT_FOUND) ...@@ -31,7 +31,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
endif() endif()
# Create static library # Create static library
cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor) cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array)
if(NOT APPLE) if(NOT APPLE)
# TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
...@@ -41,7 +41,7 @@ endif() ...@@ -41,7 +41,7 @@ endif()
# Create shared library # Create shared library
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} paddle_fluid_api) DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array)
set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
if(NOT APPLE) if(NOT APPLE)
......
...@@ -107,6 +107,9 @@ void Analyzer::Run(Argument* argument) { ...@@ -107,6 +107,9 @@ void Analyzer::Run(Argument* argument) {
passes.push_back("mkldnn_placement_pass"); passes.push_back("mkldnn_placement_pass");
} }
#endif #endif
// infer_clean_graph_pass should be the first default pass
// after mkldnn_placement_pass.
passes.push_back("infer_clean_graph_pass");
for (auto& pass : ir_passes_) { for (auto& pass : ir_passes_) {
if (!disabled_ir_passes_.count(pass)) { if (!disabled_ir_passes_.count(pass)) {
passes.push_back(pass); passes.push_back(pass);
......
...@@ -67,7 +67,6 @@ class Analyzer : public OrderedRegistry<PassManager> { ...@@ -67,7 +67,6 @@ class Analyzer : public OrderedRegistry<PassManager> {
// larger fusion. // larger fusion.
const std::vector<std::string> all_ir_passes_{{ const std::vector<std::string> all_ir_passes_{{
// Manual update the passes here. // Manual update the passes here.
"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", // "attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", //
"embedding_fc_lstm_fuse_pass", // "embedding_fc_lstm_fuse_pass", //
......
...@@ -18,7 +18,8 @@ if(APPLE) ...@@ -18,7 +18,8 @@ if(APPLE)
endif(APPLE) endif(APPLE)
set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor ${GLOB_PASS_LIB}) set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor ${GLOB_PASS_LIB}
)
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine analysis_predictor) set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine analysis_predictor)
...@@ -31,10 +32,17 @@ function(inference_api_test TARGET_NAME) ...@@ -31,10 +32,17 @@ function(inference_api_test TARGET_NAME)
set(multiValueArgs ARGS) set(multiValueArgs ARGS)
cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cc_test(${TARGET_NAME} if (WITH_GPU)
SRCS ${inference_test_SRC} cc_test(${TARGET_NAME}
DEPS "${inference_deps}" SRCS ${inference_test_SRC}
ARGS --dirname=${PYTHON_TESTS_DIR}/book/) DEPS "${inference_deps}"
ARGS --dirname=${PYTHON_TESTS_DIR}/book/ --fraction_of_gpu_memory_to_use=0.15)
else()
cc_test(${TARGET_NAME}
SRCS ${inference_test_SRC}
DEPS "${inference_deps}"
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
endif()
if(inference_test_ARGS) if(inference_test_ARGS)
set_tests_properties(${TARGET_NAME} set_tests_properties(${TARGET_NAME}
PROPERTIES DEPENDS "${inference_test_ARGS}") PROPERTIES DEPENDS "${inference_test_ARGS}")
...@@ -42,7 +50,8 @@ function(inference_api_test TARGET_NAME) ...@@ -42,7 +50,8 @@ function(inference_api_test TARGET_NAME)
endif(WITH_TESTING) endif(WITH_TESTING)
endfunction(inference_api_test) endfunction(inference_api_test)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope) cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS reset_tensor_array lod_tensor scope)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor) cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor)
cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api) cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api)
cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc DEPS paddle_inference_api) cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc DEPS paddle_inference_api)
......
...@@ -82,6 +82,7 @@ bool AnalysisPredictor::Init( ...@@ -82,6 +82,7 @@ bool AnalysisPredictor::Init(
// Get the feed_target_names and fetch_target_names // Get the feed_target_names and fetch_target_names
PrepareFeedFetch(); PrepareFeedFetch();
return true; return true;
} }
...@@ -109,6 +110,10 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -109,6 +110,10 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
return false; return false;
} }
VLOG(3) << "predict cost: " << timer.toc() << "ms"; VLOG(3) << "predict cost: " << timer.toc() << "ms";
// Fix TensorArray reuse not cleaned bug.
tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get());
tensor_array_batch_cleaner_.ResetTensorArray();
return true; return true;
} }
...@@ -322,6 +327,9 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor( ...@@ -322,6 +327,9 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
bool AnalysisPredictor::ZeroCopyRun() { bool AnalysisPredictor::ZeroCopyRun() {
executor_->Run(); executor_->Run();
// Fix TensorArray reuse not cleaned bug.
tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get());
tensor_array_batch_cleaner_.ResetTensorArray();
return true; return true;
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -88,6 +89,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -88,6 +89,7 @@ class AnalysisPredictor : public PaddlePredictor {
// Memory buffer for feed inputs. The temporary LoDTensor will cause serious // Memory buffer for feed inputs. The temporary LoDTensor will cause serious
// concurrency problems, so cache them. // concurrency problems, so cache them.
std::vector<framework::LoDTensor> feed_tensors_; std::vector<framework::LoDTensor> feed_tensors_;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner_;
}; };
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -157,6 +158,10 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -157,6 +158,10 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
return false; return false;
} }
VLOG(3) << "predict cost: " << timer.toc() << "ms"; VLOG(3) << "predict cost: " << timer.toc() << "ms";
// Fix TensorArray reuse not cleaned bug.
tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get());
tensor_array_batch_cleaner_.ResetTensorArray();
return true; return true;
} }
......
...@@ -26,11 +26,11 @@ limitations under the License. */ ...@@ -26,11 +26,11 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
...@@ -77,6 +77,7 @@ class NativePaddlePredictor : public PaddlePredictor { ...@@ -77,6 +77,7 @@ class NativePaddlePredictor : public PaddlePredictor {
std::vector<framework::OpDesc *> fetchs_; std::vector<framework::OpDesc *> fetchs_;
// Do not use unique_ptr, use parent scope to delete // Do not use unique_ptr, use parent scope to delete
framework::Scope *sub_scope_{nullptr}; framework::Scope *sub_scope_{nullptr};
details::TensorArrayBatchCleaner tensor_array_batch_cleaner_;
}; };
} // namespace paddle } // namespace paddle
...@@ -52,6 +52,7 @@ include_directories("${PADDLE_LIB}") ...@@ -52,6 +52,7 @@ include_directories("${PADDLE_LIB}")
include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include")
include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include")
include_directories("${PADDLE_LIB}/third_party/install/gflags/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include")
include_directories("${PADDLE_LIB}/third_party/install/xxhash/include")
if (NOT WIN32) if (NOT WIN32)
include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappy/include")
include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include")
...@@ -61,8 +62,8 @@ endif(NOT WIN32) ...@@ -61,8 +62,8 @@ endif(NOT WIN32)
include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3") include_directories("${PADDLE_LIB}/third_party/eigen3")
if (NOT WIN32) if (NOT WIN32)
if (USE_TENSORRT AND WITH_GPU) if (USE_TENSORRT AND WITH_GPU)
include_directories("${TENSORRT_INCLUDE_DIR}") include_directories("${TENSORRT_INCLUDE_DIR}")
link_directories("${TENSORRT_LIB_DIR}") link_directories("${TENSORRT_LIB_DIR}")
endif() endif()
...@@ -77,13 +78,14 @@ endif(NOT WIN32) ...@@ -77,13 +78,14 @@ endif(NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib")
link_directories("${PADDLE_LIB}/paddle/lib") link_directories("${PADDLE_LIB}/paddle/lib")
add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
if(WITH_MKL) if(WITH_MKL)
include_directories("${PADDLE_LIB}/third_party/install/mklml/include") include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn")
if(EXISTS ${MKLDNN_PATH}) if(EXISTS ${MKLDNN_PATH})
...@@ -107,7 +109,7 @@ if (NOT WIN32) ...@@ -107,7 +109,7 @@ if (NOT WIN32)
set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf snappystream snappy z glog gflags protobuf snappystream snappy z xxhash
${EXTERNAL_LIB}) ${EXTERNAL_LIB})
else() else()
set(DEPS ${DEPS} set(DEPS ${DEPS}
...@@ -120,7 +122,7 @@ endif(NOT WIN32) ...@@ -120,7 +122,7 @@ endif(NOT WIN32)
if(WITH_GPU) if(WITH_GPU)
if(NOT WIN32) if(NOT WIN32)
if (USE_TENSORRT) if (USE_TENSORRT)
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX})
endif() endif()
......
...@@ -16,7 +16,7 @@ if [ $2 == ON ]; then ...@@ -16,7 +16,7 @@ if [ $2 == ON ]; then
fi fi
if [ $3 == ON ]; then if [ $3 == ON ]; then
use_gpu_list='true false' use_gpu_list='true false'
else else
use_gpu_list='false' use_gpu_list='false'
fi fi
...@@ -60,7 +60,8 @@ for WITH_STATIC_LIB in ON OFF; do ...@@ -60,7 +60,8 @@ for WITH_STATIC_LIB in ON OFF; do
-DWITH_MKL=$TURN_ON_MKL \ -DWITH_MKL=$TURN_ON_MKL \
-DDEMO_NAME=simple_on_word2vec \ -DDEMO_NAME=simple_on_word2vec \
-DWITH_GPU=$TEST_GPU_CPU \ -DWITH_GPU=$TEST_GPU_CPU \
-DWITH_STATIC_LIB=$WITH_STATIC_LIB -DWITH_STATIC_LIB=$WITH_STATIC_LIB \
-DON_INFER=ON
make -j make -j
word2vec_model=${PADDLE_ROOT}'/build/python/paddle/fluid/tests/book/word2vec.inference.model' word2vec_model=${PADDLE_ROOT}'/build/python/paddle/fluid/tests/book/word2vec.inference.model'
if [ -d $word2vec_model ]; then if [ -d $word2vec_model ]; then
...@@ -80,10 +81,11 @@ for WITH_STATIC_LIB in ON OFF; do ...@@ -80,10 +81,11 @@ for WITH_STATIC_LIB in ON OFF; do
-DWITH_MKL=$TURN_ON_MKL \ -DWITH_MKL=$TURN_ON_MKL \
-DDEMO_NAME=vis_demo \ -DDEMO_NAME=vis_demo \
-DWITH_GPU=$TEST_GPU_CPU \ -DWITH_GPU=$TEST_GPU_CPU \
-DWITH_STATIC_LIB=$WITH_STATIC_LIB -DWITH_STATIC_LIB=$WITH_STATIC_LIB \
-DON_INFER=ON
make -j make -j
for use_gpu in $use_gpu_list; do for use_gpu in $use_gpu_list; do
for vis_demo_name in $vis_demo_list; do for vis_demo_name in $vis_demo_list; do
./vis_demo \ ./vis_demo \
--modeldir=$DATA_DIR/$vis_demo_name/model \ --modeldir=$DATA_DIR/$vis_demo_name/model \
--data=$DATA_DIR/$vis_demo_name/data.txt \ --data=$DATA_DIR/$vis_demo_name/data.txt \
...@@ -95,7 +97,7 @@ for WITH_STATIC_LIB in ON OFF; do ...@@ -95,7 +97,7 @@ for WITH_STATIC_LIB in ON OFF; do
fi fi
done done
done done
# --------tensorrt mobilenet------ # --------tensorrt mobilenet------
if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then
rm -rf * rm -rf *
...@@ -106,8 +108,9 @@ for WITH_STATIC_LIB in ON OFF; do ...@@ -106,8 +108,9 @@ for WITH_STATIC_LIB in ON OFF; do
-DWITH_STATIC_LIB=$WITH_STATIC_LIB \ -DWITH_STATIC_LIB=$WITH_STATIC_LIB \
-DUSE_TENSORRT=$USE_TENSORRT \ -DUSE_TENSORRT=$USE_TENSORRT \
-DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \ -DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \
-DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR \
make -j -DON_INFER=ON
make -j
./trt_mobilenet_demo \ ./trt_mobilenet_demo \
--modeldir=$DATA_DIR/mobilenet/model \ --modeldir=$DATA_DIR/mobilenet/model \
--data=$DATA_DIR/mobilenet/data.txt \ --data=$DATA_DIR/mobilenet/data.txt \
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
namespace paddle {
namespace details {
// Should be called after the parameters are loaded.
void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) {
if (flag_) {
for (auto &var_name : scope->LocalVarNames()) {
auto *var = scope->FindVar(var_name);
// TODO(Superjomn) should avoid the case when a TensorArray is a
// parameter.
if (var_name == "feed" || var_name == "fetch") continue;
if (var->Type() == typeid(framework::LoDTensorArray)) {
VLOG(4) << "collect " << var_name;
arrays_.push_back(var->GetMutable<framework::LoDTensorArray>());
}
}
for (auto *kid : scope->kids()) {
CollectTensorArrays(kid);
}
VLOG(3) << "Collect " << arrays_.size() << " arrays";
flag_ = false;
}
}
// Should be called when `Run` finished.
void TensorArrayBatchCleaner::ResetTensorArray() {
for (auto *arr : arrays_) {
arr->clear();
}
}
} // namespace details
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace details {
// Clean the TensorArray each batch to make the behavior the same with the
// training phase.
struct TensorArrayBatchCleaner {
// Fix the tensor array not clear in the inference scenarios.
void CollectTensorArrays(framework::Scope *scope);
void ResetTensorArray();
private:
bool flag_{true};
std::vector<framework::LoDTensorArray *> arrays_;
};
} // namespace details
} // namespace paddle
...@@ -124,7 +124,7 @@ class ZeroCopyTensor { ...@@ -124,7 +124,7 @@ class ZeroCopyTensor {
std::vector<std::vector<size_t>> lod() const; std::vector<std::vector<size_t>> lod() const;
protected: protected:
ZeroCopyTensor(void* scope) : scope_{scope} {} explicit ZeroCopyTensor(void* scope) : scope_{scope} {}
void SetName(const std::string& name) { name_ = name; } void SetName(const std::string& name) { name_ = name; }
void* FindTensor() const; void* FindTensor() const;
...@@ -259,12 +259,6 @@ struct AnalysisConfig : public NativeConfig { ...@@ -259,12 +259,6 @@ struct AnalysisConfig : public NativeConfig {
kExclude // Specify the disabled passes in `ir_passes`. kExclude // Specify the disabled passes in `ir_passes`.
}; };
void SetIncludeMode() {
ir_mode = IrPassMode::kInclude;
// this pass has to be run at the beginning of all fuse passes
ir_passes = {"infer_clean_graph_pass"};
}
// Determine whether to perform graph optimization. // Determine whether to perform graph optimization.
bool enable_ir_optim = true; bool enable_ir_optim = true;
// Manually determine the IR passes to run. // Manually determine the IR passes to run.
......
...@@ -228,6 +228,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { ...@@ -228,6 +228,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST(Analyzer_rnn1, profile) { TEST(Analyzer_rnn1, profile) {
contrib::AnalysisConfig cfg; contrib::AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg.use_gpu = false;
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND) ...@@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
else() else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif() endif()
op_library(hash_op DEPS xxhash)
op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows)
op_library(sum_op DEPS selected_rows_functor) op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor)
......
...@@ -79,6 +79,9 @@ struct BeamSearchDecodeFunctor { ...@@ -79,6 +79,9 @@ struct BeamSearchDecodeFunctor {
bool tensor_on_gpu_; bool tensor_on_gpu_;
size_t beam_size_; size_t beam_size_;
int end_id_; int end_id_;
// TODO(Superjomn) Here might result serious performance issue in the
// concurrency
// scenarios.
const LoDTensorArray& step_ids_origin_; const LoDTensorArray& step_ids_origin_;
const LoDTensorArray& step_scores_origin_; const LoDTensorArray& step_scores_origin_;
LoDTensorArray step_ids_ = LoDTensorArray(); LoDTensorArray step_ids_ = LoDTensorArray();
......
...@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, ...@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
selected_indices.push_back(idx); selected_indices.push_back(idx);
++selected_num; ++selected_num;
} }
sorted_indices.erase(sorted_indices.end()); sorted_indices.erase(sorted_indices.end() - 1);
if (flag && eta < 1 && adaptive_threshold > 0.5) { if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta; adaptive_threshold *= eta;
} }
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped.") "will be dropped.")
.SetDefault(false); .SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string& type) {
PADDLE_ENFORCE(
type == "downgrade_in_infer" || type == "upscale_in_train",
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train");
});
AddComment(R"DOC( AddComment(R"DOC(
Dropout Operator. Dropout Operator.
...@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, ...@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>); dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout_grad, dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>); ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -26,7 +27,8 @@ namespace operators { ...@@ -26,7 +27,8 @@ namespace operators {
template <typename T> template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed, __global__ void RandomGenerator(const size_t n, const int seed,
const float dropout_prob, const T* src, const float dropout_prob, const T* src,
T* mask_data, T* dst) { T* mask_data, T* dst,
bool is_upscale_in_train) {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<float> dist(0, 1); thrust::uniform_real_distribution<float> dist(0, 1);
...@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed, ...@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
if (dist(rng) < dropout_prob) { if (dist(rng) < dropout_prob) {
mask = static_cast<T>(0); mask = static_cast<T>(0);
} else { } else {
mask = static_cast<T>(1); if (is_upscale_in_train) {
mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
} else {
mask = static_cast<T>(1);
}
} }
dest = s * mask; dest = s * mask;
mask_data[idx] = mask; mask_data[idx] = mask;
...@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y->mutable_data<T>(context.GetPlace()); y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
auto& place = *context.template device_context<Place>().eigen_device(); auto& place = *context.template device_context<Place>().eigen_device();
if (!context.Attr<bool>("is_test")) { if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
...@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int grid = (x->numel() + threads - 1) / threads; int grid = (x->numel() + threads - 1) / threads;
RandomGenerator< RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>( T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data); size, seed, dropout_prob, x_data, mask_data, y_data,
(dropout_implementation == "upscale_in_train"));
} else { } else {
auto X = EigenMatrix<T>::Reshape(*x, 1); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob); if (dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
} }
} }
}; };
...@@ -99,6 +112,8 @@ namespace ops = paddle::operators; ...@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>, dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>); ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
REGISTER_OP_CUDA_KERNEL(dropout_grad, ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
ops::DropoutGradKernel<plat::CUDADeviceContext, float>); REGISTER_OP_CUDA_KERNEL(
dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
ops::DropoutGradKernel<plat::CUDADeviceContext, double>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <random> #include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace()); auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (!context.Attr<bool>("is_test")) { if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace()); auto* mask_data = mask->mutable_data<T>(context.GetPlace());
...@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1); std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims()); size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) { if (dist(engine) < dropout_prob) {
mask_data[i] = 0; mask_data[i] = 0;
y_data[i] = 0; y_data[i] = 0;
} else { } else {
mask_data[i] = 1; if (dropout_implementation == "upscale_in_train") {
y_data[i] = x_data[i]; mask_data[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else {
mask_data[i] = 1;
y_data[i] = x_data[i];
}
} }
} }
} else { } else {
...@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * (1.0f - dropout_prob); if (dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
} }
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hash_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class HashOp : public framework::OperatorWithKernel {
public:
HashOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of HashOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of HashOp should not be null.");
auto dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dims.size(), 2UL,
"The input of hash_op's dimensions must be 2");
std::vector<int64_t> out_dims;
out_dims.reserve(dims.size() + 1);
// copy all dims except the last one
for (size_t i = 0u; i != dims.size() - 1; ++i) {
out_dims.emplace_back(dims[i]);
}
int num_hash = ctx->Attrs().Get<int>("num_hash");
out_dims.emplace_back(num_hash);
// keep the last dim to 1
out_dims.emplace_back(1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class HashOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddOutput("Out", "(Tensor) Output tensor of scale operator.");
AddComment(R"DOC(
**Hash Operator**
$$Out = scale * X$$
)DOC");
AddAttr<int>("num_hash", "").SetDefault(1);
AddAttr<int>("mod_by", "").SetDefault(100000);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker);
REGISTER_OP_CPU_KERNEL(hash, ops::HashKerel<int>, ops::HashKerel<int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
extern "C" {
#include <xxhash.h>
}
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
// template <typename DeviceContext, typename T>
template <typename T>
class HashKerel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out_t = context.Output<framework::LoDTensor>("Out");
auto* in_t = context.Input<framework::LoDTensor>("X");
int mod_by = context.Attr<int>("mod_by");
int num_hash = context.Attr<int>("num_hash");
auto* output = out_t->mutable_data<T>(context.GetPlace());
auto in_dims = in_t->dims();
auto in_lod = in_t->lod();
PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
"The actual input data's size mismatched with LoD information.");
auto seq_length = in_dims[0];
auto last_dim = in_dims[in_dims.size() - 1];
auto* input = in_t->data<T>();
for (int idx = 0; idx < seq_length; ++idx) {
for (int ihash = 0; ihash != num_hash; ++ihash) {
output[idx * num_hash + ihash] =
XXH64(input, sizeof(int) * last_dim, ihash) % mod_by;
}
input += last_dim;
}
}
};
} // namespace operators
} // namespace paddle
...@@ -81,6 +81,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,6 +81,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"Otherwise the given value indicates padding the output " "Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.") "with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding); .SetDefault(kNoPadding);
// NOTE(minqiyang): grad_inplace is an temporal attribute,
// please do NOT set this attribute in python layer.
AddAttr<bool>("grad_inplace",
"(boolean, default false) "
"If the grad op reuse the input's variable.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Lookup Table Operator. Lookup Table Operator.
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -68,6 +69,7 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -68,6 +69,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
const auto *table = table_t.value().data<T>(); const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) { if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T)); memset(output + i * row_width, 0, row_width * sizeof(T));
...@@ -75,8 +77,8 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -75,8 +77,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]); auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
memcpy(output + i * row_width, table + id_index * row_width, blas.VCOPY(row_width, table + id_index * row_width,
row_width * sizeof(T)); output + i * row_width);
} }
} }
} }
...@@ -111,27 +113,37 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -111,27 +113,37 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel(); int64_t ids_num = ids->numel();
framework::Vector<int64_t> new_rows; std::vector<int64_t> new_rows;
new_rows.reserve(ids_num); new_rows.resize(ids_num);
for (int64_t i = 0; i < ids_num; i++) { std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
new_rows.push_back(ids_data[i]);
}
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->mutable_data<T>(context.GetPlace()); // FIXME(minqiyang):
// memory optimization will NOT reuse Tensor with SelectedRows
d_table->set_height(table_dim[0]); // so we could just share the tensor here directly.
// However, the InferVarType method will infer the output SelectedRows
auto *d_output_data = d_output->data<T>(); // to Tensor sometimes, which is a bug, so we will add an attribute
auto *d_table_data = d_table_value->data<T>(); // here to indicate the inplace and remove this attribute after
// the InferVarType's bug was fixed
auto d_output_dims = d_output->dims(); bool grad_inplace = context.Attr<bool>("grad_inplace");
PADDLE_ENFORCE_EQ( if (grad_inplace) {
d_table_value->dims(), d_table_value->ShareDataWith(*d_output);
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); } else {
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); d_table_value->mutable_data<T>(context.GetPlace());
d_table->set_height(table_dim[0]);
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table_value->data<T>();
auto d_output_dims = d_output->dims();
PADDLE_ENFORCE_EQ(
d_table_value->dims(),
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
}
} else { } else {
auto *ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
......
...@@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { ...@@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
return -1; return -1;
} }
template <typename T>
HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
#ifdef __CUDA_ARCH__
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/lower_bound
auto *first = x;
int64_t count = static_cast<int64_t>(num);
while (count > 0) {
int64_t step = (count >> 1);
auto *it = first + step;
if (*it < val) {
first = ++it;
count -= (step + 1);
} else {
count = step;
}
}
return static_cast<size_t>(first - x);
#else
return static_cast<size_t>(std::lower_bound(x, x + num, val) - x);
#endif
}
template <typename T>
HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) {
#ifdef __CUDA_ARCH__
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
auto *first = x;
int64_t count = static_cast<int64_t>(num);
while (count > 0) {
auto step = (count >> 1);
auto *it = first + step;
if (val < *it) {
count = step;
} else {
first = ++it;
count -= (step + 1);
}
}
return static_cast<size_t>(first - x);
#else
return static_cast<size_t>(std::upper_bound(x, x + num, val) - x);
#endif
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_reverse_op.h"
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_reverse, ops::SequenceReverseOp,
ops::SequenceReverseOpMaker,
ops::SequenceReverseGradOpDescMaker);
REGISTER_OP_CPU_KERNEL(
sequence_reverse,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_reverse_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_reverse,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
class SequenceReverseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dim.size(), 2,
"Rank of Input(X) must be not less than 2.");
ctx->SetOutputDim("Y", x_dim);
ctx->ShareLoD("X", "Y");
}
};
class SequenceReverseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input LoDTensor of sequence_reverse op.");
AddOutput("Y", "The output LoDTensor of sequence_reverse op.");
AddComment(R"DOC(
SequenceReverse Operator.
Reverse each sequence in input X along dim 0.
Assuming X is a LoDTensor with dims [5, 4] and lod [[0, 2, 5]], where:
X.data() = [
[1, 2, 3, 4],
[5, 6, 7, 8], # the 0-th sequence with length 2
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20] # the 1-st sequence with length 3
]
The output Y would be a LoDTensor sharing the same dims and lod with input X,
and:
Y.data() = [
[5, 6, 7, 8],
[1, 2, 3, 4], # the reversed 0-th sequence with length 2
[17, 18, 19, 20],
[13, 14, 15, 16],
[9, 10, 11, 12] # the reversed 1-st sequence with length 3
]
This Operator is useful to build a reverse dynamic RNN network.
This Operator only supports one-level lod currently.
)DOC");
}
};
template <typename T>
struct SequenceReverseFunctor {
SequenceReverseFunctor(const T *x, T *y, const size_t *lod, size_t lod_count,
size_t row_numel)
: x_(x), y_(y), lod_(lod), lod_count_(lod_count), row_numel_(row_numel) {}
HOSTDEVICE void operator()(size_t idx_x) const {
auto row_idx_x = idx_x / row_numel_;
auto lod_idx = math::UpperBound(lod_, lod_count_, row_idx_x);
auto row_idx_y = lod_[lod_idx - 1] + (lod_[lod_idx] - 1 - row_idx_x);
auto idx_y = row_idx_y * row_numel_ + idx_x % row_numel_;
y_[idx_y] = x_[idx_x];
}
const T *x_;
T *y_;
const size_t *lod_;
size_t lod_count_;
size_t row_numel_;
};
template <typename DeviceContext, typename T>
class SequenceReverseOpKernel : public framework::OpKernel<T> {
using LoDTensor = framework::LoDTensor;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto &x = *ctx.Input<LoDTensor>("X");
auto *y = ctx.Output<LoDTensor>("Y");
PADDLE_ENFORCE_EQ(x.lod().size(), 1,
"SequenceReverse Op only support one level lod.");
auto &dev_ctx = ctx.template device_context<DeviceContext>();
const size_t *lod;
size_t lod_count = x.lod()[0].size();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
lod = x.lod()[0].CUDAData(ctx.GetPlace());
} else {
#endif
lod = x.lod()[0].data();
#ifdef PADDLE_WITH_CUDA
}
#endif
size_t limit = static_cast<size_t>(x.numel());
size_t row_numel = static_cast<size_t>(limit / x.dims()[0]);
auto *x_data = x.data<T>();
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NE(x_data, y_data,
"SequenceReverse Op does not support in-place operation");
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
row_numel);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor);
}
};
class SequenceReverseGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_reverse");
op->SetInput("X", OutputGrad("Y"));
op->SetOutput("Y", InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
...@@ -76,6 +76,8 @@ namespace ops = paddle::operators; ...@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>, ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<double>,
ops::SoftmaxCUDNNKernel<plat::float16>); ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>); ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<double>);
...@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker, ...@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>); transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker); ops::Transpose2GradMaker);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad); REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2, transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -16,15 +16,18 @@ limitations under the License. */ ...@@ -16,15 +16,18 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose, transpose, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2, transpose2,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -15,6 +15,7 @@ include_directories("${PADDLE_LIB}") ...@@ -15,6 +15,7 @@ include_directories("${PADDLE_LIB}")
include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include")
include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include")
include_directories("${PADDLE_LIB}/third_party/install/gflags/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include")
include_directories("${PADDLE_LIB}/third_party/install/xxhash/include")
include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappy/include")
include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include")
include_directories("${PADDLE_LIB}/third_party/install/zlib/include") include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
...@@ -27,6 +28,7 @@ link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") ...@@ -27,6 +28,7 @@ link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib")
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib")
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
add_executable(demo_trainer demo_trainer.cc) add_executable(demo_trainer demo_trainer.cc)
...@@ -62,5 +64,5 @@ target_link_libraries(demo_trainer ...@@ -62,5 +64,5 @@ target_link_libraries(demo_trainer
${ARCHIVE_END} ${ARCHIVE_END}
${MATH_LIB} ${MATH_LIB}
${MKLDNN_LIB} ${MKLDNN_LIB}
glog gflags protobuf snappystream snappy z glog gflags protobuf snappystream snappy z xxhash
${EXTERNAL_LIB}) ${EXTERNAL_LIB})
...@@ -95,9 +95,9 @@ function cmake_gen() { ...@@ -95,9 +95,9 @@ function cmake_gen() {
exit 1 exit 1
fi fi
fi fi
else else
if [ "$1" != "" ]; then if [ "$1" != "" ]; then
echo "using python abi: $1" echo "using python abi: $1"
if [ "$1" == "cp27-cp27m" ]; then if [ "$1" == "cp27-cp27m" ]; then
export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:}
export PATH=/opt/python/cp27-cp27m/bin/:${PATH} export PATH=/opt/python/cp27-cp27m/bin/:${PATH}
...@@ -119,7 +119,7 @@ function cmake_gen() { ...@@ -119,7 +119,7 @@ function cmake_gen() {
fi fi
fi fi
fi fi
if [ "$SYSTEM" == "Darwin" ]; then if [ "$SYSTEM" == "Darwin" ]; then
WITH_DISTRIBUTE=${WITH_DISTRIBUTE:-ON} WITH_DISTRIBUTE=${WITH_DISTRIBUTE:-ON}
WITH_AVX=${WITH_AVX:-ON} WITH_AVX=${WITH_AVX:-ON}
...@@ -127,7 +127,7 @@ function cmake_gen() { ...@@ -127,7 +127,7 @@ function cmake_gen() {
else else
INFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo} INFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo}
fi fi
cat <<EOF cat <<EOF
======================================== ========================================
Configuring cmake in /paddle/build ... Configuring cmake in /paddle/build ...
...@@ -394,8 +394,8 @@ EOF ...@@ -394,8 +394,8 @@ EOF
export http_proxy= export http_proxy=
export https_proxy= export https_proxy=
# TODO: jiabin need to refine this part when these tests fixed on mac # TODO: jiabin need to refine this part when these tests fixed on mac
ctest --output-on-failure -j $1 ctest --output-on-failure -j $1
# make install should also be test when unittest # make install should also be test when unittest
make install -j 8 make install -j 8
pip install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl pip install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl
if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]] ; then if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]] ; then
...@@ -659,7 +659,7 @@ function gen_fluid_lib() { ...@@ -659,7 +659,7 @@ function gen_fluid_lib() {
Generating fluid library for train and inference ... Generating fluid library for train and inference ...
======================================== ========================================
EOF EOF
cmake .. -DWITH_DISTRIBUTE=OFF cmake .. -DWITH_DISTRIBUTE=OFF -DON_INFER=ON
make -j `nproc` fluid_lib_dist make -j `nproc` fluid_lib_dist
make -j `nproc` inference_lib_dist make -j `nproc` inference_lib_dist
fi fi
......
...@@ -78,7 +78,7 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -78,7 +78,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
six.iteritems(word_dict), key=lambda x: x[1], six.iteritems(word_dict), key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0])) fout.write("%s\n" % (cpt.to_bytes(word[0])))
def __load_dict(tar_file, dict_size, lang, reverse=False): def __load_dict(tar_file, dict_size, lang, reverse=False):
......
...@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
) )
square = grad * grad square = grad * grad
local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64') local_norm_var = layers.reduce_sum(input=square)
context[self.group_name].append(local_norm_var) context[self.group_name].append(local_norm_var)
self.context = context self.context = context
...@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if group_scale_name not in self.context: if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name]) group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = layers.sqrt(x=group_norm_var) group_norm_var = layers.sqrt(x=group_norm_var)
group_norm_var = layers.cast(group_norm_var, 'float32')
clip_var = self.context[self.group_name + "_clip"] clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div( group_scale_var = layers.elementwise_div(
x=clip_var, x=clip_var,
......
...@@ -316,7 +316,7 @@ class DetectionMAP(Evaluator): ...@@ -316,7 +316,7 @@ class DetectionMAP(Evaluator):
gt_label (Variable): The ground truth label index, which is a LoDTensor gt_label (Variable): The ground truth label index, which is a LoDTensor
with shape [N, 1]. with shape [N, 1].
gt_box (Variable): The ground truth bounding box (bbox), which is a gt_box (Variable): The ground truth bounding box (bbox), which is a
LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax]. LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
gt_difficult (Variable|None): Whether this ground truth is a difficult gt_difficult (Variable|None): Whether this ground truth is a difficult
bounding bbox, which can be a LoDTensor [N, 1] or not set. If None, bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
it means all the ground truth labels are not difficult bbox. it means all the ground truth labels are not difficult bbox.
......
...@@ -154,7 +154,9 @@ __all__ = [ ...@@ -154,7 +154,9 @@ __all__ = [
'mul', 'mul',
'sigmoid_cross_entropy_with_logits', 'sigmoid_cross_entropy_with_logits',
'maxout', 'maxout',
'sequence_reverse',
'affine_channel', 'affine_channel',
'hash',
] ]
...@@ -980,7 +982,12 @@ def cos_sim(X, Y): ...@@ -980,7 +982,12 @@ def cos_sim(X, Y):
return out return out
def dropout(x, dropout_prob, is_test=False, seed=None, name=None): def dropout(x,
dropout_prob,
is_test=False,
seed=None,
name=None,
dropout_implementation="downgrade_in_infer"):
""" """
Computes dropout. Computes dropout.
...@@ -1000,6 +1007,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -1000,6 +1007,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training. units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns: Returns:
Variable: A tensor variable is the shape with `x`. Variable: A tensor variable is the shape with `x`.
...@@ -1029,7 +1051,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -1029,7 +1051,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob': dropout_prob, 'dropout_prob': dropout_prob,
'is_test': is_test, 'is_test': is_test,
'fix_seed': seed is not None, 'fix_seed': seed is not None,
'seed': seed if seed is not None else 0 'seed': seed if seed is not None else 0,
'dropout_implementation': dropout_implementation,
}) })
return out return out
...@@ -1969,17 +1992,17 @@ def sequence_slice(input, offset, length, name=None): ...@@ -1969,17 +1992,17 @@ def sequence_slice(input, offset, length, name=None):
""" """
**Sequence Slice Layer** **Sequence Slice Layer**
The layer crops a subsequence from given sequence with given start The layer crops a subsequence from given sequence with given start
offset and subsequence length. offset and subsequence length.
It only supports sequence data (LoDTensor with lod_level equal to 1). It only supports sequence data (LoDTensor with lod_level equal to 1).
.. code-block:: text .. code-block:: text
- Case: - Case:
Given the input Variable **input**: Given the input Variable **input**:
input.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]], input.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]],
input.lod = [[3, 2]], input.lod = [[3, 2]],
input.dims = (5, 2), input.dims = (5, 2),
...@@ -1987,16 +2010,16 @@ def sequence_slice(input, offset, length, name=None): ...@@ -1987,16 +2010,16 @@ def sequence_slice(input, offset, length, name=None):
with offset.data = [[0], [1]] and length.data = [[2], [1]], with offset.data = [[0], [1]] and length.data = [[2], [1]],
the output Variable will be the output Variable will be
out.data = [[a1, a2], [b1, b2], [e1, e2]], out.data = [[a1, a2], [b1, b2], [e1, e2]],
out.lod = [[2, 1]], out.lod = [[2, 1]],
out.dims = (3, 2). out.dims = (3, 2).
NOTE: The first dimension size of **input**, **offset** and **length** NOTE: The first dimension size of **input**, **offset** and **length**
should be equal. The **offset** should start from 0. should be equal. The **offset** should start from 0.
Args: Args:
input(Variable): The input Variable which consists of the complete input(Variable): The input Variable which consists of the complete
sequences. sequences.
offset(Variable): The offset to slice each sequence. offset(Variable): The offset to slice each sequence.
length(Variable): The length of each subsequence. length(Variable): The length of each subsequence.
...@@ -2015,7 +2038,7 @@ def sequence_slice(input, offset, length, name=None): ...@@ -2015,7 +2038,7 @@ def sequence_slice(input, offset, length, name=None):
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
offset = fluid.layers.assign(input=np.array([[0, 1]]).astype("int32")) offset = fluid.layers.assign(input=np.array([[0, 1]]).astype("int32"))
length = fluid.layers.assign(input=np.array([[2, 1]]).astype("int32")) length = fluid.layers.assign(input=np.array([[2, 1]]).astype("int32"))
subseqs = fluid.layers.sequence_slice(input=seqs, offset=offset, subseqs = fluid.layers.sequence_slice(input=seqs, offset=offset,
length=length) length=length)
""" """
helper = LayerHelper("sequence_slice", **locals()) helper = LayerHelper("sequence_slice", **locals())
...@@ -2398,12 +2421,12 @@ def layer_norm(input, ...@@ -2398,12 +2421,12 @@ def layer_norm(input,
param_attr(ParamAttr|None): The parameter attribute for the learnable param_attr(ParamAttr|None): The parameter attribute for the learnable
gain :math:`g`. If :attr:`scale` is False, :attr:`param_attr` is gain :math:`g`. If :attr:`scale` is False, :attr:`param_attr` is
omitted. If :attr:`scale` is True and :attr:`param_attr` is None, omitted. If :attr:`scale` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as scale. The a default :code:`ParamAttr` would be added as scale. The
:attr:`param_attr` is initialized as 1 if it is added. Default None. :attr:`param_attr` is initialized as 1 if it is added. Default None.
bias_attr(ParamAttr|None): The parameter attribute for the learnable bias_attr(ParamAttr|None): The parameter attribute for the learnable
bias :math:`b`. If :attr:`shift` is False, :attr:`bias_attr` is bias :math:`b`. If :attr:`shift` is False, :attr:`bias_attr` is
omitted. If :attr:`shift` is True and :attr:`param_attr` is None, omitted. If :attr:`shift` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as bias. The a default :code:`ParamAttr` would be added as bias. The
:attr:`bias_attr` is initialized as 0 if it is added. Default None. :attr:`bias_attr` is initialized as 0 if it is added. Default None.
act(str): Activation to be applied to the output of layer normalizaiton. act(str): Activation to be applied to the output of layer normalizaiton.
Default None. Default None.
...@@ -3021,8 +3044,8 @@ def sequence_unpad(x, length, name=None): ...@@ -3021,8 +3044,8 @@ def sequence_unpad(x, length, name=None):
""" """
**Sequence Unpad Layer** **Sequence Unpad Layer**
This layer removes the padding data in the input sequences and convert This layer removes the padding data in the input sequences and convert
them into sequences with actual length as output, identitied by lod them into sequences with actual length as output, identitied by lod
information. information.
.. code-block:: text .. code-block:: text
...@@ -3032,9 +3055,9 @@ def sequence_unpad(x, length, name=None): ...@@ -3032,9 +3055,9 @@ def sequence_unpad(x, length, name=None):
Given input Variable **x**: Given input Variable **x**:
x.data = [[ 1.0, 2.0, 3.0, 4.0, 5.0], x.data = [[ 1.0, 2.0, 3.0, 4.0, 5.0],
[ 6.0, 7.0, 8.0, 9.0, 10.0], [ 6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0]], [11.0, 12.0, 13.0, 14.0, 15.0]],
in which there are 3 sequences padded to length 5, and the acutal length in which there are 3 sequences padded to length 5, and the acutal length
specified by input Variable **length**: specified by input Variable **length**:
length.data = [[2], [3], [4]], length.data = [[2], [3], [4]],
...@@ -3042,7 +3065,7 @@ def sequence_unpad(x, length, name=None): ...@@ -3042,7 +3065,7 @@ def sequence_unpad(x, length, name=None):
after unpadding, the output Variable will be: after unpadding, the output Variable will be:
out.data = [[1.0, 2.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, 14.0]] out.data = [[1.0, 2.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, 14.0]]
out.lod = [[2, 3, 4]] out.lod = [[2, 3, 4]]
Args: Args:
x(Variable): Input Variable which contains the padded sequences with x(Variable): Input Variable which contains the padded sequences with
...@@ -4844,7 +4867,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): ...@@ -4844,7 +4867,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
return counter return counter
def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
""" """
Gives a new shape to the input Tensor without changing its data. Gives a new shape to the input Tensor without changing its data.
...@@ -4892,15 +4915,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4892,15 +4915,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
:attr:`shape` specifying shape. That is to :attr:`shape` specifying shape. That is to
say :attr:`actual_shape` has a higher priority say :attr:`actual_shape` has a higher priority
than :attr:`shape`. than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable. act (str): The non-linear activation to be applied to the reshaped tensor
inplace(bool): If this flag is set true, the output variable.
shares data with input without copying, otherwise inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple
a new output tensor is created operators. If this flag is set :attr:`True`, reuse input
whose data is copied from input x. :attr:`x` to reshape, which will change the shape of
tensor variable :attr:`x` and might cause errors when
:attr:`x` is used in multiple operators. If :attr:`False`,
preserve the shape :attr:`x` and create a new output tensor
variable whose data is copied from input x but reshaped.
name (str): The name of this layer. It is optional. name (str): The name of this layer. It is optional.
Returns: Returns:
Variable: The output tensor. Variable: The reshaped tensor variable if :attr:`act` is None. It is a \
new tensor variable if :attr:`inplace` is :attr:`False`, \
otherwise it is :attr:`x`. If :attr:`act` is not None, return \
the activated tensor variable.
Raises: Raises:
TypeError: if actual_shape is neither Variable nor None. TypeError: if actual_shape is neither Variable nor None.
...@@ -4911,7 +4941,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4911,7 +4941,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
data = fluid.layers.data( data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32') name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.reshape( reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True) x=data, shape=[-1, 0, 3, 2], inplace=True)
""" """
if not (isinstance(shape, list) or isinstance(shape, tuple)): if not (isinstance(shape, list) or isinstance(shape, tuple)):
...@@ -4938,7 +4968,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4938,7 +4968,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"except one unknown dimension.") "except one unknown dimension.")
helper = LayerHelper("reshape2", **locals()) helper = LayerHelper("reshape2", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = x if inplace else helper.create_variable_for_type_inference(
dtype=x.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape2", type="reshape2",
...@@ -5469,9 +5500,9 @@ def roi_align(input, ...@@ -5469,9 +5500,9 @@ def roi_align(input,
Examples: Examples:
.. code-block:: python .. code-block:: python
align_out = fluid.layers.roi_align(input=x, align_out = fluid.layers.roi_align(input=x,
rois=rois, rois=rois,
pooled_height=7, pooled_height=7,
pooled_width=7, pooled_width=7,
spatial_scale=0.5, spatial_scale=0.5,
sampling_ratio=-1) sampling_ratio=-1)
...@@ -7455,13 +7486,40 @@ def maxout(x, groups, name=None): ...@@ -7455,13 +7486,40 @@ def maxout(x, groups, name=None):
return out return out
@templatedoc()
def sequence_reverse(x, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
name(basestring|None): Name of the output.
Returns:
out(${y_type}): ${y_comment}
"""
helper = LayerHelper("sequence_reverse", **locals())
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="sequence_reverse",
inputs={"X": x},
outputs={"Y": out},
attrs=dict())
return out
def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
""" """
Applies a separate affine transformation to each channel of the input. Applies a separate affine transformation to each channel of the input.
Useful for replacing spatial batch norm with its equivalent fixed Useful for replacing spatial batch norm with its equivalent fixed
transformation. The input also can be 2D tensor and applies a affine transformation. The input also can be 2D tensor and applies a affine
transformation in second dimension. transformation in second dimension.
Args: Args:
x (Variable): Feature map input can be a 4D tensor with order NCHW x (Variable): Feature map input can be a 4D tensor with order NCHW
or NHWC. It also can be a 2D tensor and the affine transformation or NHWC. It also can be a 2D tensor and the affine transformation
...@@ -7494,3 +7552,31 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): ...@@ -7494,3 +7552,31 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
attrs={"data_layout": data_layout}, attrs={"data_layout": data_layout},
outputs={"Out": out}) outputs={"Out": out})
return out return out
def hash(input, hash_size, num_hash=1, name=None):
"""
hash the input
Args:
input (Variable): The input variable which is a one-hot word.
hash_size (int): The space size for hash algorithm.
num_hash (int): The times of hash, default 1.
name (str, default None): The name of this layer.
Returns:
Variable: The hash result variable which is a LoDTensor.
Examples:
.. code-block:: python
word_dict = paddle.dataset.imdb.word_dict()
x = fluid.layers.data(shape[1], dtype='int32', lod_level=1)
out = fluid.layers.hash(input=x, len(word_dict))
"""
helper = LayerHelper('hash', **locals())
out = helper.create_variable_for_type_inference(
helper.input_dtype(), stop_gradient=True)
helper.append_op(
type='hash',
inputs={'X': input},
outputs={'Out': out},
attrs={'num_hash': num_hash,
'mod_by': hash_size})
return out
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
""" """
Fluid Metrics Fluid Metrics
The metrics are accomplished via Python natively.
""" """
from __future__ import print_function from __future__ import print_function
...@@ -24,6 +22,12 @@ import copy ...@@ -24,6 +22,12 @@ import copy
import warnings import warnings
import six import six
from .layer_helper import LayerHelper
from .initializer import Constant
from . import unique_name
from .framework import Program, Variable, program_guard
from . import layers
__all__ = [ __all__ = [
'MetricBase', 'MetricBase',
'CompositeMetric', 'CompositeMetric',
...@@ -474,71 +478,10 @@ class EditDistance(MetricBase): ...@@ -474,71 +478,10 @@ class EditDistance(MetricBase):
"There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance." "There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance."
) )
avg_distance = self.total_distance / self.seq_num avg_distance = self.total_distance / self.seq_num
avg_instance_error = self.instance_error / self.seq_num avg_instance_error = self.instance_error / float(self.seq_num)
return avg_distance, avg_instance_error return avg_distance, avg_instance_error
class DetectionMAP(MetricBase):
"""
Calculate the detection mean average precision (mAP).
mAP is the metric to measure the accuracy of object detectors
like Faster R-CNN, SSD, etc.
It is the average of the maximum precisions at different recall values.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Examples:
.. code-block:: python
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
batch_map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
metric = fluid.metrics.DetectionMAP()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, batch_map])
batch_size = data[0]
metric.update(value=batch_map, weight=batch_size)
numpy_map = metric.eval()
"""
def __init__(self, name=None):
super(DetectionMAP, self).__init__(name)
# the current map value
self.value = .0
self.weight = .0
def update(self, value, weight):
if not _is_number_or_matrix_(value):
raise ValueError(
"The 'value' must be a number(int, float) or a numpy ndarray.")
if not _is_number_(weight):
raise ValueError("The 'weight' must be a number(int, float).")
self.value += value
self.weight += weight
def eval(self):
if self.weight == 0:
raise ValueError(
"There is no data in DetectionMAP Metrics. "
"Please check layers.detection_map output has added to DetectionMAP."
)
return self.value / self.weight
class Auc(MetricBase): class Auc(MetricBase):
""" """
Auc metric adapts to the binary classification. Auc metric adapts to the binary classification.
...@@ -616,3 +559,179 @@ class Auc(MetricBase): ...@@ -616,3 +559,179 @@ class Auc(MetricBase):
idx -= 1 idx -= 1
return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0 return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
class DetectionMAP(object):
"""
Calculate the detection mean average precision (mAP).
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
Args:
input (Variable): The detection results, which is a LoDTensor with shape
[M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax].
gt_label (Variable): The ground truth label index, which is a LoDTensor
with shape [N, 1].
gt_box (Variable): The ground truth bounding box (bbox), which is a
LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
gt_difficult (Variable|None): Whether this ground truth is a difficult
bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
it means all the ground truth labels are not difficult bbox.
class_num (int): The class number.
background_label (int): The index of background label, the background
label will be ignored. If set to -1, then all categories will be
considered, 0 by defalut.
overlap_threshold (float): The threshold for deciding true/false
positive, 0.5 by defalut.
evaluate_difficult (bool): Whether to consider difficult ground truth
for evaluation, True by defalut. This argument does not work when
gt_difficult is None.
ap_version (string): The average precision calculation ways, it must be
'integral' or '11point'. Please check
https://sanchom.wordpress.com/tag/average-precision/ for details.
- 11point: the 11-point interpolated average precision.
- integral: the natural integral of the precision-recall curve.
Examples:
.. code-block:: python
exe = fluid.Executor(place)
map_evaluator = fluid.Evaluator.DetectionMAP(input,
gt_label, gt_box, gt_difficult)
cur_map, accum_map = map_evaluator.get_map_var()
fetch = [cost, cur_map, accum_map]
for epoch in PASS_NUM:
map_evaluator.reset(exe)
for data in batches:
loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
In the above example:
'cur_map_v' is the mAP of current mini-batch.
'accum_map_v' is the accumulative mAP of one pass.
"""
def __init__(self,
input,
gt_label,
gt_box,
gt_difficult=None,
class_num=None,
background_label=0,
overlap_threshold=0.5,
evaluate_difficult=True,
ap_version='integral'):
self.helper = LayerHelper('map_eval')
gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
if gt_difficult:
gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype)
label = layers.concat([gt_label, gt_difficult, gt_box], axis=1)
else:
label = layers.concat([gt_label, gt_box], axis=1)
# calculate mean average precision (mAP) of current mini-batch
map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
states = []
states.append(
self._create_state(
dtype='int32', shape=None, suffix='accum_pos_count'))
states.append(
self._create_state(
dtype='float32', shape=None, suffix='accum_true_pos'))
states.append(
self._create_state(
dtype='float32', shape=None, suffix='accum_false_pos'))
var = self._create_state(dtype='int32', shape=[1], suffix='has_state')
self.helper.set_variable_initializer(
var, initializer=Constant(value=int(0)))
self.has_state = var
# calculate accumulative mAP
accum_map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
has_state=self.has_state,
input_states=states,
out_states=states,
ap_version=ap_version)
layers.fill_constant(
shape=self.has_state.shape,
value=1,
dtype=self.has_state.dtype,
out=self.has_state)
self.cur_map = map
self.accum_map = accum_map
def _create_state(self, suffix, dtype, shape):
"""
Create state variable.
Args:
suffix(str): the state suffix.
dtype(str|core.VarDesc.VarType): the state data type
shape(tuple|list): the shape of state
Returns: State variable
"""
state = self.helper.create_variable(
name="_".join([unique_name.generate(self.helper.name), suffix]),
persistable=True,
dtype=dtype,
shape=shape)
return state
def get_map_var(self):
"""
Returns: mAP variable of current mini-batch and
accumulative mAP variable cross mini-batches.
"""
return self.cur_map, self.accum_map
def reset(self, executor, reset_program=None):
"""
Reset metric states at the begin of each pass/user specified batch.
Args:
executor(Executor): a executor for executing
the reset_program.
reset_program(Program|None): a single Program for reset process.
If None, will create a Program.
"""
def _clone_var_(block, var):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=var.persistable)
if reset_program is None:
reset_program = Program()
with program_guard(main_program=reset_program):
var = _clone_var_(reset_program.current_block(), self.has_state)
layers.fill_constant(
shape=var.shape, value=0, dtype=var.dtype, out=var)
executor.run(reset_program)
...@@ -1159,6 +1159,7 @@ def prepare_encoder(src_word, ...@@ -1159,6 +1159,7 @@ def prepare_encoder(src_word,
name=pos_enc_param_name, name=pos_enc_param_name,
trainable=False, trainable=False,
initializer=fluid.initializer.ConstantInitializer(0.001))) initializer=fluid.initializer.ConstantInitializer(0.001)))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
return layers.dropout( return layers.dropout(
enc_input, enc_input,
......
...@@ -23,9 +23,8 @@ class TestDistCTR2x2(TestDistBase): ...@@ -23,9 +23,8 @@ class TestDistCTR2x2(TestDistBase):
self._sync_mode = True self._sync_mode = True
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_dist_ctr(self):
def test_dist_ctr(self): self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -40,8 +40,7 @@ class TestDistMnistAsync(TestDistBase): ...@@ -40,8 +40,7 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reduce = False self._use_reduce = False
# FIXME(typhoonzero): fix async mode test later def test_dist_train(self):
def no_test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200) self.check_with_place("dist_mnist.py", delta=200)
......
...@@ -40,8 +40,7 @@ class TestDistSeResneXt2x2Async(TestDistBase): ...@@ -40,8 +40,7 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reader_alloc = False self._use_reader_alloc = False
#FIXME(typhoonzero): fix async mode later def test_dist_train(self):
def no_test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
......
...@@ -79,8 +79,7 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): ...@@ -79,8 +79,7 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
#FIXME(typhoonzero): fix async tests later def test_simnet_bow(self):
def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '1', "IS_SPARSE": '1',
......
...@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest): ...@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self.check_output() self.check_output()
class TestDropoutOp6(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {
'dropout_prob': 1.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
}
class TestDropoutOp7(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
}
class TestDropoutOp8(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {
'dropout_prob': 0.35,
'fix_seed': True,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {'Out': self.inputs['X']}
def test_check_output(self):
self.check_output()
class TestDropoutOp9(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {
'dropout_prob': 0.75,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {'Out': self.inputs['X']}
def test_check_output(self):
self.check_output()
class TestFP16DropoutOp(OpTest): class TestFP16DropoutOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestScaleOp(OpTest):
def setUp(self):
self.op_type = "hash"
self.init_test_case()
self.inputs = {'X': (self.in_seq, self.lod)}
self.attrs = {'num_hash': 4, 'mod_by': 10000}
self.outputs = {'Out': (self.out_seq, self.lod)}
def init_test_case(self):
np.random.seed = 1
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
# self.out_seq = np.ones([30, 4, 1], dtype=np.int32)
self.out_seq = [
[[9662], [9217], [1129], [8487]], [[9662], [9217], [1129], [8487]],
[[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]],
[[9407], [6715], [6949], [8094]], [[8473], [694], [5142], [2479]],
[[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]],
[[4372], [9456], [8204], [6695]], [[6897], [3218], [2013], [1241]],
[[8473], [694], [5142], [2479]], [[4372], [9456], [8204], [6695]],
[[4372], [9456], [8204], [6695]], [[8473], [694], [5142], [2479]],
[[9407], [6715], [6949], [8094]], [[9369], [4525], [8935], [9210]],
[[4372], [9456], [8204], [6695]], [[4372], [9456], [8204], [6695]],
[[9369], [4525], [8935], [9210]], [[6897], [3218], [2013], [1241]],
[[9038], [7951], [5953], [8657]], [[9407], [6715], [6949], [8094]],
[[9662], [9217], [1129], [8487]], [[9369], [4525], [8935], [9210]],
[[9038], [7951], [5953], [8657]], [[9662], [9217], [1129], [8487]],
[[9369], [4525], [8935], [9210]], [[1719], [5986], [9919], [3421]],
[[4372], [9456], [8204], [6695]], [[9038], [7951], [5953], [8657]]
]
self.out_seq = np.array(self.out_seq)
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
class TestMetricsDetectionMap(unittest.TestCase):
def test_detection_map(self):
program = fluid.Program()
with program_guard(program):
detect_res = fluid.layers.data(
name='detect_res',
shape=[10, 6],
append_batch_size=False,
dtype='float32')
label = fluid.layers.data(
name='label',
shape=[10, 1],
append_batch_size=False,
dtype='float32')
box = fluid.layers.data(
name='bbox',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
map_eval = fluid.metrics.DetectionMAP(
detect_res, label, box, class_num=21)
cur_map, accm_map = map_eval.get_map_var()
self.assertIsNotNone(cur_map)
self.assertIsNotNone(accm_map)
print(str(program))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest
import numpy as np
class TestSequenceReverseBase(OpTest):
def initParameters(self):
pass
def setUp(self):
self.size = (10, 3, 4)
self.lod = [2, 3, 5]
self.dtype = 'float32'
self.initParameters()
self.op_type = 'sequence_reverse'
self.x = np.random.random(self.size).astype(self.dtype)
self.y = self.get_output()
self.inputs = {'X': (self.x, [self.lod, ]), }
self.outputs = {'Y': (self.y, [self.lod, ]), }
def get_output(self):
tmp_x = np.reshape(self.x, newshape=[self.x.shape[0], -1])
tmp_y = np.ndarray(tmp_x.shape).astype(self.dtype)
prev_idx = 0
for cur_len in self.lod:
idx_range = range(prev_idx, prev_idx + cur_len)
tmp_y[idx_range, :] = np.flip(tmp_x[idx_range, :], 0)
prev_idx += cur_len
return np.reshape(tmp_y, newshape=self.x.shape).astype(self.dtype)
def test_output(self):
self.check_output(0)
def test_grad(self):
self.check_grad(['X'], 'Y')
class TestSequenceReserve1(TestSequenceReverseBase):
def initParameters(self):
self.size = (12, 10)
self.lod = [4, 5, 3]
class TestSequenceReverse2(TestSequenceReverseBase):
def initParameters(self):
self.size = (12, 10)
self.lod = [12]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册