未验证 提交 564dba17 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #7196 from reyoung/feature/async_drop_kid

Async to drop kid
......@@ -26,7 +26,10 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc DEPS glog)
cc_library(threadpool SRCS threadpool.cc)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto)
......@@ -70,8 +73,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_library(threadpool SRCS threadpool.cc)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init)
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include "glog/logging.h"
#include "paddle/framework/threadpool.h"
#include "paddle/string/printf.h"
namespace paddle {
......@@ -87,7 +88,8 @@ void Scope::DeleteScope(Scope* scope) {
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it);
delete scope;
// Make delete async.
Async([scope] { delete scope; });
}
void Scope::Rename(const std::string& origin_name,
......
......@@ -29,7 +29,6 @@ namespace framework {
class ThreadPool {
public:
typedef std::packaged_task<void()> Task;
typedef std::function<void()> Fun;
/**
* @brief Get a instance of threadpool, the thread number will
......@@ -67,7 +66,8 @@ class ThreadPool {
* @return std::future<void>, we could wait for the task finished by
* f.wait().
*/
std::future<void> Run(const Fun& fn) {
template <typename Callback>
std::future<void> Run(Callback fn) {
std::unique_lock<std::mutex> lock(mutex_);
Task task(std::bind(fn));
std::future<void> f = task.get_future();
......@@ -159,5 +159,13 @@ class ThreadPool {
std::condition_variable completed_;
};
// Run a function asynchronously.
// NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer.
template <typename Callback>
std::future<void> Async(Callback callback) {
return ThreadPool::GetInstance()->Run(callback);
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册