diff --git a/.gitignore b/.gitignore index 90138f996cf9cacc3c1cbff0cf2600eefca3f305..fa0c8882606b76ac71b43dcde7e1df6770c46c31 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ third_party/ build_* # clion workspace. cmake-build-* +model_test diff --git a/CMakeLists.txt b/CMakeLists.txt index 6aa2e1715b92d73aa4e5e97d5e52ffac51451d80..d3379a663db4613e529cdba4ce22111765ff59cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,7 @@ option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_INFERENCE "Compile fluid inference library" ON) +option(ON_INFER "Turn on inference optimization." OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) @@ -179,6 +180,7 @@ include(external/eigen) # download eigen3 include(external/pybind11) # download pybind11 include(external/cares) include(external/cub) +include(external/xxhash) # download xxhash if (NOT WIN32) # there is no official support of snappystream, warpctc, nccl, cupti in windows @@ -301,3 +303,8 @@ if(WITH_DOC) find_python_module(recommonmark REQUIRED) add_subdirectory(doc) endif() + +if (ON_INFER) + message(WARNING "On inference mode, will take place some specific optimization.") + add_definitions(-DPADDLE_ON_INFERENCE) +endif() diff --git a/Dockerfile b/Dockerfile index 738bba9bc2e1ab19709722fe04f1490b1b13bd4f..c8b9eed6d60e5d3b32fc14c0c7af80a785145d1b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \ pip3 install -U docopt PyYAML sphinx==1.5.6 && \ pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \ easy_install -U pip && \ - pip install -U wheel && \ + pip install -U pip setuptools wheel && \ pip install -U docopt PyYAML sphinx==1.5.6 && \ pip install sphinx-rtd-theme==0.1.9 recommonmark -RUN pip3 install pre-commit 'ipython==5.3.0' && \ +RUN pip3 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \ pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3 install opencv-python && \ - pip install pre-commit 'ipython==5.3.0' && \ + pip install 'pre-commit==1.10.4' 'ipython==5.3.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install opencv-python diff --git a/cmake/external/xxhash.cmake b/cmake/external/xxhash.cmake new file mode 100644 index 0000000000000000000000000000000000000000..4deaab7545c20002fedcad1cca6df54fe9783eb0 --- /dev/null +++ b/cmake/external/xxhash.cmake @@ -0,0 +1,46 @@ +INCLUDE(ExternalProject) + +set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash) +set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash) +set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include") + +IF(WITH_STATIC_LIB) + SET(BUILD_CMD make lib) +ELSE() + SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib) +ENDIF() + +ExternalProject_Add( + extern_xxhash + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/Cyan4973/xxHash" + GIT_TAG "v0.6.5" + PREFIX ${XXHASH_SOURCE_DIR} + DOWNLOAD_NAME "xxhash" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + PATCH_COMMAND + BUILD_COMMAND ${BUILD_CMD} + INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install + TEST_COMMAND "" +) + +set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a") +INCLUDE_DIRECTORIES(${XXHASH_INCLUDE_DIR}) + +add_library(xxhash STATIC IMPORTED GLOBAL) +set_property(TARGET xxhash PROPERTY IMPORTED_LOCATION ${XXHASH_LIBRARIES}) +include_directories(${XXHASH_INCLUDE_DIR}) +add_dependencies(xxhash extern_xxhash) + +LIST(APPEND external_project_dependencies xxhash) + +IF(WITH_C_API) + INSTALL(DIRECTORY ${XXHASH_INCLUDE_DIR} DESTINATION third_party/xxhash) + IF(ANDROID) + INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib/${ANDROID_ABI}) + ELSE() + INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib) + ENDIF() +ENDIF() diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 67cca09b64c1ed7a503a886e78347d786eae0de7..1047b6f998a74e42114b9deab4f0e7ba1af36835 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -14,6 +14,9 @@ # make package for paddle fluid shared and static library function(copy TARGET) + if (NOT ON_INFER) + message(WARNING "Turn on the ON_INFER flag when building inference_lib only.") + endif() set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DSTS DEPS) @@ -31,7 +34,7 @@ function(copy TARGET) foreach(index RANGE ${len}) list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_DSTS ${index} dst) - add_custom_command(TARGET ${TARGET} PRE_BUILD + add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND mkdir -p "${dst}" COMMAND cp -r "${src}" "${dst}" COMMENT "copying ${src} -> ${dst}") @@ -67,6 +70,13 @@ copy(boost_lib DEPS boost ) +set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/xxhash") +copy(xxhash_lib + SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES} + DSTS ${dst_dir} ${dst_dir}/lib + DEPS xxhash +) + if(NOT PROTOBUF_FOUND) set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf") copy(protobuf_lib @@ -186,7 +196,7 @@ copy(cmake_cache DSTS ${FLUID_INSTALL_DIR}) # This command generates a complete fluid library for both train and inference -add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep}) +add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep}) # Following commands generate a inference-only fluid library # third_party, version.txt and CMakeCache.txt are the same position with ${FLUID_INSTALL_DIR} diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 19ef23cdfa90912ff6fbd050a685d10861d909d2..0d90bf3cc12e285f5cafd80180c12ddeb4ad8b51 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name' paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None)) +paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')) paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)) @@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)) -paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None)) +paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)) paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)) @@ -174,7 +174,9 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) +paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index c54766d95a61ac1a4b61566c6de62cbc86685a1d..01e878089171e4620f32b57a65d92d1c86d307db 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) { std::deque q_nodes; std::vector> graph_nodes; std::unordered_set g_nodes; + // q_set used to record records in the queue. + std::unordered_set q_set; size_t graph_count = 0; - auto traverse_nodes = [&visited_nodes, - &q_nodes](const std::vector &nodes) { - std::copy_if( - nodes.begin(), nodes.end(), std::back_inserter(q_nodes), - [&visited_nodes](Node *node) { return !visited_nodes.count(node); }); + auto traverse_nodes = [&visited_nodes, &q_nodes, + &q_set](const std::vector &nodes) { + for (auto n : nodes) { + if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) { + q_nodes.push_back(n); + q_set.insert(n); + } + } }; while (visited_nodes.size() != nodes.size()) { if (!q_nodes.empty()) { auto cur_node = q_nodes.front(); q_nodes.pop_front(); + q_set.erase(cur_node); visited_nodes.insert(cur_node); g_nodes.insert(cur_node); traverse_nodes(cur_node->inputs); @@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) { for (auto &n : nodes) { if (visited_nodes.count(n) == 0) { q_nodes.push_back(n); + q_set.insert(n); break; } } diff --git a/paddle/fluid/framework/lod_tensor_array.h b/paddle/fluid/framework/lod_tensor_array.h index 6d7b6a4ada8729e3698dab5d2b1861aac632be79..0ad6a709008406257d6c0a220bce38bb24e188cd 100644 --- a/paddle/fluid/framework/lod_tensor_array.h +++ b/paddle/fluid/framework/lod_tensor_array.h @@ -18,6 +18,82 @@ limitations under the License. */ namespace paddle { namespace framework { + +// NOTE The vector can't be replaced with the class LoDTensorArray +// directly, because there are many vector used accross the project, +// and some of them are treated as LoDTensorArray. +#if !defined(PADDLE_ON_INFERENCE) + using LoDTensorArray = std::vector; -} + +#else // !PADDLE_ON_INFERENCE + +#pragma message "LoDTensorArray is replaced with the inference one." +/* + * A LoDTensorArray which will not deallocate buffer when resized, fix the data + * diff in inference, and more performance friendly in the concurrency + * scenerios. + */ +class LoDTensorArray { + public: + LoDTensorArray() = default; + + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + const_iterator begin() const { return array_.begin(); } + const_iterator end() const { return array_.begin() + size_; } + iterator begin() { return array_.begin(); } + iterator end() { return array_.begin() + size_; } + + void push_back(const LoDTensor& x) { + if (size_ < array_.size()) { + array_[size_++] = x; + } else { + array_.push_back(x); + ++size_; + } + } + void resize(size_t size) { + if (array_.size() < size) { + array_.resize(size); + } + size_ = size; + } + + void emplace_back() { array_.emplace_back(); } + + void emplace_back(LoDTensor&& x) { array_.emplace_back(std::move(x)); } + + LoDTensor& back() { return array_.back(); } + + size_t space() const { return array_.size(); } + + void reserve(size_t size) { + // Naive warning to tell user this array might be to large. The memory and + // buffer used by this TensorArray will not be deleted during the training + // and inference phase, so attention not to make it expand too long. + if (size > 800UL) { + LOG(WARNING) << "TensorArray has more than 800 items"; + } + array_.reserve(size); + } + + bool empty() const { return size_ == 0UL; } + void clear() { size_ = 0UL; } + + LoDTensor& operator[](size_t id) { return array_[id]; } + const LoDTensor& operator[](size_t id) const { return array_[id]; } + LoDTensor& at(size_t id) { return array_.at(id); } + const LoDTensor& at(size_t id) const { return array_.at(id); } + + size_t size() const { return size_; } + + private: + size_t size_{0}; + std::vector array_; +}; +#endif // !PADDLE_ON_INFERENCE + +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index 77386f4f069489b6ff7b927a281bdc286ff816e0..e1aac6dc5a92fb616f00de5806f044b83c2f503f 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -542,6 +542,33 @@ class CPUVector : public std::vector> { this->reserve(this->size() + size_t(end - begin)); this->insert(this->end(), begin, end); } + + const T *CUDAData(platform::Place place) const { + PADDLE_THROW( + "Vector::CUDAData() method is not supported in CPU-only version"); + } + + T *CUDAMutableData(platform::Place place) { + PADDLE_THROW( + "Vector::CUDAMutableData() method is not supported in CPU-only " + "version"); + } + + const T *Data(platform::Place place) const { + PADDLE_ENFORCE( + platform::is_cpu_place(place), + "Vector::Data() method is not supported when not in CPUPlace"); + return this->data(); + } + + T *MutableData(platform::Place place) { + PADDLE_ENFORCE( + platform::is_cpu_place(place), + "Vector::MutableData() method is not supported when not in CPUPlace"); + return this->data(); + } + + const void *Handle() const { return static_cast(this); } }; template diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 2840d503f1454271afb309efdd435225ab077dc0..7fb42feb95b4d54aec693228721c052f683f4d80 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() { ops_.swap(ops); } -void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) { -#ifdef PADDLE_WITH_MKLDNN - VLOG(3) << "use_mkldnn=True"; - for (size_t block_id = 0; block_id < program.Size(); ++block_id) { - auto *block = const_cast(program).MutableBlock(block_id); - for (auto *op : block->AllOps()) { - if (op->HasAttr("use_mkldnn")) { - op->SetAttr("use_mkldnn", true); - } - } - } -#else - LOG(WARNING) - << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; -#endif -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 9374f3f4a35cc0f90e5b2d6e8b397784b8eae123..ddfa6e1f4d8b73f594fc381ab505797491cdd378 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -48,8 +48,6 @@ class NaiveExecutor { void CleanFeedFetchOps(); - void EnableMKLDNN(const ProgramDesc& program); - protected: void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 5527783faad6f7ea6e0b9e776cbde3c743277f16..678c14a44b5c259ddc6f914b3fbb5b50e649993c 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -28,12 +28,12 @@ enum class OpRole { kBackward = 0x0001, kOptimize = 0x0002, // RPC role is for send/recv releated op - kRPC = 0x0003, + kRPC = 0x0004, // Dist role is for split_byref/split_selected_rows/concat // used for distributed training. - kDist = 0x0004, + kDist = 0x0008, // Tag all learning rate scheduler operators. - kLRSched = 0x0005, + kLRSched = 0x0016, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 7dad872dd04754d2866c361b74ab1236ff743da5..3368ae2ee4cf65b85abf1fcd89dee14f43522e1f 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor( params, member_->local_scopes_, member_->use_cuda_); #endif + // If the loss_var_name is given, the number of graph should be only one. + if (loss_var_name.size()) { + PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, + "The number of graph should be only one"); + } + if (exec_strategy.type_ == ExecutionStrategy::kDefault) { member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 14f9f36812d690fc4a7440f2e7e6a85e9993a535..9462620e829ec815e1553f6378a67463ea3b8aa3 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -78,6 +78,8 @@ class Scope { /// Drop all kids scopes belonged to this scope. void DropKids(); + std::list& kids() const { return kids_; } + /// Find if a scope exists in the kid scopes bool HasKid(const Scope* scope) const; diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 18cdca3a658a6a89e6ab637a7f38825756acfea8..a588cb417aebe94bd4aeda02b1bc8ba07a04b960 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0, namespace paddle { namespace framework { - std::unique_ptr ThreadPool::threadpool_(nullptr); std::once_flag ThreadPool::init_flag_; @@ -47,8 +46,7 @@ void ThreadPool::Init() { } } -ThreadPool::ThreadPool(int num_threads) - : total_threads_(num_threads), idle_threads_(num_threads), running_(true) { +ThreadPool::ThreadPool(int num_threads) : running_(true) { threads_.resize(num_threads); for (auto& thread : threads_) { // TODO(Yancey1989): binding the thread on the specify CPU number @@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads) ThreadPool::~ThreadPool() { { // notify all threads to stop running + std::lock_guard l(mutex_); running_ = false; scheduled_.notify_all(); } @@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() { } } -void ThreadPool::Wait() { - std::unique_lock lock(mutex_); - completed_.wait(lock, [=] { return Done() == true; }); -} - void ThreadPool::TaskLoop() { - while (running_) { + while (true) { std::unique_lock lock(mutex_); - scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; }); - if (!running_) { - break; + scheduled_.wait( + lock, [this] { return !this->tasks_.empty() || !this->running_; }); + + if (!running_ || tasks_.empty()) { + return; } + // pop a task from the task queue auto task = std::move(tasks_.front()); tasks_.pop(); - - --idle_threads_; lock.unlock(); // run the task task(); - - { - std::unique_lock lock(mutex_); - ++idle_threads_; - if (Done()) { - completed_.notify_all(); - } - } } } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 94111ee335b1a5df327b3e46d62069b4735c54f6..0687e628aaa4fb7b2e67938fa09a319c8bb35aff 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -57,15 +57,6 @@ class ThreadPool { ~ThreadPool(); - // Returns the number of threads created by the constructor. - size_t Threads() const { return total_threads_; } - - // Returns the number of currently idle threads. - size_t IdleThreads() { - std::unique_lock lock(mutex_); - return idle_threads_; - } - // Run pushes a function to the task queue and returns a std::future // object. To wait for the completion of the task, call // std::future::wait(). @@ -94,25 +85,13 @@ class ThreadPool { }); std::future> f = task.get_future(); tasks_.push(std::move(task)); - lock.unlock(); scheduled_.notify_one(); return f; } - // Wait until all the tasks are completed. - void Wait(); - private: DISABLE_COPY_AND_ASSIGN(ThreadPool); - // If the task queue is empty and avaialbe is equal to the number of - // threads, means that all tasks are completed. Note: this function - // is not thread-safe. Returns true if all tasks are completed. - // Note: don't delete the data member total_threads_ and use - // threads_.size() instead; because you'd need to lock the mutex - // before accessing threads_. - bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; } - // The constructor starts threads to run TaskLoop, which retrieves // and runs tasks from the queue. void TaskLoop(); @@ -125,14 +104,11 @@ class ThreadPool { static std::once_flag init_flag_; std::vector> threads_; - const size_t total_threads_; - size_t idle_threads_; std::queue tasks_; std::mutex mutex_; bool running_; std::condition_variable scheduled_; - std::condition_variable completed_; }; class ThreadPoolIO : ThreadPool { diff --git a/paddle/fluid/framework/threadpool_test.cc b/paddle/fluid/framework/threadpool_test.cc index 27a4ffd4fcbf293a3dea1744b29384d0bee0c137..884d61e23428a0ad758946295ca9c470767e93ef 100644 --- a/paddle/fluid/framework/threadpool_test.cc +++ b/paddle/fluid/framework/threadpool_test.cc @@ -19,10 +19,11 @@ limitations under the License. */ namespace framework = paddle::framework; -void do_sum(framework::ThreadPool* pool, std::atomic* sum, int cnt) { - std::vector> fs; +void do_sum(std::vector>* fs, std::mutex* mu, + std::atomic* sum, int cnt) { for (int i = 0; i < cnt; ++i) { - fs.push_back(framework::Async([sum]() { sum->fetch_add(1); })); + std::lock_guard l(*mu); + fs->push_back(framework::Async([sum]() { sum->fetch_add(1); })); } } @@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) { } TEST(ThreadPool, ConcurrentRun) { - framework::ThreadPool* pool = framework::ThreadPool::GetInstance(); std::atomic sum(0); std::vector threads; + std::vector> fs; + std::mutex fs_mu; int n = 50; // sum = (n * (n + 1)) / 2 for (int i = 1; i <= n; ++i) { - std::thread t(do_sum, pool, &sum, i); + std::thread t(do_sum, &fs, &fs_mu, &sum, i); threads.push_back(std::move(t)); } for (auto& t : threads) { t.join(); } - pool->Wait(); + for (auto& t : fs) { + t.wait(); + } EXPECT_EQ(sum, ((n + 1) * n) / 2); } diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 7202fd44a4de00f3aa90b0047c859a637e16a29c..e7f13b19da0ef2d6b73e6c02345e122d11e77dd3 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -31,7 +31,7 @@ if (WITH_GPU AND TENSORRT_FOUND) endif() # Create static library -cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor) +cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array) if(NOT APPLE) # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. @@ -41,7 +41,7 @@ endif() # Create shared library cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} - DEPS ${fluid_modules} paddle_fluid_api) + DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) if(NOT APPLE) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 2e79d495d5ff00000000029ac0f6eb486aaea94a..ef4142f334e503380dc7ccd74c348404ffe52ee6 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -107,6 +107,9 @@ void Analyzer::Run(Argument* argument) { passes.push_back("mkldnn_placement_pass"); } #endif + // infer_clean_graph_pass should be the first default pass + // after mkldnn_placement_pass. + passes.push_back("infer_clean_graph_pass"); for (auto& pass : ir_passes_) { if (!disabled_ir_passes_.count(pass)) { passes.push_back(pass); diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index c51a4fdb2f6b27e54637481c23bf6f1f6ec97718..7114f5222c5904bb4422b9e67ad035b85bbb770c 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -67,7 +67,6 @@ class Analyzer : public OrderedRegistry { // larger fusion. const std::vector all_ir_passes_{{ // Manual update the passes here. - "infer_clean_graph_pass", // "attention_lstm_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // "embedding_fc_lstm_fuse_pass", // diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 0ddd5d53f836131fe37d412fc867cb38f11ee2b5..e2027b7cb4d584ffcc48624d2c01e65a61829975 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -18,7 +18,8 @@ if(APPLE) endif(APPLE) -set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor ${GLOB_PASS_LIB}) +set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor ${GLOB_PASS_LIB} + ) if(WITH_GPU AND TENSORRT_FOUND) set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine analysis_predictor) @@ -31,10 +32,17 @@ function(inference_api_test TARGET_NAME) set(multiValueArgs ARGS) cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cc_test(${TARGET_NAME} - SRCS ${inference_test_SRC} - DEPS "${inference_deps}" - ARGS --dirname=${PYTHON_TESTS_DIR}/book/) + if (WITH_GPU) + cc_test(${TARGET_NAME} + SRCS ${inference_test_SRC} + DEPS "${inference_deps}" + ARGS --dirname=${PYTHON_TESTS_DIR}/book/ --fraction_of_gpu_memory_to_use=0.15) + else() + cc_test(${TARGET_NAME} + SRCS ${inference_test_SRC} + DEPS "${inference_deps}" + ARGS --dirname=${PYTHON_TESTS_DIR}/book/) + endif() if(inference_test_ARGS) set_tests_properties(${TARGET_NAME} PROPERTIES DEPENDS "${inference_test_ARGS}") @@ -42,7 +50,8 @@ function(inference_api_test TARGET_NAME) endif(WITH_TESTING) endfunction(inference_api_test) -cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope) +cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope) +cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS reset_tensor_array lod_tensor scope) cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor) cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api) cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc DEPS paddle_inference_api) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index eec665767164dc6e79738890947c54d7f7217037..54c37fe64590aa82d7100c93c4c5c4ee36491421 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -82,6 +82,7 @@ bool AnalysisPredictor::Init( // Get the feed_target_names and fetch_target_names PrepareFeedFetch(); + return true; } @@ -109,6 +110,10 @@ bool AnalysisPredictor::Run(const std::vector &inputs, return false; } VLOG(3) << "predict cost: " << timer.toc() << "ms"; + + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } @@ -322,6 +327,9 @@ std::unique_ptr AnalysisPredictor::GetOutputTensor( bool AnalysisPredictor::ZeroCopyRun() { executor_->Run(); + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 5a9f4d36959d4ee7ca16dec769d9d1283b8787cb..b7dc2067332278c1c38df4beefb5059efe76417f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/string/printf.h" @@ -88,6 +89,7 @@ class AnalysisPredictor : public PaddlePredictor { // Memory buffer for feed inputs. The temporary LoDTensor will cause serious // concurrency problems, so cache them. std::vector feed_tensors_; + details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 7cda9c5d8a8366bd097491f37f5352a10e4fb16c..d06ab8f8c8e3c0adf4a4074eb1450012126e83ea 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -157,6 +158,10 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, return false; } VLOG(3) << "predict cost: " << timer.toc() << "ms"; + + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } diff --git a/paddle/fluid/inference/api/api_impl.h b/paddle/fluid/inference/api/api_impl.h index 7882f6a53c7ce9a2486158ea9b50c018d1814091..4e4ab47ca9c5e37f2714ebd48d250c23c7e9b117 100644 --- a/paddle/fluid/inference/api/api_impl.h +++ b/paddle/fluid/inference/api/api_impl.h @@ -26,11 +26,11 @@ limitations under the License. */ #include #include -#include "paddle/fluid/inference/api/paddle_inference_api.h" - #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/init.h" @@ -77,6 +77,7 @@ class NativePaddlePredictor : public PaddlePredictor { std::vector fetchs_; // Do not use unique_ptr, use parent scope to delete framework::Scope *sub_scope_{nullptr}; + details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 03f0f726eb61c2619c7719a865383090f86b5b7f..49683eab07a2f5bc008272038a27bdb277396284 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -52,6 +52,7 @@ include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") if (NOT WIN32) include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") @@ -61,8 +62,8 @@ endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") -if (NOT WIN32) - if (USE_TENSORRT AND WITH_GPU) +if (NOT WIN32) + if (USE_TENSORRT AND WITH_GPU) include_directories("${TENSORRT_INCLUDE_DIR}") link_directories("${TENSORRT_LIB_DIR}") endif() @@ -77,13 +78,14 @@ endif(NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/paddle/lib") add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) if(WITH_MKL) include_directories("${PADDLE_LIB}/third_party/install/mklml/include") - set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") if(EXISTS ${MKLDNN_PATH}) @@ -107,7 +109,7 @@ if (NOT WIN32) set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} - glog gflags protobuf snappystream snappy z + glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) else() set(DEPS ${DEPS} @@ -120,7 +122,7 @@ endif(NOT WIN32) if(WITH_GPU) if(NOT WIN32) - if (USE_TENSORRT) + if (USE_TENSORRT) set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) endif() diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 6e682b69583e00ab1bbe1c0d22e21ae114a61a76..340e84d9312c20e2d10eb4c0a13066511d5d2eb5 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -16,7 +16,7 @@ if [ $2 == ON ]; then fi if [ $3 == ON ]; then use_gpu_list='true false' -else +else use_gpu_list='false' fi @@ -60,7 +60,8 @@ for WITH_STATIC_LIB in ON OFF; do -DWITH_MKL=$TURN_ON_MKL \ -DDEMO_NAME=simple_on_word2vec \ -DWITH_GPU=$TEST_GPU_CPU \ - -DWITH_STATIC_LIB=$WITH_STATIC_LIB + -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ + -DON_INFER=ON make -j word2vec_model=${PADDLE_ROOT}'/build/python/paddle/fluid/tests/book/word2vec.inference.model' if [ -d $word2vec_model ]; then @@ -80,10 +81,11 @@ for WITH_STATIC_LIB in ON OFF; do -DWITH_MKL=$TURN_ON_MKL \ -DDEMO_NAME=vis_demo \ -DWITH_GPU=$TEST_GPU_CPU \ - -DWITH_STATIC_LIB=$WITH_STATIC_LIB + -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ + -DON_INFER=ON make -j for use_gpu in $use_gpu_list; do - for vis_demo_name in $vis_demo_list; do + for vis_demo_name in $vis_demo_list; do ./vis_demo \ --modeldir=$DATA_DIR/$vis_demo_name/model \ --data=$DATA_DIR/$vis_demo_name/data.txt \ @@ -95,7 +97,7 @@ for WITH_STATIC_LIB in ON OFF; do fi done done - + # --------tensorrt mobilenet------ if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then rm -rf * @@ -106,8 +108,9 @@ for WITH_STATIC_LIB in ON OFF; do -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ -DUSE_TENSORRT=$USE_TENSORRT \ -DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \ - -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR - make -j + -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR \ + -DON_INFER=ON + make -j ./trt_mobilenet_demo \ --modeldir=$DATA_DIR/mobilenet/model \ --data=$DATA_DIR/mobilenet/data.txt \ diff --git a/paddle/fluid/inference/api/details/reset_tensor_array.cc b/paddle/fluid/inference/api/details/reset_tensor_array.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ae6c6dc9f44650c1c62f5be5448864d817513b1 --- /dev/null +++ b/paddle/fluid/inference/api/details/reset_tensor_array.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" + +namespace paddle { +namespace details { + +// Should be called after the parameters are loaded. +void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) { + if (flag_) { + for (auto &var_name : scope->LocalVarNames()) { + auto *var = scope->FindVar(var_name); + // TODO(Superjomn) should avoid the case when a TensorArray is a + // parameter. + if (var_name == "feed" || var_name == "fetch") continue; + if (var->Type() == typeid(framework::LoDTensorArray)) { + VLOG(4) << "collect " << var_name; + arrays_.push_back(var->GetMutable()); + } + } + for (auto *kid : scope->kids()) { + CollectTensorArrays(kid); + } + + VLOG(3) << "Collect " << arrays_.size() << " arrays"; + flag_ = false; + } +} + +// Should be called when `Run` finished. +void TensorArrayBatchCleaner::ResetTensorArray() { + for (auto *arr : arrays_) { + arr->clear(); + } +} + +} // namespace details +} // namespace paddle diff --git a/paddle/fluid/inference/api/details/reset_tensor_array.h b/paddle/fluid/inference/api/details/reset_tensor_array.h new file mode 100644 index 0000000000000000000000000000000000000000..a39449ff0e67786815dfb8d2d30d79dcdba757d7 --- /dev/null +++ b/paddle/fluid/inference/api/details/reset_tensor_array.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace details { + +// Clean the TensorArray each batch to make the behavior the same with the +// training phase. +struct TensorArrayBatchCleaner { + // Fix the tensor array not clear in the inference scenarios. + void CollectTensorArrays(framework::Scope *scope); + void ResetTensorArray(); + + private: + bool flag_{true}; + std::vector arrays_; +}; + +} // namespace details +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 07ee6e72d1053d2271b8f8d69ce38003f5e038a0..a755ccb93bdee018dfeaf91157e7971b4d4cd832 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -124,7 +124,7 @@ class ZeroCopyTensor { std::vector> lod() const; protected: - ZeroCopyTensor(void* scope) : scope_{scope} {} + explicit ZeroCopyTensor(void* scope) : scope_{scope} {} void SetName(const std::string& name) { name_ = name; } void* FindTensor() const; @@ -259,12 +259,6 @@ struct AnalysisConfig : public NativeConfig { kExclude // Specify the disabled passes in `ir_passes`. }; - void SetIncludeMode() { - ir_mode = IrPassMode::kInclude; - // this pass has to be run at the beginning of all fuse passes - ir_passes = {"infer_clean_graph_pass"}; - } - // Determine whether to perform graph optimization. bool enable_ir_optim = true; // Manually determine the IR passes to run. diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc index 6399476680c0af83a6d26aea952c58543bdce9ae..e0416ff953b61f56a2ca1a45cb382d40a6cffa4a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc @@ -228,6 +228,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_rnn1, profile) { contrib::AnalysisConfig cfg; SetConfig(&cfg); + cfg.use_gpu = false; std::vector outputs; std::vector> input_slots_all; diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 78ef6f207eadea6799864fe22889103b468d1780..0d51cb92618170cb422cb49ba63ba54ae6608ef4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND) else() set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) endif() +op_library(hash_op DEPS xxhash) op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index b6cb935814e25b31d4104f9ce24fe952680cb491..0d32cae0e1e5ff274793df50e854283d8e2f7bf8 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -79,6 +79,9 @@ struct BeamSearchDecodeFunctor { bool tensor_on_gpu_; size_t beam_size_; int end_id_; + // TODO(Superjomn) Here might result serious performance issue in the + // concurrency + // scenarios. const LoDTensorArray& step_ids_origin_; const LoDTensorArray& step_scores_origin_; LoDTensorArray step_ids_ = LoDTensorArray(); diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index a69d9c9a529f26b3981ca8d1ba226fda71b8820a..709c2dfc4b7c67d7d04074c58ce6da85b6e790fe 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, selected_indices.push_back(idx); ++selected_num; } - sorted_indices.erase(sorted_indices.end()); + sorted_indices.erase(sorted_indices.end() - 1); if (flag && eta < 1 && adaptive_threshold > 0.5) { adaptive_threshold *= eta; } diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 07322e720f26213ea777be3cd22f2fead28507f0..3c28ef30922e6d6ba09b96282619eef15867631e 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/dropout_op.h" +#include namespace paddle { namespace operators { @@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { "will be dropped.") .SetDefault(false); AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * dropout_prob" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_prob )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string& type) { + PADDLE_ENFORCE( + type == "downgrade_in_infer" || type == "upscale_in_train", + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train"); + }); AddComment(R"DOC( Dropout Operator. @@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); REGISTER_OP_CPU_KERNEL( - dropout, ops::CPUDropoutKernel); + dropout, ops::CPUDropoutKernel, + ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, - ops::DropoutGradKernel); + ops::DropoutGradKernel, + ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 1dd66e0280c46c0624ff70e822cb6fa6f06b7aa9..e011f47e086183a4ef3a3373c17acd6c21b6cf7e 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/float16.h" @@ -26,7 +27,8 @@ namespace operators { template __global__ void RandomGenerator(const size_t n, const int seed, const float dropout_prob, const T* src, - T* mask_data, T* dst) { + T* mask_data, T* dst, + bool is_upscale_in_train) { thrust::minstd_rand rng; rng.seed(seed); thrust::uniform_real_distribution dist(0, 1); @@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed, if (dist(rng) < dropout_prob) { mask = static_cast(0); } else { - mask = static_cast(1); + if (is_upscale_in_train) { + mask = static_cast(1.0f / (1.0f - dropout_prob)); + } else { + mask = static_cast(1); + } } dest = s * mask; mask_data[idx] = mask; @@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel { y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); + auto dropout_implementation = + context.Attr("dropout_implementation"); auto& place = *context.template device_context().eigen_device(); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); @@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel { int grid = (x->numel() + threads - 1) / threads; RandomGenerator< T><<>>( - size, seed, dropout_prob, x_data, mask_data, y_data); + size, seed, dropout_prob, x_data, mask_data, y_data, + (dropout_implementation == "upscale_in_train")); } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); - Y.device(place) = X * static_cast(1.0f - dropout_prob); + if (dropout_implementation == "upscale_in_train") { + Y.device(place) = X; + } else { + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } } } }; @@ -99,6 +112,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( dropout, ops::GPUDropoutKernel, - ops::GPUDropoutKernel); -REGISTER_OP_CUDA_KERNEL(dropout_grad, - ops::DropoutGradKernel); + ops::GPUDropoutKernel, + ops::GPUDropoutKernel); +REGISTER_OP_CUDA_KERNEL( + dropout_grad, ops::DropoutGradKernel, + ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 0628b4b826d2730a8e3fb4842e4ae550b8c00569..6c629b7b6d255828023ed25680675ca104a33e12 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel { auto* y_data = y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); + auto dropout_implementation = + context.Attr("dropout_implementation"); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); @@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel { engine.seed(seed); std::uniform_real_distribution dist(0, 1); + size_t size = framework::product(mask->dims()); for (size_t i = 0; i < size; ++i) { if (dist(engine) < dropout_prob) { mask_data[i] = 0; y_data[i] = 0; } else { - mask_data[i] = 1; - y_data[i] = x_data[i]; + if (dropout_implementation == "upscale_in_train") { + mask_data[i] = 1.0f / static_cast(1.0f - dropout_prob); + y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); + } else { + mask_data[i] = 1; + y_data[i] = x_data[i]; + } } } } else { @@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel { auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); - Y.device(place) = X * (1.0f - dropout_prob); + if (dropout_implementation == "upscale_in_train") { + Y.device(place) = X; + } else { + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } } } }; diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9ebe71a3d7ae270a10a45f4805652415078b363 --- /dev/null +++ b/paddle/fluid/operators/hash_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/hash_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class HashOp : public framework::OperatorWithKernel { + public: + HashOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of HashOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of HashOp should not be null."); + + auto dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(dims.size(), 2UL, + "The input of hash_op's dimensions must be 2"); + std::vector out_dims; + out_dims.reserve(dims.size() + 1); + // copy all dims except the last one + for (size_t i = 0u; i != dims.size() - 1; ++i) { + out_dims.emplace_back(dims[i]); + } + int num_hash = ctx->Attrs().Get("num_hash"); + out_dims.emplace_back(num_hash); + // keep the last dim to 1 + out_dims.emplace_back(1); + + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class HashOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input tensor of scale operator."); + AddOutput("Out", "(Tensor) Output tensor of scale operator."); + AddComment(R"DOC( +**Hash Operator** +$$Out = scale * X$$ +)DOC"); + AddAttr("num_hash", "").SetDefault(1); + AddAttr("mod_by", "").SetDefault(100000); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker); +REGISTER_OP_CPU_KERNEL(hash, ops::HashKerel, ops::HashKerel); diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9781bb0f453642cefb3eb59a05389c339a7de39d --- /dev/null +++ b/paddle/fluid/operators/hash_op.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +extern "C" { +#include +} +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +// template +template +class HashKerel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + auto* out_t = context.Output("Out"); + auto* in_t = context.Input("X"); + int mod_by = context.Attr("mod_by"); + int num_hash = context.Attr("num_hash"); + auto* output = out_t->mutable_data(context.GetPlace()); + + auto in_dims = in_t->dims(); + auto in_lod = in_t->lod(); + PADDLE_ENFORCE_EQ( + static_cast(in_dims[0]), in_lod[0].back(), + "The actual input data's size mismatched with LoD information."); + + auto seq_length = in_dims[0]; + auto last_dim = in_dims[in_dims.size() - 1]; + auto* input = in_t->data(); + for (int idx = 0; idx < seq_length; ++idx) { + for (int ihash = 0; ihash != num_hash; ++ihash) { + output[idx * num_hash + ihash] = + XXH64(input, sizeof(int) * last_dim, ihash) % mod_by; + } + input += last_dim; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index b9ac54e446811889b647397ae1fbb11c28f46777..d7f6cd5ab0acd2b677a3e5bd51bbcffe82eb1e50 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -81,6 +81,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { "Otherwise the given value indicates padding the output " "with zeros whenever lookup encounters it in Ids.") .SetDefault(kNoPadding); + // NOTE(minqiyang): grad_inplace is an temporal attribute, + // please do NOT set this attribute in python layer. + AddAttr("grad_inplace", + "(boolean, default false) " + "If the grad op reuse the input's variable.") + .SetDefault(false); AddComment(R"DOC( Lookup Table Operator. diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 58463dc4d6fd7cc3454de766814a947fee161070..e504c4f0cd5c0feaef4a251fad57b389a10a2ce7 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { @@ -68,6 +69,7 @@ class LookupTableKernel : public framework::OpKernel { const auto *table = table_t.value().data(); auto *output = output_t->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); for (int64_t i = 0; i < ids_numel; ++i) { if (padding_idx != kNoPadding && ids[i] == padding_idx) { memset(output + i * row_width, 0, row_width * sizeof(T)); @@ -75,8 +77,8 @@ class LookupTableKernel : public framework::OpKernel { PADDLE_ENFORCE_GE(ids[i], 0); auto id_index = table_t.Index(ids[i]); PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); - memcpy(output + i * row_width, table + id_index * row_width, - row_width * sizeof(T)); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); } } } @@ -111,27 +113,37 @@ class LookupTableGradKernel : public framework::OpKernel { auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); - framework::Vector new_rows; - new_rows.reserve(ids_num); - for (int64_t i = 0; i < ids_num; i++) { - new_rows.push_back(ids_data[i]); - } + std::vector new_rows; + new_rows.resize(ids_num); + std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table_dim[1]}); - d_table_value->mutable_data(context.GetPlace()); - - d_table->set_height(table_dim[0]); - - auto *d_output_data = d_output->data(); - auto *d_table_data = d_table_value->data(); - - auto d_output_dims = d_output->dims(); - PADDLE_ENFORCE_EQ( - d_table_value->dims(), - framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); - memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + // FIXME(minqiyang): + // memory optimization will NOT reuse Tensor with SelectedRows + // so we could just share the tensor here directly. + // However, the InferVarType method will infer the output SelectedRows + // to Tensor sometimes, which is a bug, so we will add an attribute + // here to indicate the inplace and remove this attribute after + // the InferVarType's bug was fixed + bool grad_inplace = context.Attr("grad_inplace"); + if (grad_inplace) { + d_table_value->ShareDataWith(*d_output); + } else { + d_table_value->mutable_data(context.GetPlace()); + + d_table->set_height(table_dim[0]); + + auto *d_output_data = d_output->data(); + auto *d_table_data = d_table_value->data(); + + auto d_output_dims = d_output->dims(); + PADDLE_ENFORCE_EQ( + d_table_value->dims(), + framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } } else { auto *ids = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h index 262469beea7449eb5820b86de1ac4f790a833e79..2e75b6abce5e1f43742ee15bff1dac4801186cd4 100644 --- a/paddle/fluid/operators/math/algorithm.h +++ b/paddle/fluid/operators/math/algorithm.h @@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { return -1; } +template +HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) { +#ifdef __CUDA_ARCH__ + // The following code is from + // https://en.cppreference.com/w/cpp/algorithm/lower_bound + auto *first = x; + int64_t count = static_cast(num); + while (count > 0) { + int64_t step = (count >> 1); + auto *it = first + step; + if (*it < val) { + first = ++it; + count -= (step + 1); + } else { + count = step; + } + } + return static_cast(first - x); +#else + return static_cast(std::lower_bound(x, x + num, val) - x); +#endif +} + +template +HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) { +#ifdef __CUDA_ARCH__ + // The following code is from + // https://en.cppreference.com/w/cpp/algorithm/upper_bound + auto *first = x; + int64_t count = static_cast(num); + while (count > 0) { + auto step = (count >> 1); + auto *it = first + step; + if (val < *it) { + count = step; + } else { + first = ++it; + count -= (step + 1); + } + } + return static_cast(first - x); +#else + return static_cast(std::upper_bound(x, x + num, val) - x); +#endif +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/sequence_reverse_op.cc b/paddle/fluid/operators/sequence_reverse_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1428cca1a6bf6150594f9cb72dbf00cd0eff7df5 --- /dev/null +++ b/paddle/fluid/operators/sequence_reverse_op.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_reverse_op.h" + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(sequence_reverse, ops::SequenceReverseOp, + ops::SequenceReverseOpMaker, + ops::SequenceReverseGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL( + sequence_reverse, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel); diff --git a/paddle/fluid/operators/sequence_reverse_op.cu b/paddle/fluid/operators/sequence_reverse_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ce65f4799e8661adca60d212eaa9c3f0f92c4c29 --- /dev/null +++ b/paddle/fluid/operators/sequence_reverse_op.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_reverse_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + sequence_reverse, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel); diff --git a/paddle/fluid/operators/sequence_reverse_op.h b/paddle/fluid/operators/sequence_reverse_op.h new file mode 100644 index 0000000000000000000000000000000000000000..39dad2311b2bcf29f808723caf7bfaef4c88cef2 --- /dev/null +++ b/paddle/fluid/operators/sequence_reverse_op.h @@ -0,0 +1,157 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +class SequenceReverseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist"); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist"); + + auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dim.size(), 2, + "Rank of Input(X) must be not less than 2."); + + ctx->SetOutputDim("Y", x_dim); + ctx->ShareLoD("X", "Y"); + } +}; + +class SequenceReverseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input LoDTensor of sequence_reverse op."); + AddOutput("Y", "The output LoDTensor of sequence_reverse op."); + AddComment(R"DOC( +SequenceReverse Operator. + +Reverse each sequence in input X along dim 0. + +Assuming X is a LoDTensor with dims [5, 4] and lod [[0, 2, 5]], where: + +X.data() = [ + [1, 2, 3, 4], + [5, 6, 7, 8], # the 0-th sequence with length 2 + [9, 10, 11, 12], + [13, 14, 15, 16], + [17, 18, 19, 20] # the 1-st sequence with length 3 +] + +The output Y would be a LoDTensor sharing the same dims and lod with input X, +and: + +Y.data() = [ + [5, 6, 7, 8], + [1, 2, 3, 4], # the reversed 0-th sequence with length 2 + [17, 18, 19, 20], + [13, 14, 15, 16], + [9, 10, 11, 12] # the reversed 1-st sequence with length 3 +] + +This Operator is useful to build a reverse dynamic RNN network. + +This Operator only supports one-level lod currently. + )DOC"); + } +}; + +template +struct SequenceReverseFunctor { + SequenceReverseFunctor(const T *x, T *y, const size_t *lod, size_t lod_count, + size_t row_numel) + : x_(x), y_(y), lod_(lod), lod_count_(lod_count), row_numel_(row_numel) {} + + HOSTDEVICE void operator()(size_t idx_x) const { + auto row_idx_x = idx_x / row_numel_; + auto lod_idx = math::UpperBound(lod_, lod_count_, row_idx_x); + auto row_idx_y = lod_[lod_idx - 1] + (lod_[lod_idx] - 1 - row_idx_x); + auto idx_y = row_idx_y * row_numel_ + idx_x % row_numel_; + y_[idx_y] = x_[idx_x]; + } + + const T *x_; + T *y_; + const size_t *lod_; + size_t lod_count_; + size_t row_numel_; +}; + +template +class SequenceReverseOpKernel : public framework::OpKernel { + using LoDTensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &x = *ctx.Input("X"); + auto *y = ctx.Output("Y"); + + PADDLE_ENFORCE_EQ(x.lod().size(), 1, + "SequenceReverse Op only support one level lod."); + + auto &dev_ctx = ctx.template device_context(); + const size_t *lod; + size_t lod_count = x.lod()[0].size(); + +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + lod = x.lod()[0].CUDAData(ctx.GetPlace()); + } else { +#endif + lod = x.lod()[0].data(); +#ifdef PADDLE_WITH_CUDA + } +#endif + + size_t limit = static_cast(x.numel()); + size_t row_numel = static_cast(limit / x.dims()[0]); + auto *x_data = x.data(); + auto *y_data = y->mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_NE(x_data, y_data, + "SequenceReverse Op does not support in-place operation"); + + SequenceReverseFunctor functor(x_data, y_data, lod, lod_count, + row_numel); + platform::ForRange for_range(dev_ctx, limit); + for_range(functor); + } +}; + +class SequenceReverseGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_reverse"); + op->SetInput("X", OutputGrad("Y")); + op->SetOutput("Y", InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 2bdb23e999621b10799b5163f326bc4b66a437e6..f6e241af0634650f4a32be6a4547617f8ec3ee60 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -76,6 +76,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, + ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel); REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel); + ops::SoftmaxGradCUDNNKernel, + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 6a9fc6611a8f8eaa6749aefac0673ccabaebbcfe..bbd71db6062107f6ba40343c84d942b54b3958e6 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker, REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); REGISTER_OP_CPU_KERNEL( - transpose, ops::TransposeKernel); + transpose, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CPU_KERNEL( transpose_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, ops::Transpose2GradMaker); REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad); REGISTER_OP_CPU_KERNEL( - transpose2, - ops::TransposeKernel); + transpose2, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CPU_KERNEL( transpose2_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); diff --git a/paddle/fluid/operators/transpose_op.cu.cc b/paddle/fluid/operators/transpose_op.cu.cc index c1b5a8b31be243fab3af06a18c8e51986c953700..b4025350fa9f3610bde43eee91cd059f3063813f 100644 --- a/paddle/fluid/operators/transpose_op.cu.cc +++ b/paddle/fluid/operators/transpose_op.cu.cc @@ -16,15 +16,18 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - transpose, - ops::TransposeKernel); + transpose, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CUDA_KERNEL( transpose_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); REGISTER_OP_CUDA_KERNEL( transpose2, - ops::TransposeKernel); + ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CUDA_KERNEL( transpose2_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); diff --git a/paddle/fluid/train/demo/CMakeLists.txt b/paddle/fluid/train/demo/CMakeLists.txt index 78d6e5ff554b9cd9facae85be166a697e0b75337..eabb51d370aff709e289e1fc727aa2dbb551d82e 100644 --- a/paddle/fluid/train/demo/CMakeLists.txt +++ b/paddle/fluid/train/demo/CMakeLists.txt @@ -15,6 +15,7 @@ include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") include_directories("${PADDLE_LIB}/third_party/install/zlib/include") @@ -27,6 +28,7 @@ link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") add_executable(demo_trainer demo_trainer.cc) @@ -62,5 +64,5 @@ target_link_libraries(demo_trainer ${ARCHIVE_END} ${MATH_LIB} ${MKLDNN_LIB} - glog gflags protobuf snappystream snappy z + glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 85493c10549c290330ed09b9f28accb7a980de6a..5a71382fb14b64989502c34d8ac0aa13c62bc7d0 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -95,9 +95,9 @@ function cmake_gen() { exit 1 fi fi - else + else if [ "$1" != "" ]; then - echo "using python abi: $1" + echo "using python abi: $1" if [ "$1" == "cp27-cp27m" ]; then export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} export PATH=/opt/python/cp27-cp27m/bin/:${PATH} @@ -119,7 +119,7 @@ function cmake_gen() { fi fi fi - + if [ "$SYSTEM" == "Darwin" ]; then WITH_DISTRIBUTE=${WITH_DISTRIBUTE:-ON} WITH_AVX=${WITH_AVX:-ON} @@ -127,7 +127,7 @@ function cmake_gen() { else INFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo} fi - + cat < 0.0 and tot_neg > 0.0 else 0.0 + + +class DetectionMAP(object): + """ + Calculate the detection mean average precision (mAP). + + The general steps are as follows: + 1. calculate the true positive and false positive according to the input + of detection and labels. + 2. calculate mAP value, support two versions: '11 point' and 'integral'. + + Please get more information from the following articles: + https://sanchom.wordpress.com/tag/average-precision/ + https://arxiv.org/abs/1512.02325 + + Args: + input (Variable): The detection results, which is a LoDTensor with shape + [M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax]. + gt_label (Variable): The ground truth label index, which is a LoDTensor + with shape [N, 1]. + gt_box (Variable): The ground truth bounding box (bbox), which is a + LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax]. + gt_difficult (Variable|None): Whether this ground truth is a difficult + bounding bbox, which can be a LoDTensor [N, 1] or not set. If None, + it means all the ground truth labels are not difficult bbox. + class_num (int): The class number. + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all categories will be + considered, 0 by defalut. + overlap_threshold (float): The threshold for deciding true/false + positive, 0.5 by defalut. + evaluate_difficult (bool): Whether to consider difficult ground truth + for evaluation, True by defalut. This argument does not work when + gt_difficult is None. + ap_version (string): The average precision calculation ways, it must be + 'integral' or '11point'. Please check + https://sanchom.wordpress.com/tag/average-precision/ for details. + - 11point: the 11-point interpolated average precision. + - integral: the natural integral of the precision-recall curve. + + Examples: + .. code-block:: python + + exe = fluid.Executor(place) + map_evaluator = fluid.Evaluator.DetectionMAP(input, + gt_label, gt_box, gt_difficult) + cur_map, accum_map = map_evaluator.get_map_var() + fetch = [cost, cur_map, accum_map] + for epoch in PASS_NUM: + map_evaluator.reset(exe) + for data in batches: + loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch) + + In the above example: + + 'cur_map_v' is the mAP of current mini-batch. + 'accum_map_v' is the accumulative mAP of one pass. + """ + + def __init__(self, + input, + gt_label, + gt_box, + gt_difficult=None, + class_num=None, + background_label=0, + overlap_threshold=0.5, + evaluate_difficult=True, + ap_version='integral'): + + self.helper = LayerHelper('map_eval') + gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype) + if gt_difficult: + gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype) + label = layers.concat([gt_label, gt_difficult, gt_box], axis=1) + else: + label = layers.concat([gt_label, gt_box], axis=1) + + # calculate mean average precision (mAP) of current mini-batch + map = layers.detection_map( + input, + label, + class_num, + background_label, + overlap_threshold=overlap_threshold, + evaluate_difficult=evaluate_difficult, + ap_version=ap_version) + + states = [] + states.append( + self._create_state( + dtype='int32', shape=None, suffix='accum_pos_count')) + states.append( + self._create_state( + dtype='float32', shape=None, suffix='accum_true_pos')) + states.append( + self._create_state( + dtype='float32', shape=None, suffix='accum_false_pos')) + var = self._create_state(dtype='int32', shape=[1], suffix='has_state') + self.helper.set_variable_initializer( + var, initializer=Constant(value=int(0))) + self.has_state = var + + # calculate accumulative mAP + accum_map = layers.detection_map( + input, + label, + class_num, + background_label, + overlap_threshold=overlap_threshold, + evaluate_difficult=evaluate_difficult, + has_state=self.has_state, + input_states=states, + out_states=states, + ap_version=ap_version) + + layers.fill_constant( + shape=self.has_state.shape, + value=1, + dtype=self.has_state.dtype, + out=self.has_state) + + self.cur_map = map + self.accum_map = accum_map + + def _create_state(self, suffix, dtype, shape): + """ + Create state variable. + Args: + suffix(str): the state suffix. + dtype(str|core.VarDesc.VarType): the state data type + shape(tuple|list): the shape of state + Returns: State variable + """ + state = self.helper.create_variable( + name="_".join([unique_name.generate(self.helper.name), suffix]), + persistable=True, + dtype=dtype, + shape=shape) + return state + + def get_map_var(self): + """ + Returns: mAP variable of current mini-batch and + accumulative mAP variable cross mini-batches. + """ + return self.cur_map, self.accum_map + + def reset(self, executor, reset_program=None): + """ + Reset metric states at the begin of each pass/user specified batch. + + Args: + executor(Executor): a executor for executing + the reset_program. + reset_program(Program|None): a single Program for reset process. + If None, will create a Program. + """ + + def _clone_var_(block, var): + assert isinstance(var, Variable) + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=var.persistable) + + if reset_program is None: + reset_program = Program() + with program_guard(main_program=reset_program): + var = _clone_var_(reset_program.current_block(), self.has_state) + layers.fill_constant( + shape=var.shape, value=0, dtype=var.dtype, out=var) + executor.run(reset_program) diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index ab44954811562b8f74e368a551e855948f90af87..27c67edf4f62dd3c5d396826348f8da4513667ba 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -1159,6 +1159,7 @@ def prepare_encoder(src_word, name=pos_enc_param_name, trainable=False, initializer=fluid.initializer.ConstantInitializer(0.001))) + src_pos_enc.stop_gradient = True enc_input = src_word_emb + src_pos_enc return layers.dropout( enc_input, diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py index 3575fd07fc727bd6c6b07a19a60b1df6656ae9e2..390393e04f8a1ff7b994da66cf1fa104ccb61793 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -23,9 +23,8 @@ class TestDistCTR2x2(TestDistBase): self._sync_mode = True self._enforce_place = "CPU" - -def test_dist_ctr(self): - self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) + def test_dist_ctr(self): + self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 94b66a40233be4378e1a003f01d9375d00794743..f65dd7e2a28c4ace3988c0cc1267ebe981fbd9cb 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -40,8 +40,7 @@ class TestDistMnistAsync(TestDistBase): self._sync_mode = False self._use_reduce = False - # FIXME(typhoonzero): fix async mode test later - def no_test_dist_train(self): + def test_dist_train(self): self.check_with_place("dist_mnist.py", delta=200) diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index c1e60dc9e420d11677468e0c62357437ecdf9e35..c0989ca709e100d8f147a08970b0e858c81ce09b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -40,8 +40,7 @@ class TestDistSeResneXt2x2Async(TestDistBase): self._sync_mode = False self._use_reader_alloc = False - #FIXME(typhoonzero): fix async mode later - def no_test_dist_train(self): + def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=100) diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py index e1e6ef61090dfb439a3b43c4baf5ba88f61310ba..fcf793da07eb8985fb19c9e36ff9b7d5a8f51212 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -79,8 +79,7 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): self._sync_mode = False self._enforce_place = "CPU" - #FIXME(typhoonzero): fix async tests later - def no_test_simnet_bow(self): + def test_simnet_bow(self): need_envs = { "IS_DISTRIBUTED": '0', "IS_SPARSE": '1', diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 0296bc2af4e0b79478c34b4cceab32b5a8a50f2f..be3c5f3b9558ec522803ed9a5acedea75cda6ccc 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest): self.check_output() +class TestDropoutOp6(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = { + 'dropout_prob': 1.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': np.zeros((32, 64)).astype('float32'), + 'Mask': np.zeros((32, 64)).astype('float32') + } + + +class TestDropoutOp7(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64, 2)).astype('float32') + } + + +class TestDropoutOp8(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.35, + 'fix_seed': True, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = {'Out': self.inputs['X']} + + def test_check_output(self): + self.check_output() + + +class TestDropoutOp9(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.75, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = {'Out': self.inputs['X']} + + def test_check_output(self): + self.check_output() + + class TestFP16DropoutOp(OpTest): def setUp(self): self.op_type = "dropout" diff --git a/python/paddle/fluid/tests/unittests/test_hash_op.py b/python/paddle/fluid/tests/unittests/test_hash_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1130ea39c42204283885ab1072a52db8c22f8b2e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_hash_op.py @@ -0,0 +1,57 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestScaleOp(OpTest): + def setUp(self): + self.op_type = "hash" + self.init_test_case() + self.inputs = {'X': (self.in_seq, self.lod)} + self.attrs = {'num_hash': 4, 'mod_by': 10000} + self.outputs = {'Out': (self.out_seq, self.lod)} + + def init_test_case(self): + np.random.seed = 1 + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + # self.out_seq = np.ones([30, 4, 1], dtype=np.int32) + self.out_seq = [ + [[9662], [9217], [1129], [8487]], [[9662], [9217], [1129], [8487]], + [[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]], + [[9407], [6715], [6949], [8094]], [[8473], [694], [5142], [2479]], + [[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]], + [[4372], [9456], [8204], [6695]], [[6897], [3218], [2013], [1241]], + [[8473], [694], [5142], [2479]], [[4372], [9456], [8204], [6695]], + [[4372], [9456], [8204], [6695]], [[8473], [694], [5142], [2479]], + [[9407], [6715], [6949], [8094]], [[9369], [4525], [8935], [9210]], + [[4372], [9456], [8204], [6695]], [[4372], [9456], [8204], [6695]], + [[9369], [4525], [8935], [9210]], [[6897], [3218], [2013], [1241]], + [[9038], [7951], [5953], [8657]], [[9407], [6715], [6949], [8094]], + [[9662], [9217], [1129], [8487]], [[9369], [4525], [8935], [9210]], + [[9038], [7951], [5953], [8657]], [[9662], [9217], [1129], [8487]], + [[9369], [4525], [8935], [9210]], [[1719], [5986], [9919], [3421]], + [[4372], [9456], [8204], [6695]], [[9038], [7951], [5953], [8657]] + ] + self.out_seq = np.array(self.out_seq) + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_metrics.py b/python/paddle/fluid/tests/unittests/test_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..ec27884cae2b0462951f6597b1b83e58d1c8af5d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_metrics.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +from paddle.fluid.framework import Program, program_guard + + +class TestMetricsDetectionMap(unittest.TestCase): + def test_detection_map(self): + program = fluid.Program() + with program_guard(program): + detect_res = fluid.layers.data( + name='detect_res', + shape=[10, 6], + append_batch_size=False, + dtype='float32') + label = fluid.layers.data( + name='label', + shape=[10, 1], + append_batch_size=False, + dtype='float32') + box = fluid.layers.data( + name='bbox', + shape=[10, 4], + append_batch_size=False, + dtype='float32') + map_eval = fluid.metrics.DetectionMAP( + detect_res, label, box, class_num=21) + cur_map, accm_map = map_eval.get_map_var() + self.assertIsNotNone(cur_map) + self.assertIsNotNone(accm_map) + print(str(program)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_reverse.py b/python/paddle/fluid/tests/unittests/test_sequence_reverse.py new file mode 100644 index 0000000000000000000000000000000000000000..eebd25e0975f1711ea86093f007212cadc6334f5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_reverse.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from op_test import OpTest +import numpy as np + + +class TestSequenceReverseBase(OpTest): + def initParameters(self): + pass + + def setUp(self): + self.size = (10, 3, 4) + self.lod = [2, 3, 5] + self.dtype = 'float32' + self.initParameters() + self.op_type = 'sequence_reverse' + self.x = np.random.random(self.size).astype(self.dtype) + self.y = self.get_output() + + self.inputs = {'X': (self.x, [self.lod, ]), } + self.outputs = {'Y': (self.y, [self.lod, ]), } + + def get_output(self): + tmp_x = np.reshape(self.x, newshape=[self.x.shape[0], -1]) + tmp_y = np.ndarray(tmp_x.shape).astype(self.dtype) + prev_idx = 0 + for cur_len in self.lod: + idx_range = range(prev_idx, prev_idx + cur_len) + tmp_y[idx_range, :] = np.flip(tmp_x[idx_range, :], 0) + prev_idx += cur_len + + return np.reshape(tmp_y, newshape=self.x.shape).astype(self.dtype) + + def test_output(self): + self.check_output(0) + + def test_grad(self): + self.check_grad(['X'], 'Y') + + +class TestSequenceReserve1(TestSequenceReverseBase): + def initParameters(self): + self.size = (12, 10) + self.lod = [4, 5, 3] + + +class TestSequenceReverse2(TestSequenceReverseBase): + def initParameters(self): + self.size = (12, 10) + self.lod = [12] + + +if __name__ == '__main__': + unittest.main()