Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5022ee63
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5022ee63
编写于
12月 28, 2017
作者:
Y
Yancey
提交者:
GitHub
12月 28, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ThreadPool::Run interface return std::future (#7099)
* Run interface return future * delete unused comments
上级
18311767
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
25 addition
and
13 deletion
+25
-13
paddle/framework/threadpool.h
paddle/framework/threadpool.h
+13
-6
paddle/framework/threadpool_test.cc
paddle/framework/threadpool_test.cc
+12
-7
未找到文件。
paddle/framework/threadpool.h
浏览文件 @
5022ee63
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <condition_variable>
#include <condition_variable>
#include <functional>
#include <functional>
#include <future>
#include <mutex>
#include <mutex>
#include <queue>
#include <queue>
#include <thread>
#include <thread>
...
@@ -25,10 +26,11 @@ limitations under the License. */
...
@@ -25,10 +26,11 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
typedef
std
::
function
<
void
()
>
Task
;
class
ThreadPool
{
class
ThreadPool
{
public:
public:
typedef
std
::
packaged_task
<
void
()
>
Task
;
typedef
std
::
function
<
void
()
>
Fun
;
/**
/**
* @brief Get a instance of threadpool, the thread number will
* @brief Get a instance of threadpool, the thread number will
* be specified as the number of hardware thread contexts
* be specified as the number of hardware thread contexts
...
@@ -61,13 +63,18 @@ class ThreadPool {
...
@@ -61,13 +63,18 @@ class ThreadPool {
/**
/**
* @brief Push a function to the queue, and will be scheduled and
* @brief Push a function to the queue, and will be scheduled and
* executed if a thread is available.
* executed if a thread is available.
* @param[in] Task will be pushed to the task queue.
* @param[in] Task, will be pushed to the task queue.
* @return std::future<void>, we could wait for the task finished by
* f.wait().
*/
*/
void
Run
(
const
Task
&
fn
)
{
std
::
future
<
void
>
Run
(
const
Fun
&
fn
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
tasks_
.
push
(
fn
);
Task
task
(
std
::
bind
(
fn
));
std
::
future
<
void
>
f
=
task
.
get_future
();
tasks_
.
push
(
std
::
move
(
task
));
lock
.
unlock
();
lock
.
unlock
();
scheduled_
.
notify_one
();
scheduled_
.
notify_one
();
return
f
;
}
}
/**
/**
...
@@ -110,7 +117,7 @@ class ThreadPool {
...
@@ -110,7 +117,7 @@ class ThreadPool {
break
;
break
;
}
}
// pop a task from the task queue
// pop a task from the task queue
auto
task
=
tasks_
.
front
(
);
auto
task
=
std
::
move
(
tasks_
.
front
()
);
tasks_
.
pop
();
tasks_
.
pop
();
--
available_
;
--
available_
;
...
...
paddle/framework/threadpool_test.cc
浏览文件 @
5022ee63
...
@@ -20,16 +20,21 @@ limitations under the License. */
...
@@ -20,16 +20,21 @@ limitations under the License. */
namespace
framework
=
paddle
::
framework
;
namespace
framework
=
paddle
::
framework
;
void
do_sum
(
framework
::
ThreadPool
*
pool
,
std
::
atomic
<
int
>&
sum
,
int
cnt
)
{
void
do_sum
(
framework
::
ThreadPool
*
pool
,
std
::
atomic
<
int
>&
sum
,
int
cnt
)
{
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
pool
->
Run
([
&
sum
]()
{
sum
.
fetch_add
(
1
);
});
auto
f
=
pool
->
Run
([
&
sum
]()
{
sum
.
fetch_add
(
1
);
});
fs
.
push_back
(
std
::
move
(
f
));
}
for
(
auto
&
f
:
fs
)
{
f
.
wait
();
}
}
}
}
TEST
(
ThreadPool
,
ConcurrentInit
)
{
TEST
(
ThreadPool
,
ConcurrentInit
)
{
framework
::
ThreadPool
*
pool
;
framework
::
ThreadPool
*
pool
;
int
concurrent_cnt
=
50
;
int
n
=
50
;
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
thread
>
threads
;
for
(
int
i
=
0
;
i
<
concurrent_cnt
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
std
::
thread
t
([
&
pool
]()
{
pool
=
framework
::
ThreadPool
::
GetInstance
();
});
std
::
thread
t
([
&
pool
]()
{
pool
=
framework
::
ThreadPool
::
GetInstance
();
});
threads
.
push_back
(
std
::
move
(
t
));
threads
.
push_back
(
std
::
move
(
t
));
}
}
...
@@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) {
...
@@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) {
}
}
}
}
TEST
(
ThreadPool
,
Concurrent
Start
)
{
TEST
(
ThreadPool
,
Concurrent
Run
)
{
framework
::
ThreadPool
*
pool
=
framework
::
ThreadPool
::
GetInstance
();
framework
::
ThreadPool
*
pool
=
framework
::
ThreadPool
::
GetInstance
();
std
::
atomic
<
int
>
sum
(
0
);
std
::
atomic
<
int
>
sum
(
0
);
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
thread
>
threads
;
int
concurrent_cnt
=
50
;
int
n
=
50
;
// sum = (n * (n + 1)) / 2
// sum = (n * (n + 1)) / 2
for
(
int
i
=
1
;
i
<=
concurrent_cnt
;
++
i
)
{
for
(
int
i
=
1
;
i
<=
n
;
++
i
)
{
std
::
thread
t
(
do_sum
,
pool
,
std
::
ref
(
sum
),
i
);
std
::
thread
t
(
do_sum
,
pool
,
std
::
ref
(
sum
),
i
);
threads
.
push_back
(
std
::
move
(
t
));
threads
.
push_back
(
std
::
move
(
t
));
}
}
...
@@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) {
...
@@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) {
t
.
join
();
t
.
join
();
}
}
pool
->
Wait
();
pool
->
Wait
();
EXPECT_EQ
(
sum
,
((
concurrent_cnt
+
1
)
*
concurrent_cnt
)
/
2
);
EXPECT_EQ
(
sum
,
((
n
+
1
)
*
n
)
/
2
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录