Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
41894da1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41894da1
编写于
3月 14, 2018
作者:
A
Abhinav Arora
提交者:
GitHub
3月 14, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add changes to channel that are needed for select op (#9084)
上级
a4b0e4a1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
255 addition
and
7 deletion
+255
-7
paddle/fluid/framework/channel.h
paddle/fluid/framework/channel.h
+109
-0
paddle/fluid/framework/channel_impl.h
paddle/fluid/framework/channel_impl.h
+146
-7
未找到文件。
paddle/fluid/framework/channel.h
浏览文件 @
41894da1
...
...
@@ -15,23 +15,43 @@ limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <condition_variable>
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
enum
class
ChannelAction
{
SEND
=
0
,
RECEIVE
=
1
,
CLOSE
=
2
,
};
// Channel is the abstract class of buffered and un-buffered channels.
template
<
typename
T
>
class
Channel
{
public:
virtual
bool
CanSend
()
=
0
;
virtual
bool
CanReceive
()
=
0
;
virtual
bool
Send
(
T
*
)
=
0
;
virtual
bool
Receive
(
T
*
)
=
0
;
virtual
size_t
Cap
()
=
0
;
virtual
void
Lock
()
=
0
;
virtual
void
Unlock
()
=
0
;
virtual
bool
IsClosed
()
=
0
;
virtual
void
Close
()
=
0
;
virtual
~
Channel
()
{}
virtual
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
=
0
;
virtual
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
=
0
;
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
)
=
0
;
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
=
0
;
};
// Forward declaration of channel implementations.
...
...
@@ -80,6 +100,27 @@ class ChannelHolder {
return
channel
!=
nullptr
?
channel
->
Receive
(
data
)
:
false
;
}
bool
IsClosed
()
{
if
(
IsInitialized
())
{
return
holder_
->
IsClosed
();
}
return
false
;
}
bool
CanSend
()
{
if
(
IsInitialized
())
{
return
holder_
->
CanSend
();
}
return
false
;
}
bool
CanReceive
()
{
if
(
IsInitialized
())
{
return
holder_
->
CanReceive
();
}
return
false
;
}
void
close
()
{
if
(
IsInitialized
())
holder_
->
Close
();
}
...
...
@@ -97,6 +138,50 @@ class ChannelHolder {
if
(
IsInitialized
())
holder_
->
Unlock
();
}
template
<
typename
T
>
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
if
(
IsInitialized
())
{
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
if
(
channel
!=
nullptr
)
{
channel
->
AddToSendQ
(
referrer
,
data
,
cond
,
cb
);
}
}
}
template
<
typename
T
>
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
if
(
IsInitialized
())
{
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
if
(
channel
!=
nullptr
)
{
channel
->
AddToReceiveQ
(
referrer
,
data
,
cond
,
cb
);
}
}
}
template
<
typename
T
>
void
RemoveFromSendQ
(
const
void
*
referrer
)
{
if
(
IsInitialized
())
{
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
if
(
channel
!=
nullptr
)
{
channel
->
RemoveFromSendQ
(
referrer
);
}
}
}
template
<
typename
T
>
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
{
if
(
IsInitialized
())
{
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
if
(
channel
!=
nullptr
)
{
channel
->
RemoveFromReceiveQ
(
referrer
);
}
}
}
inline
bool
IsInitialized
()
const
{
return
holder_
!=
nullptr
;
}
inline
const
std
::
type_index
Type
()
{
...
...
@@ -113,6 +198,9 @@ class ChannelHolder {
virtual
~
Placeholder
()
{}
virtual
const
std
::
type_index
Type
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
virtual
bool
IsClosed
()
=
0
;
virtual
bool
CanSend
()
=
0
;
virtual
bool
CanReceive
()
=
0
;
virtual
void
Close
()
=
0
;
virtual
void
Lock
()
=
0
;
virtual
void
Unlock
()
=
0
;
...
...
@@ -129,6 +217,27 @@ class ChannelHolder {
virtual
void
*
Ptr
()
const
{
return
static_cast
<
void
*>
(
channel_
.
get
());
}
virtual
bool
IsClosed
()
{
if
(
channel_
)
{
return
channel_
->
IsClosed
();
}
return
false
;
}
virtual
bool
CanSend
()
{
if
(
channel_
)
{
return
channel_
->
CanSend
();
}
return
false
;
}
virtual
bool
CanReceive
()
{
if
(
channel_
)
{
return
channel_
->
CanReceive
();
}
return
false
;
}
virtual
void
Close
()
{
if
(
channel_
)
channel_
->
Close
();
}
...
...
paddle/fluid/framework/channel_impl.h
浏览文件 @
41894da1
...
...
@@ -29,32 +29,50 @@ class ChannelImpl : public paddle::framework::Channel<T> {
friend
void
paddle
::
framework
::
CloseChannel
<
T
>
(
Channel
<
T
>
*
);
public:
virtual
bool
CanSend
();
virtual
bool
CanReceive
();
virtual
bool
Send
(
T
*
);
virtual
bool
Receive
(
T
*
);
virtual
size_t
Cap
()
{
return
cap_
;
}
virtual
void
Lock
();
virtual
void
Unlock
();
virtual
bool
IsClosed
();
virtual
void
Close
();
ChannelImpl
(
size_t
);
virtual
~
ChannelImpl
();
virtual
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
);
virtual
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
);
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
);
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
);
private:
struct
QueueMessage
{
T
*
data
;
std
::
condition_variable_any
cond
;
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
;
bool
chan_closed
=
false
;
bool
completed
=
false
;
const
void
*
referrer
;
// TODO(thuan): figure out better way to do this
std
::
function
<
bool
(
ChannelAction
)
>
callback
;
QueueMessage
(
T
*
item
)
:
data
(
item
)
{}
QueueMessage
(
T
*
item
)
:
data
(
item
),
cond
(
std
::
make_shared
<
std
::
condition_variable_any
>
())
{}
QueueMessage
(
T
*
item
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
)
:
data
(
item
),
cond
(
cond
)
{}
void
Wait
(
std
::
unique_lock
<
std
::
recursive_mutex
>
&
lock
)
{
cond
.
wait
(
lock
,
[
this
]()
{
return
completed
;
});
cond
->
wait
(
lock
,
[
this
]()
{
return
completed
;
});
}
void
Notify
()
{
completed
=
true
;
cond
.
notify_all
();
cond
->
notify_all
();
}
};
...
...
@@ -87,6 +105,18 @@ ChannelImpl<T>::ChannelImpl(size_t capacity)
PADDLE_ENFORCE_GE
(
capacity
,
0
);
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
CanSend
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
!
closed_
&&
(
!
recvq
.
empty
()
||
buf_
.
size
()
<
cap_
);
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
CanReceive
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
!
(
closed_
&&
buf_
.
empty
())
&&
(
!
sendq
.
empty
()
||
buf_
.
size
()
>
0
);
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
Send
(
T
*
item
)
{
send_ctr
++
;
...
...
@@ -105,7 +135,24 @@ bool ChannelImpl<T>::Send(T *item) {
std
::
shared_ptr
<
QueueMessage
>
m
=
recvq
.
front
();
recvq
.
pop_front
();
// Do the data transfer
*
(
m
->
data
)
=
std
::
move
(
*
item
);
// We will do this data transfer if either of the following
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool
do_send
=
true
;
if
(
m
->
callback
!=
nullptr
)
do_send
=
m
->
callback
(
ChannelAction
::
SEND
);
if
(
do_send
)
*
(
m
->
data
)
=
std
::
move
(
*
item
);
else
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Send function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return
Send
(
item
);
// Wake up the blocked process and unlock
m
->
Notify
();
lock
.
unlock
();
...
...
@@ -150,7 +197,25 @@ bool ChannelImpl<T>::Receive(T *item) {
std
::
shared_ptr
<
QueueMessage
>
m
=
sendq
.
front
();
sendq
.
pop_front
();
// Do the data transfer
*
item
=
std
::
move
(
*
(
m
->
data
));
// We will do this data transfer if either of the following
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool
do_receive
=
true
;
if
(
m
->
callback
!=
nullptr
)
do_receive
=
m
->
callback
(
ChannelAction
::
RECEIVE
);
if
(
do_receive
)
*
item
=
std
::
move
(
*
(
m
->
data
));
else
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Receive function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return
Receive
(
item
);
// Wake up the blocked process and unlock
m
->
Notify
();
lock
.
unlock
();
...
...
@@ -186,6 +251,12 @@ void ChannelImpl<T>::Unlock() {
mu_
.
unlock
();
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
IsClosed
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
closed_
;
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Close
()
{
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
...
...
@@ -203,6 +274,12 @@ void ChannelImpl<T>::Close() {
std
::
shared_ptr
<
QueueMessage
>
m
=
recvq
.
front
();
recvq
.
pop_front
();
m
->
chan_closed
=
true
;
// Execute callback function (if any)
if
(
m
->
callback
!=
nullptr
)
{
m
->
callback
(
ChannelAction
::
CLOSE
);
}
m
->
Notify
();
}
...
...
@@ -211,10 +288,72 @@ void ChannelImpl<T>::Close() {
std
::
shared_ptr
<
QueueMessage
>
m
=
sendq
.
front
();
sendq
.
pop_front
();
m
->
chan_closed
=
true
;
// Execute callback function (if any)
if
(
m
->
callback
!=
nullptr
)
{
m
->
callback
(
ChannelAction
::
CLOSE
);
}
m
->
Notify
();
}
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
data
,
cond
);
m
->
referrer
=
referrer
;
m
->
callback
=
cb
;
sendq
.
push_back
(
m
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
data
,
cond
);
m
->
referrer
=
referrer
;
m
->
callback
=
cb
;
recvq
.
push_back
(
m
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
RemoveFromSendQ
(
const
void
*
referrer
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
for
(
auto
it
=
sendq
.
begin
();
it
!=
sendq
.
end
();)
{
std
::
shared_ptr
<
QueueMessage
>
sendMsg
=
(
std
::
shared_ptr
<
QueueMessage
>
)
*
it
;
if
(
sendMsg
->
referrer
==
referrer
)
{
it
=
sendq
.
erase
(
it
);
send_ctr
--
;
}
else
{
++
it
;
}
}
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
RemoveFromReceiveQ
(
const
void
*
referrer
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
for
(
auto
it
=
recvq
.
begin
();
it
!=
recvq
.
end
();)
{
std
::
shared_ptr
<
QueueMessage
>
recvMsg
=
(
std
::
shared_ptr
<
QueueMessage
>
)
*
it
;
if
(
recvMsg
->
referrer
==
referrer
)
{
it
=
recvq
.
erase
(
it
);
recv_ctr
--
;
}
else
{
++
it
;
}
}
}
template
<
typename
T
>
ChannelImpl
<
T
>::~
ChannelImpl
()
{
Close
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录