Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b60da672
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b60da672
编写于
2月 03, 2018
作者:
C
chengduo
提交者:
Abhinav Arora
2月 03, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine buffer channel (#8098)
* refine buffer channel * refine Receive and Send * follow comments
上级
022e5dee
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
23 addition
and
20 deletion
+23
-20
paddle/framework/channel.h
paddle/framework/channel.h
+2
-2
paddle/framework/details/buffered_channel.h
paddle/framework/details/buffered_channel.h
+11
-14
paddle/framework/details/unbuffered_channel.h
paddle/framework/details/unbuffered_channel.h
+10
-4
未找到文件。
paddle/framework/channel.h
浏览文件 @
b60da672
...
@@ -23,8 +23,8 @@ namespace framework {
...
@@ -23,8 +23,8 @@ namespace framework {
template
<
typename
T
>
template
<
typename
T
>
class
Channel
{
class
Channel
{
public:
public:
virtual
void
Send
(
T
*
)
=
0
;
virtual
bool
Send
(
T
*
)
=
0
;
virtual
void
Receive
(
T
*
)
=
0
;
virtual
bool
Receive
(
T
*
)
=
0
;
virtual
size_t
Cap
()
=
0
;
virtual
size_t
Cap
()
=
0
;
virtual
void
Close
()
=
0
;
virtual
void
Close
()
=
0
;
virtual
~
Channel
()
{}
virtual
~
Channel
()
{}
...
...
paddle/framework/details/buffered_channel.h
浏览文件 @
b60da672
...
@@ -30,8 +30,8 @@ class Buffered : public paddle::framework::Channel<T> {
...
@@ -30,8 +30,8 @@ class Buffered : public paddle::framework::Channel<T> {
friend
void
paddle
::
framework
::
CloseChannel
<
T
>
(
Channel
<
T
>*
);
friend
void
paddle
::
framework
::
CloseChannel
<
T
>
(
Channel
<
T
>*
);
public:
public:
virtual
void
Send
(
T
*
);
virtual
bool
Send
(
T
*
);
virtual
void
Receive
(
T
*
);
virtual
bool
Receive
(
T
*
);
virtual
size_t
Cap
()
{
return
cap_
;
}
virtual
size_t
Cap
()
{
return
cap_
;
}
virtual
void
Close
();
virtual
void
Close
();
virtual
~
Buffered
();
virtual
~
Buffered
();
...
@@ -48,33 +48,36 @@ class Buffered : public paddle::framework::Channel<T> {
...
@@ -48,33 +48,36 @@ class Buffered : public paddle::framework::Channel<T> {
PADDLE_ENFORCE_GT
(
cap
,
0
);
PADDLE_ENFORCE_GT
(
cap
,
0
);
}
}
void
NotifyAllSenders
(
std
::
unique_lock
<
std
::
mutex
>*
);
void
NotifyAllParticipants
(
std
::
unique_lock
<
std
::
mutex
>*
);
void
NotifyAllParticipants
(
std
::
unique_lock
<
std
::
mutex
>*
);
};
};
template
<
typename
T
>
template
<
typename
T
>
void
Buffered
<
T
>::
Send
(
T
*
item
)
{
bool
Buffered
<
T
>::
Send
(
T
*
item
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mu_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mu_
);
full_cond_var_
.
wait
(
lock
,
full_cond_var_
.
wait
(
lock
,
[
this
]()
{
return
channel_
.
size
()
<
cap_
||
closed_
;
});
[
this
]()
{
return
channel_
.
size
()
<
cap_
||
closed_
;
});
bool
ret
=
false
;
if
(
!
closed_
)
{
if
(
!
closed_
)
{
channel_
.
push_back
(
std
::
move
(
*
item
));
channel_
.
push_back
(
std
::
move
(
*
item
));
lock
.
unlock
();
lock
.
unlock
();
empty_cond_var_
.
notify_one
();
empty_cond_var_
.
notify_one
();
ret
=
true
;
}
}
return
ret
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
Buffered
<
T
>::
Receive
(
T
*
item
)
{
bool
Buffered
<
T
>::
Receive
(
T
*
item
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mu_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mu_
);
empty_cond_var_
.
wait
(
lock
,
[
this
]()
{
return
!
channel_
.
empty
()
||
closed_
;
});
empty_cond_var_
.
wait
(
lock
,
[
this
]()
{
return
!
channel_
.
empty
()
||
closed_
;
});
bool
ret
=
false
;
if
(
!
closed_
)
{
if
(
!
closed_
)
{
*
item
=
std
::
move
(
channel_
.
front
());
*
item
=
std
::
move
(
channel_
.
front
());
channel_
.
pop_front
();
channel_
.
pop_front
();
NotifyAllSenders
(
&
lock
);
full_cond_var_
.
notify_one
();
}
else
{
ret
=
true
;
item
=
nullptr
;
}
}
return
ret
;
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -92,12 +95,6 @@ Buffered<T>::~Buffered() {
...
@@ -92,12 +95,6 @@ Buffered<T>::~Buffered() {
NotifyAllParticipants
(
&
lock
);
NotifyAllParticipants
(
&
lock
);
}
}
template
<
typename
T
>
void
Buffered
<
T
>::
NotifyAllSenders
(
std
::
unique_lock
<
std
::
mutex
>*
lock
)
{
lock
->
unlock
();
full_cond_var_
.
notify_all
();
}
template
<
typename
T
>
template
<
typename
T
>
void
Buffered
<
T
>::
NotifyAllParticipants
(
std
::
unique_lock
<
std
::
mutex
>*
lock
)
{
void
Buffered
<
T
>::
NotifyAllParticipants
(
std
::
unique_lock
<
std
::
mutex
>*
lock
)
{
lock
->
unlock
();
lock
->
unlock
();
...
...
paddle/framework/details/unbuffered_channel.h
浏览文件 @
b60da672
...
@@ -29,8 +29,8 @@ class UnBuffered : public paddle::framework::Channel<T> {
...
@@ -29,8 +29,8 @@ class UnBuffered : public paddle::framework::Channel<T> {
friend
void
paddle
::
framework
::
CloseChannel
<
T
>
(
Channel
<
T
>*
);
friend
void
paddle
::
framework
::
CloseChannel
<
T
>
(
Channel
<
T
>*
);
public:
public:
virtual
void
Send
(
T
*
);
virtual
bool
Send
(
T
*
);
virtual
void
Receive
(
T
*
);
virtual
bool
Receive
(
T
*
);
virtual
size_t
Cap
()
{
return
0
;
}
virtual
size_t
Cap
()
{
return
0
;
}
virtual
void
Close
();
virtual
void
Close
();
virtual
~
UnBuffered
();
virtual
~
UnBuffered
();
...
@@ -57,7 +57,7 @@ class UnBuffered : public paddle::framework::Channel<T> {
...
@@ -57,7 +57,7 @@ class UnBuffered : public paddle::framework::Channel<T> {
// This function implements the concept of how data should
// This function implements the concept of how data should
// be sent from a writer to a reader.
// be sent from a writer to a reader.
template
<
typename
T
>
template
<
typename
T
>
void
UnBuffered
<
T
>::
Send
(
T
*
data
)
{
bool
UnBuffered
<
T
>::
Send
(
T
*
data
)
{
// Prevent other writers from entering
// Prevent other writers from entering
std
::
unique_lock
<
std
::
recursive_mutex
>
writer_lock
(
mu_write_
);
std
::
unique_lock
<
std
::
recursive_mutex
>
writer_lock
(
mu_write_
);
writer_found_
=
true
;
writer_found_
=
true
;
...
@@ -66,6 +66,7 @@ void UnBuffered<T>::Send(T* data) {
...
@@ -66,6 +66,7 @@ void UnBuffered<T>::Send(T* data) {
cv_writer_
.
wait
(
cv_lock
,
cv_writer_
.
wait
(
cv_lock
,
[
this
]()
{
return
reader_found_
==
true
||
closed_
;
});
[
this
]()
{
return
reader_found_
==
true
||
closed_
;
});
cv_reader_
.
notify_one
();
cv_reader_
.
notify_one
();
bool
ret
=
false
;
if
(
!
closed_
)
{
if
(
!
closed_
)
{
std
::
unique_lock
<
std
::
mutex
>
channel_lock
(
mu_ch_
);
std
::
unique_lock
<
std
::
mutex
>
channel_lock
(
mu_ch_
);
item
=
data
;
item
=
data
;
...
@@ -74,14 +75,16 @@ void UnBuffered<T>::Send(T* data) {
...
@@ -74,14 +75,16 @@ void UnBuffered<T>::Send(T* data) {
channel_lock
.
lock
();
channel_lock
.
lock
();
cv_channel_
.
wait
(
channel_lock
,
cv_channel_
.
wait
(
channel_lock
,
[
this
]()
{
return
item
==
nullptr
||
closed_
;
});
[
this
]()
{
return
item
==
nullptr
||
closed_
;
});
ret
=
true
;
}
}
writer_found_
=
false
;
writer_found_
=
false
;
return
ret
;
}
}
// This function implements the concept of how
// This function implements the concept of how
// data that was sent by a writer is read from a reader.
// data that was sent by a writer is read from a reader.
template
<
typename
T
>
template
<
typename
T
>
void
UnBuffered
<
T
>::
Receive
(
T
*
data
)
{
bool
UnBuffered
<
T
>::
Receive
(
T
*
data
)
{
// Prevent other readers from entering
// Prevent other readers from entering
std
::
unique_lock
<
std
::
recursive_mutex
>
read_lock
{
mu_read_
};
std
::
unique_lock
<
std
::
recursive_mutex
>
read_lock
{
mu_read_
};
reader_found_
=
true
;
reader_found_
=
true
;
...
@@ -90,6 +93,7 @@ void UnBuffered<T>::Receive(T* data) {
...
@@ -90,6 +93,7 @@ void UnBuffered<T>::Receive(T* data) {
cv_reader_
.
wait
(
cv_lock
,
cv_reader_
.
wait
(
cv_lock
,
[
this
]()
{
return
writer_found_
==
true
||
closed_
;
});
[
this
]()
{
return
writer_found_
==
true
||
closed_
;
});
cv_writer_
.
notify_one
();
cv_writer_
.
notify_one
();
bool
ret
=
false
;
if
(
!
closed_
)
{
if
(
!
closed_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock_ch
{
mu_ch_
};
std
::
unique_lock
<
std
::
mutex
>
lock_ch
{
mu_ch_
};
// Reader should wait for the writer to first write its data
// Reader should wait for the writer to first write its data
...
@@ -98,10 +102,12 @@ void UnBuffered<T>::Receive(T* data) {
...
@@ -98,10 +102,12 @@ void UnBuffered<T>::Receive(T* data) {
*
data
=
std
::
move
(
*
item
);
*
data
=
std
::
move
(
*
item
);
item
=
nullptr
;
item
=
nullptr
;
lock_ch
.
unlock
();
lock_ch
.
unlock
();
ret
=
true
;
}
}
cv_channel_
.
notify_one
();
cv_channel_
.
notify_one
();
}
}
reader_found_
=
false
;
reader_found_
=
false
;
return
ret
;
}
}
// This function implements the sequence of events
// This function implements the sequence of events
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录