diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index b4458eb9551724021636b628c5bf8c96f6e659aa..fb8c9ab96d372bde1fb4e1d86488cd5b831b93e0 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -26,7 +26,10 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) cc_test(variable_test SRCS variable_test.cc) -cc_library(scope SRCS scope.cc DEPS glog) +cc_library(threadpool SRCS threadpool.cc) +cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) + +cc_library(scope SRCS scope.cc DEPS glog threadpool) cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto) @@ -70,8 +73,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) -cc_library(threadpool SRCS threadpool.cc) -cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) + cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 0c01d605bcd95f5796fba1e5a3351a2640b2898a..4e80e3d974e2b646ad62d26991e4629f8c450578 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include // for unique_ptr #include // for call_once #include "glog/logging.h" +#include "paddle/framework/threadpool.h" #include "paddle/string/printf.h" namespace paddle { @@ -87,7 +88,8 @@ void Scope::DeleteScope(Scope* scope) { auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); this->kids_.erase(it); - delete scope; + // Make delete async. + Async([scope] { delete scope; }); } void Scope::Rename(const std::string& origin_name, diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h index bcd8190755083ec30687675602a1c95a9c15c69e..3ac345851c38557f82698786dd3bc8e1202a4256 100644 --- a/paddle/framework/threadpool.h +++ b/paddle/framework/threadpool.h @@ -29,7 +29,6 @@ namespace framework { class ThreadPool { public: typedef std::packaged_task Task; - typedef std::function Fun; /** * @brief Get a instance of threadpool, the thread number will @@ -67,7 +66,8 @@ class ThreadPool { * @return std::future, we could wait for the task finished by * f.wait(). */ - std::future Run(const Fun& fn) { + template + std::future Run(Callback fn) { std::unique_lock lock(mutex_); Task task(std::bind(fn)); std::future f = task.get_future(); @@ -159,5 +159,13 @@ class ThreadPool { std::condition_variable completed_; }; +// Run a function asynchronously. +// NOTE: The function must return void. If the function need to return a value, +// you can use lambda to capture a value pointer. +template +std::future Async(Callback callback) { + return ThreadPool::GetInstance()->Run(callback); +} + } // namespace framework } // namespace paddle