Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
425a8821
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
425a8821
编写于
9月 29, 2018
作者:
X
Xin Pan
提交者:
GitHub
9月 29, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13643 from panyx0718/ir2
clean up channel
上级
23644940
64290595
变更
31
显示空白变更内容
内联
并排
Showing
31 changed file
with
11 addition
and
3742 deletion
+11
-3742
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+0
-7
paddle/fluid/framework/channel.h
paddle/fluid/framework/channel.h
+0
-291
paddle/fluid/framework/channel_impl.h
paddle/fluid/framework/channel_impl.h
+0
-369
paddle/fluid/framework/channel_test.cc
paddle/fluid/framework/channel_test.cc
+0
-1008
paddle/fluid/framework/concurrency_test.cc
paddle/fluid/framework/concurrency_test.cc
+0
-292
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+1
-4
paddle/fluid/framework/framework.proto
paddle/fluid/framework/framework.proto
+0
-7
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+4
-4
paddle/fluid/framework/tuple.h
paddle/fluid/framework/tuple.h
+0
-1
paddle/fluid/framework/var_desc.cc
paddle/fluid/framework/var_desc.cc
+2
-52
paddle/fluid/framework/var_desc.h
paddle/fluid/framework/var_desc.h
+0
-4
paddle/fluid/framework/var_type.h
paddle/fluid/framework/var_type.h
+0
-6
paddle/fluid/inference/analysis/analysis_pass.h
paddle/fluid/inference/analysis/analysis_pass.h
+0
-6
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+0
-5
paddle/fluid/operators/channel_close_op.cc
paddle/fluid/operators/channel_close_op.cc
+0
-70
paddle/fluid/operators/channel_create_op.cc
paddle/fluid/operators/channel_create_op.cc
+0
-113
paddle/fluid/operators/channel_recv_op.cc
paddle/fluid/operators/channel_recv_op.cc
+0
-98
paddle/fluid/operators/channel_send_op.cc
paddle/fluid/operators/channel_send_op.cc
+0
-76
paddle/fluid/operators/concurrency/CMakeLists.txt
paddle/fluid/operators/concurrency/CMakeLists.txt
+0
-1
paddle/fluid/operators/concurrency/channel_util.cc
paddle/fluid/operators/concurrency/channel_util.cc
+0
-111
paddle/fluid/operators/concurrency/channel_util.h
paddle/fluid/operators/concurrency/channel_util.h
+0
-38
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+1
-0
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+1
-0
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+1
-0
paddle/fluid/operators/select_op.cc
paddle/fluid/operators/select_op.cc
+0
-419
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+0
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
python/paddle/fluid/concurrency.py
python/paddle/fluid/concurrency.py
+0
-454
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+1
-2
python/paddle/fluid/tests/no_test_concurrency.py
python/paddle/fluid/tests/no_test_concurrency.py
+0
-260
python/paddle/fluid/tests/notest_concurrency.py
python/paddle/fluid/tests/notest_concurrency.py
+0
-41
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
425a8821
...
@@ -169,15 +169,8 @@ cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
...
@@ -169,15 +169,8 @@ cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto
)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto
)
cc_test
(
cow_ptr_tests SRCS details/cow_ptr_test.cc
)
cc_test
(
cow_ptr_tests SRCS details/cow_ptr_test.cc
)
# cc_test(channel_test SRCS channel_test.cc)
cc_test
(
tuple_test SRCS tuple_test.cc
)
cc_test
(
tuple_test SRCS tuple_test.cc
)
if
(
NOT WIN32
)
if
(
NOT WIN32
)
cc_test
(
rw_lock_test SRCS rw_lock_test.cc
)
cc_test
(
rw_lock_test SRCS rw_lock_test.cc
)
endif
(
NOT WIN32
)
endif
(
NOT WIN32
)
# disable test temporarily.
# TODO https://github.com/PaddlePaddle/Paddle/issues/11971
# cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
# channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
# conditional_block_op while_op assign_op print_op executor proto_desc)
paddle/fluid/framework/channel.h
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <condition_variable> // NOLINT
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
enum
class
ChannelAction
{
SEND
=
0
,
RECEIVE
=
1
,
CLOSE
=
2
,
};
// Channel is the abstract class of buffered and un-buffered channels.
template
<
typename
T
>
class
Channel
{
public:
virtual
bool
CanSend
()
=
0
;
virtual
bool
CanReceive
()
=
0
;
virtual
void
Send
(
T
*
)
=
0
;
virtual
bool
Receive
(
T
*
)
=
0
;
virtual
size_t
Cap
()
=
0
;
virtual
void
Lock
()
=
0
;
virtual
void
Unlock
()
=
0
;
virtual
bool
IsClosed
()
=
0
;
virtual
void
Close
()
=
0
;
virtual
~
Channel
()
{}
virtual
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
=
0
;
virtual
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
=
0
;
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
)
=
0
;
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
=
0
;
};
// Forward declaration of channel implementations.
template
<
typename
T
>
class
ChannelImpl
;
template
<
typename
T
>
Channel
<
T
>*
MakeChannel
(
size_t
buffer_size
)
{
return
new
ChannelImpl
<
T
>
(
buffer_size
);
}
template
<
typename
T
>
void
CloseChannel
(
Channel
<
T
>*
ch
)
{
ch
->
Close
();
}
/*
* The ChannelHolder class serves two main purposes:
* 1. It acts as a unified wrapper for the different kinds of
* channels, i.e. Buffered and Unbuffered channels. This is
* similar to the ReaderHolder class.
* 2. It also helps us in TypeHiding. This is similar to the
* PlaceHolder implementations in variable.h and tensor.h.
*/
class
ChannelHolder
{
public:
template
<
typename
T
>
void
Reset
(
size_t
buffer_size
)
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
buffer_size
));
}
template
<
typename
T
>
void
Send
(
T
*
data
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
PADDLE_ENFORCE_EQ
(
holder_
->
Type
(),
std
::
type_index
(
typeid
(
T
)),
"Channel type is not same as the type of the data being sent"
);
// Static cast should be safe because we have ensured that types are same
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
PADDLE_ENFORCE_EQ
(
channel
!=
nullptr
,
true
,
"Channel should not be null."
);
channel
->
Send
(
data
);
}
template
<
typename
T
>
bool
Receive
(
T
*
data
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
PADDLE_ENFORCE_EQ
(
holder_
->
Type
(),
std
::
type_index
(
typeid
(
T
)),
"Channel type is not same as the type of the data being sent"
);
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
PADDLE_ENFORCE_EQ
(
channel
!=
nullptr
,
true
,
"Channel should not be null."
);
return
channel
->
Receive
(
data
);
}
bool
IsClosed
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
IsClosed
();
}
bool
CanSend
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
CanSend
();
}
bool
CanReceive
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
CanReceive
();
}
void
close
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
Close
();
}
size_t
Cap
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
Cap
();
}
void
Lock
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
Lock
();
}
void
Unlock
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
Unlock
();
}
template
<
typename
T
>
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
if
(
channel
!=
nullptr
)
{
channel
->
AddToSendQ
(
referrer
,
data
,
cond
,
cb
);
}
}
template
<
typename
T
>
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
if
(
channel
!=
nullptr
)
{
channel
->
AddToReceiveQ
(
referrer
,
data
,
cond
,
cb
);
}
}
void
RemoveFromSendQ
(
const
void
*
referrer
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
RemoveFromSendQ
(
referrer
);
}
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
RemoveFromReceiveQ
(
referrer
);
}
inline
bool
IsInitialized
()
const
{
return
holder_
!=
nullptr
;
}
inline
const
std
::
type_index
Type
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
Type
();
}
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of ChannelHolder.
*/
struct
Placeholder
{
virtual
~
Placeholder
()
{}
virtual
const
std
::
type_index
Type
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
virtual
bool
IsClosed
()
=
0
;
virtual
bool
CanSend
()
=
0
;
virtual
bool
CanReceive
()
=
0
;
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
)
=
0
;
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
=
0
;
virtual
void
Close
()
=
0
;
virtual
void
Lock
()
=
0
;
virtual
void
Unlock
()
=
0
;
virtual
size_t
Cap
()
=
0
;
};
template
<
typename
T
>
struct
PlaceholderImpl
:
public
Placeholder
{
explicit
PlaceholderImpl
(
size_t
buffer_size
)
:
type_
(
std
::
type_index
(
typeid
(
T
)))
{
channel_
.
reset
(
MakeChannel
<
T
>
(
buffer_size
));
}
virtual
const
std
::
type_index
Type
()
const
{
return
type_
;
}
virtual
void
*
Ptr
()
const
{
return
static_cast
<
void
*>
(
channel_
.
get
());
}
virtual
bool
IsClosed
()
{
if
(
channel_
)
{
return
channel_
->
IsClosed
();
}
return
false
;
}
virtual
bool
CanSend
()
{
if
(
channel_
)
{
return
channel_
->
CanSend
();
}
return
false
;
}
virtual
bool
CanReceive
()
{
if
(
channel_
)
{
return
channel_
->
CanReceive
();
}
return
false
;
}
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
)
{
if
(
channel_
)
{
channel_
->
RemoveFromSendQ
(
referrer
);
}
}
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
{
if
(
channel_
)
{
channel_
->
RemoveFromReceiveQ
(
referrer
);
}
}
virtual
void
Close
()
{
if
(
channel_
)
channel_
->
Close
();
}
virtual
size_t
Cap
()
{
if
(
channel_
)
return
channel_
->
Cap
();
else
return
-
1
;
}
virtual
void
Lock
()
{
if
(
channel_
)
channel_
->
Lock
();
}
virtual
void
Unlock
()
{
if
(
channel_
)
channel_
->
Unlock
();
}
std
::
unique_ptr
<
Channel
<
T
>>
channel_
;
const
std
::
type_index
type_
;
};
// Pointer to a PlaceholderImpl object
std
::
unique_ptr
<
Placeholder
>
holder_
;
};
}
// namespace framework
}
// namespace paddle
#include "paddle/fluid/framework/channel_impl.h"
paddle/fluid/framework/channel_impl.h
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <atomic>
#include <condition_variable> // NOLINT
#include <deque>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
class
ChannelImpl
:
public
paddle
::
framework
::
Channel
<
T
>
{
friend
Channel
<
T
>
*
paddle
::
framework
::
MakeChannel
<
T
>
(
size_t
);
friend
void
paddle
::
framework
::
CloseChannel
<
T
>
(
Channel
<
T
>
*
);
public:
virtual
bool
CanSend
();
virtual
bool
CanReceive
();
virtual
void
Send
(
T
*
);
virtual
bool
Receive
(
T
*
);
virtual
size_t
Cap
()
{
return
cap_
;
}
virtual
void
Lock
();
virtual
void
Unlock
();
virtual
bool
IsClosed
();
virtual
void
Close
();
explicit
ChannelImpl
(
size_t
);
virtual
~
ChannelImpl
();
virtual
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
);
virtual
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
);
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
);
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
);
private:
struct
QueueMessage
{
T
*
data
;
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
;
bool
chan_closed
=
false
;
bool
completed
=
false
;
const
void
*
referrer
;
// TODO(thuan): figure out better way to do this
std
::
function
<
bool
(
ChannelAction
)
>
callback
;
explicit
QueueMessage
(
T
*
item
)
:
data
(
item
),
cond
(
std
::
make_shared
<
std
::
condition_variable_any
>
())
{}
QueueMessage
(
T
*
item
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
)
:
data
(
item
),
cond
(
cond
)
{}
void
Wait
(
std
::
unique_lock
<
std
::
recursive_mutex
>
&
lock
)
{
cond
->
wait
(
lock
,
[
this
]()
{
return
completed
;
});
}
void
Notify
()
{
completed
=
true
;
cond
->
notify_all
();
}
};
void
send_return
()
{
send_ctr
--
;
destructor_cond_
.
notify_all
();
}
bool
recv_return
(
bool
value
)
{
recv_ctr
--
;
destructor_cond_
.
notify_all
();
return
value
;
}
std
::
shared_ptr
<
QueueMessage
>
get_first_message
(
std
::
deque
<
std
::
shared_ptr
<
QueueMessage
>>
*
queue
,
ChannelAction
action
)
{
while
(
!
queue
->
empty
())
{
// Check whether this message was added by Select
// If this was added by Select then execute the callback
// to check if you can execute this message. The callback
// can return false if some other case was executed in Select.
// In that case just discard this QueueMessage and process next.
std
::
shared_ptr
<
QueueMessage
>
m
=
queue
->
front
();
queue
->
pop_front
();
if
(
m
->
callback
==
nullptr
||
m
->
callback
(
action
))
return
m
;
}
return
nullptr
;
}
size_t
cap_
;
std
::
recursive_mutex
mu_
;
bool
closed_
;
std
::
deque
<
T
>
buf_
;
std
::
deque
<
std
::
shared_ptr
<
QueueMessage
>>
recvq
;
std
::
deque
<
std
::
shared_ptr
<
QueueMessage
>>
sendq
;
std
::
atomic
<
unsigned
>
send_ctr
{
0
};
std
::
atomic
<
unsigned
>
recv_ctr
{
0
};
std
::
condition_variable_any
destructor_cond_
;
};
template
<
typename
T
>
ChannelImpl
<
T
>::
ChannelImpl
(
size_t
capacity
)
:
cap_
(
capacity
),
closed_
(
false
),
send_ctr
(
0
),
recv_ctr
(
0
)
{
PADDLE_ENFORCE_GE
(
capacity
,
0
);
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
CanSend
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
!
closed_
&&
(
!
recvq
.
empty
()
||
buf_
.
size
()
<
cap_
);
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
CanReceive
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
!
(
closed_
&&
buf_
.
empty
())
&&
(
!
sendq
.
empty
()
||
buf_
.
size
()
>
0
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Send
(
T
*
item
)
{
send_ctr
++
;
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
// If channel is closed, throw exception
if
(
closed_
)
{
send_return
();
lock
.
unlock
();
PADDLE_THROW
(
"Cannot send on closed channel"
);
}
// If there is a receiver, directly pass the value we want
// to send to the receiver, bypassing the channel buffer if any
if
(
!
recvq
.
empty
())
{
std
::
shared_ptr
<
QueueMessage
>
m
=
get_first_message
(
&
recvq
,
ChannelAction
::
SEND
);
if
(
m
!=
nullptr
)
{
*
(
m
->
data
)
=
std
::
move
(
*
item
);
m
->
Notify
();
send_return
();
return
;
}
else
{
Send
(
item
);
send_return
();
return
;
}
}
// Unbuffered channel will always bypass this
// If buffered channel has space in buffer,
// write the element to the buffer.
if
(
buf_
.
size
()
<
cap_
)
{
// Copy to buffer
buf_
.
push_back
(
std
::
move
(
*
item
));
send_return
();
return
;
}
// Block on channel, because some receiver will complete
// the operation for us
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
item
);
sendq
.
push_back
(
m
);
m
->
Wait
(
lock
);
if
(
m
->
chan_closed
)
{
send_return
();
lock
.
unlock
();
PADDLE_THROW
(
"Cannot send on closed channel"
);
}
send_return
();
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
Receive
(
T
*
item
)
{
recv_ctr
++
;
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
// If channel is closed and buffer is empty or
// channel is unbuffered
if
(
closed_
&&
buf_
.
empty
())
return
recv_return
(
false
);
// If there is a sender, directly receive the value we want
// from the sender. In case of a buffered channel, read from
// buffer and move front of send queue to the buffer
if
(
!
sendq
.
empty
())
{
std
::
shared_ptr
<
QueueMessage
>
m
=
get_first_message
(
&
sendq
,
ChannelAction
::
RECEIVE
);
if
(
buf_
.
size
()
>
0
)
{
// Case 1 : Channel is Buffered
// Do Data transfer from front of buffer
// and add a QueueMessage to the buffer
*
item
=
std
::
move
(
buf_
.
front
());
buf_
.
pop_front
();
// If first message from sendq is not null
// add it to the buffer and notify it
if
(
m
!=
nullptr
)
{
// Copy to buffer
buf_
.
push_back
(
std
::
move
(
*
(
m
->
data
)));
m
->
Notify
();
}
// Ignore if there is no first message
}
else
{
// Case 2: Channel is Unbuffered
// Do data transfer from front of SendQ
// If front is nullptr, then recursively call itself
if
(
m
!=
nullptr
)
{
*
item
=
std
::
move
(
*
(
m
->
data
));
m
->
Notify
();
}
else
{
return
recv_return
(
Receive
(
item
));
}
}
return
recv_return
(
true
);
}
// If this is a buffered channel and there are items in buffer
if
(
buf_
.
size
()
>
0
)
{
// Directly read from buffer
*
item
=
std
::
move
(
buf_
.
front
());
buf_
.
pop_front
();
// return true
return
recv_return
(
true
);
}
// No sender available, block on this channel
// Some receiver will complete the option for us
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
item
);
recvq
.
push_back
(
m
);
m
->
Wait
(
lock
);
return
recv_return
(
!
m
->
chan_closed
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Lock
()
{
mu_
.
lock
();
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Unlock
()
{
mu_
.
unlock
();
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
IsClosed
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
closed_
;
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Close
()
{
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
if
(
closed_
)
{
// TODO(abhinavarora): closing an already closed channel should panic
lock
.
unlock
();
return
;
}
closed_
=
true
;
// Empty the readers
while
(
!
recvq
.
empty
())
{
std
::
shared_ptr
<
QueueMessage
>
m
=
recvq
.
front
();
recvq
.
pop_front
();
m
->
chan_closed
=
true
;
// Execute callback function (if any)
if
(
m
->
callback
!=
nullptr
)
{
m
->
callback
(
ChannelAction
::
CLOSE
);
}
m
->
Notify
();
}
// Empty the senders
while
(
!
sendq
.
empty
())
{
std
::
shared_ptr
<
QueueMessage
>
m
=
sendq
.
front
();
sendq
.
pop_front
();
m
->
chan_closed
=
true
;
// Execute callback function (if any)
if
(
m
->
callback
!=
nullptr
)
{
m
->
callback
(
ChannelAction
::
CLOSE
);
}
m
->
Notify
();
}
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
data
,
cond
);
m
->
referrer
=
referrer
;
m
->
callback
=
cb
;
sendq
.
push_back
(
m
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
data
,
cond
);
m
->
referrer
=
referrer
;
m
->
callback
=
cb
;
recvq
.
push_back
(
m
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
RemoveFromSendQ
(
const
void
*
referrer
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
for
(
auto
it
=
sendq
.
begin
();
it
!=
sendq
.
end
();)
{
std
::
shared_ptr
<
QueueMessage
>
sendMsg
=
(
std
::
shared_ptr
<
QueueMessage
>
)
*
it
;
if
(
sendMsg
->
referrer
==
referrer
)
{
it
=
sendq
.
erase
(
it
);
}
else
{
++
it
;
}
}
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
RemoveFromReceiveQ
(
const
void
*
referrer
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
for
(
auto
it
=
recvq
.
begin
();
it
!=
recvq
.
end
();)
{
std
::
shared_ptr
<
QueueMessage
>
recvMsg
=
(
std
::
shared_ptr
<
QueueMessage
>
)
*
it
;
if
(
recvMsg
->
referrer
==
referrer
)
{
it
=
recvq
.
erase
(
it
);
}
else
{
++
it
;
}
}
}
template
<
typename
T
>
ChannelImpl
<
T
>::~
ChannelImpl
()
{
Close
();
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
destructor_cond_
.
wait
(
lock
,
[
this
]()
{
return
send_ctr
==
0
&&
recv_ctr
==
0
;
});
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/channel_test.cc
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include "gtest/gtest.h"
using
paddle
::
framework
::
Channel
;
using
paddle
::
framework
::
ChannelHolder
;
using
paddle
::
framework
::
MakeChannel
;
using
paddle
::
framework
::
CloseChannel
;
TEST
(
Channel
,
ChannelCapacityTest
)
{
const
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
EXPECT_EQ
(
ch
->
Cap
(),
buffer_size
);
CloseChannel
(
ch
);
delete
ch
;
ch
=
MakeChannel
<
size_t
>
(
0
);
EXPECT_EQ
(
ch
->
Cap
(),
0U
);
CloseChannel
(
ch
);
delete
ch
;
}
void
RecevingOrderEqualToSendingOrder
(
Channel
<
int
>
*
ch
,
int
num_items
)
{
unsigned
sum_send
=
0
;
std
::
thread
t
([
&
]()
{
for
(
int
i
=
0
;
i
<
num_items
;
i
++
)
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
});
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
for
(
int
i
=
0
;
i
<
num_items
;
i
++
)
{
int
recv
=
-
1
;
EXPECT_EQ
(
ch
->
Receive
(
&
recv
),
true
);
EXPECT_EQ
(
recv
,
i
);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
CloseChannel
(
ch
);
t
.
join
();
unsigned
expected_sum
=
(
num_items
*
(
num_items
-
1
))
/
2
;
EXPECT_EQ
(
sum_send
,
expected_sum
);
delete
ch
;
}
TEST
(
Channel
,
SufficientBufferSizeDoesntBlock
)
{
const
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
for
(
size_t
i
=
0
;
i
<
buffer_size
;
++
i
)
{
ch
->
Send
(
&
i
);
}
size_t
out
;
for
(
size_t
i
=
0
;
i
<
buffer_size
;
++
i
)
{
EXPECT_EQ
(
ch
->
Receive
(
&
out
),
true
);
// should not block
EXPECT_EQ
(
out
,
i
);
}
CloseChannel
(
ch
);
delete
ch
;
}
// This tests that a channel must return false
// on send and receive performed after closing the channel.
// Receive will only return false after close when queue is empty.
// By creating separate threads for sending and receiving, we make this
// function able to test both buffered and unbuffered channels.
void
SendReceiveWithACloseChannelShouldPanic
(
Channel
<
size_t
>
*
ch
)
{
const
size_t
data
=
5
;
std
::
thread
send_thread
{[
&
]()
{
size_t
i
=
data
;
ch
->
Send
(
&
i
);
// should not block
}};
std
::
thread
recv_thread
{[
&
]()
{
size_t
i
;
EXPECT_EQ
(
ch
->
Receive
(
&
i
),
true
);
// should not block
EXPECT_EQ
(
i
,
data
);
}};
send_thread
.
join
();
recv_thread
.
join
();
// After closing send should panic. Receive should
// also false as there is no data in queue.
CloseChannel
(
ch
);
send_thread
=
std
::
thread
{[
&
]()
{
size_t
i
=
data
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
i
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
}};
recv_thread
=
std
::
thread
{[
&
]()
{
size_t
i
;
// should return false because channel is closed and queue is empty
EXPECT_EQ
(
ch
->
Receive
(
&
i
),
false
);
}};
send_thread
.
join
();
recv_thread
.
join
();
}
TEST
(
Channel
,
SendReceiveClosedBufferedChannelPanics
)
{
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
SendReceiveWithACloseChannelShouldPanic
(
ch
);
delete
ch
;
}
TEST
(
Channel
,
SendReceiveClosedUnBufferedChannelPanics
)
{
auto
ch
=
MakeChannel
<
size_t
>
(
0
);
SendReceiveWithACloseChannelShouldPanic
(
ch
);
delete
ch
;
}
TEST
(
Channel
,
ReceiveFromBufferedChannelReturnResidualValuesTest
)
{
const
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
for
(
size_t
i
=
0
;
i
<
buffer_size
;
++
i
)
{
ch
->
Send
(
&
i
);
// sending should not block
}
size_t
out
;
for
(
size_t
i
=
0
;
i
<
buffer_size
/
2
;
++
i
)
{
EXPECT_EQ
(
ch
->
Receive
(
&
out
),
true
);
// receiving should not block
EXPECT_EQ
(
out
,
i
);
}
CloseChannel
(
ch
);
for
(
size_t
i
=
buffer_size
/
2
;
i
<
buffer_size
;
++
i
)
{
EXPECT_EQ
(
ch
->
Receive
(
&
out
),
true
);
// receving should return residual values.
EXPECT_EQ
(
out
,
i
);
}
for
(
size_t
i
=
0
;
i
<
buffer_size
;
++
i
)
{
EXPECT_EQ
(
ch
->
Receive
(
&
out
),
false
);
// receiving on closed channel should return false
}
delete
ch
;
}
TEST
(
Channel
,
ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize
)
{
const
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
std
::
thread
t
([
&
]()
{
// Try to write more than buffer size.
for
(
size_t
i
=
0
;
i
<
2
*
buffer_size
;
++
i
)
{
if
(
i
<
buffer_size
)
{
ch
->
Send
(
&
i
);
// should block after 10 iterations
}
else
{
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
i
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
}
}
});
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
CloseChannel
(
ch
);
t
.
join
();
delete
ch
;
}
TEST
(
Channel
,
RecevingOrderEqualToSendingOrderWithUnBufferedChannel
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
RecevingOrderEqualToSendingOrder
(
ch
,
20
);
}
TEST
(
Channel
,
RecevingOrderEqualToSendingOrderWithBufferedChannel1
)
{
// Test that Receive Order is same as Send Order when number of items
// sent is less than size of buffer
auto
ch
=
MakeChannel
<
int
>
(
10
);
RecevingOrderEqualToSendingOrder
(
ch
,
5
);
}
TEST
(
Channel
,
RecevingOrderEqualToSendingOrderWithBufferedChannel2
)
{
// Test that Receive Order is same as Send Order when number of items
// sent is equal to size of buffer
auto
ch
=
MakeChannel
<
int
>
(
10
);
RecevingOrderEqualToSendingOrder
(
ch
,
10
);
}
TEST
(
Channel
,
RecevingOrderEqualToSendingOrderWithBufferedChannel3
)
{
// Test that Receive Order is same as Send Order when number of items
// sent is greater than the size of buffer
auto
ch
=
MakeChannel
<
int
>
(
10
);
RecevingOrderEqualToSendingOrder
(
ch
,
20
);
}
void
ChannelCloseUnblocksReceiversTest
(
Channel
<
int
>
*
ch
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to read and are blocked because of no writers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
EXPECT_EQ
(
ch
->
Receive
(
&
data
),
false
);
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
// Verify that all the threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
// Explicitly close the channel
// This should unblock all receivers
CloseChannel
(
ch
);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
void
ChannelCloseUnblocksSendersTest
(
Channel
<
int
>
*
ch
,
bool
isBuffered
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
bool
send_success
[
kNumThreads
];
// Launches threads that try to write and are blocked because of no readers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
send_success
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
,
bool
*
success
)
{
int
data
=
10
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
*
success
=
!
is_exception
;
*
ended
=
true
;
},
&
thread_ended
[
i
],
&
send_success
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
if
(
isBuffered
)
{
// If ch is Buffered, atleast 4 threads must be blocked.
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
!
thread_ended
[
i
])
ct
++
;
}
EXPECT_GE
(
ct
,
4
);
}
else
{
// If ch is UnBuffered, all the threads should be blocked.
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
}
// Explicitly close the thread
// This should unblock all senders
CloseChannel
(
ch
);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
if
(
isBuffered
)
{
// Verify that only 1 send was successful
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
send_success
[
i
])
ct
++
;
}
// Only 1 send must be successful
EXPECT_EQ
(
ct
,
1
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
// This tests that closing a buffered channel also unblocks
// any receivers waiting on the channel
TEST
(
Channel
,
BufferedChannelCloseUnblocksReceiversTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
1
);
ChannelCloseUnblocksReceiversTest
(
ch
);
delete
ch
;
}
// This tests that closing a buffered channel also unblocks
// any senders waiting for channel to have write space
TEST
(
Channel
,
BufferedChannelCloseUnblocksSendersTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
1
);
ChannelCloseUnblocksSendersTest
(
ch
,
true
);
delete
ch
;
}
// This tests that closing an unbuffered channel also unblocks
// unblocks any receivers waiting for senders
TEST
(
Channel
,
UnbufferedChannelCloseUnblocksReceiversTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
ChannelCloseUnblocksReceiversTest
(
ch
);
delete
ch
;
}
// This tests that closing an unbuffered channel also unblocks
// unblocks any senders waiting for senders
TEST
(
Channel
,
UnbufferedChannelCloseUnblocksSendersTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
ChannelCloseUnblocksSendersTest
(
ch
,
false
);
delete
ch
;
}
TEST
(
Channel
,
UnbufferedLessReceiveMoreSendTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
unsigned
sum_send
=
0
;
// Send should block after three iterations
// since we only have three receivers.
std
::
thread
t
([
&
]()
{
// Try to send more number of times
// than receivers
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
try
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
}
}
});
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
int
recv
;
ch
->
Receive
(
&
recv
);
EXPECT_EQ
(
recv
,
i
);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
EXPECT_EQ
(
sum_send
,
3U
);
CloseChannel
(
ch
);
t
.
join
();
delete
ch
;
}
TEST
(
Channel
,
UnbufferedMoreReceiveLessSendTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
unsigned
sum_send
=
0
;
unsigned
sum_receive
=
0
;
// The receiver should block after 5
// iterations, since there are only 5 senders.
std
::
thread
t
([
&
]()
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
int
recv
;
ch
->
Receive
(
&
recv
);
// should block after the fifth iteration.
EXPECT_EQ
(
recv
,
i
);
sum_receive
+=
i
;
}
});
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
EXPECT_EQ
(
sum_send
,
10U
);
EXPECT_EQ
(
sum_receive
,
10U
);
// send three more elements
for
(
int
i
=
5
;
i
<
8
;
i
++
)
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
CloseChannel
(
ch
);
t
.
join
();
EXPECT_EQ
(
sum_send
,
28U
);
EXPECT_EQ
(
sum_receive
,
28U
);
delete
ch
;
}
// This tests that destroying a channel unblocks
// any senders waiting for channel to have write space
void
ChannelDestroyUnblockSenders
(
Channel
<
int
>
*
ch
,
bool
isBuffered
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
bool
send_success
[
kNumThreads
];
// Launches threads that try to write and are blocked because of no readers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
send_success
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
,
bool
*
success
)
{
int
data
=
10
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
*
success
=
!
is_exception
;
*
ended
=
true
;
},
&
thread_ended
[
i
],
&
send_success
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
if
(
isBuffered
)
{
// If channel is buffered, verify that atleast 4 threads are blocked
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
thread_ended
[
i
]
==
false
)
ct
++
;
}
// Atleast 4 threads must be blocked
EXPECT_GE
(
ct
,
4
);
}
else
{
// Verify that all the threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
}
// Explicitly destroy the channel
delete
ch
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
// Count number of successful sends
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
send_success
[
i
])
ct
++
;
}
if
(
isBuffered
)
{
// Only 1 send must be successful
EXPECT_EQ
(
ct
,
1
);
}
else
{
// In unbuffered channel, no send should be successful
EXPECT_EQ
(
ct
,
0
);
}
// Join all threads
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
// This tests that destroying a channel also unblocks
// any receivers waiting on the channel
void
ChannelDestroyUnblockReceivers
(
Channel
<
int
>
*
ch
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to read and are blocked because of no writers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
// All reads should return false
EXPECT_EQ
(
ch
->
Receive
(
&
data
),
false
);
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// wait
// Verify that all threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
// delete the channel
delete
ch
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
TEST
(
Channel
,
BufferedChannelDestroyUnblocksReceiversTest
)
{
size_t
buffer_size
=
1
;
auto
ch
=
MakeChannel
<
int
>
(
buffer_size
);
ChannelDestroyUnblockReceivers
(
ch
);
}
TEST
(
Channel
,
BufferedChannelDestroyUnblocksSendersTest
)
{
size_t
buffer_size
=
1
;
auto
ch
=
MakeChannel
<
int
>
(
buffer_size
);
ChannelDestroyUnblockSenders
(
ch
,
true
);
}
// This tests that destroying an unbuffered channel also unblocks
// unblocks any receivers waiting for senders
TEST
(
Channel
,
UnbufferedChannelDestroyUnblocksReceiversTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
ChannelDestroyUnblockReceivers
(
ch
);
}
TEST
(
Channel
,
UnbufferedChannelDestroyUnblocksSendersTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
ChannelDestroyUnblockSenders
(
ch
,
false
);
}
TEST
(
ChannelHolder
,
ChannelHolderCapacityTest
)
{
const
size_t
buffer_size
=
10
;
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
buffer_size
);
EXPECT_EQ
(
ch
->
Cap
(),
buffer_size
);
delete
ch
;
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
EXPECT_EQ
(
ch
->
Cap
(),
0U
);
delete
ch
;
}
void
ChannelHolderSendReceive
(
ChannelHolder
*
ch
)
{
unsigned
sum_send
=
0
;
std
::
thread
t
([
&
]()
{
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
});
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
int
recv
;
EXPECT_EQ
(
ch
->
Receive
(
&
recv
),
true
);
EXPECT_EQ
(
recv
,
i
);
}
ch
->
close
();
t
.
join
();
EXPECT_EQ
(
sum_send
,
10U
);
}
TEST
(
ChannelHolder
,
ChannelHolderBufferedSendReceiveTest
)
{
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
10
);
ChannelHolderSendReceive
(
ch
);
delete
ch
;
}
TEST
(
ChannelHolder
,
ChannelHolderUnBufferedSendReceiveTest
)
{
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderSendReceive
(
ch
);
delete
ch
;
}
TEST
(
ChannelHolder
,
ChannelUninitializedTest
)
{
ChannelHolder
*
ch
=
new
ChannelHolder
();
EXPECT_EQ
(
ch
->
IsInitialized
(),
false
);
int
i
=
10
;
bool
send_exception
=
false
;
try
{
ch
->
Send
(
&
i
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
send_exception
=
true
;
}
EXPECT_EQ
(
send_exception
,
true
);
bool
recv_exception
=
false
;
try
{
ch
->
Receive
(
&
i
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
recv_exception
=
true
;
}
EXPECT_EQ
(
recv_exception
,
true
);
bool
is_exception
=
false
;
try
{
ch
->
Type
();
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
}
TEST
(
ChannelHolder
,
ChannelInitializedTest
)
{
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
2
);
EXPECT_EQ
(
ch
->
IsInitialized
(),
true
);
// Channel should remain intialized even after close
ch
->
close
();
EXPECT_EQ
(
ch
->
IsInitialized
(),
true
);
delete
ch
;
}
TEST
(
ChannelHolder
,
TypeMismatchSendTest
)
{
// Test with unbuffered channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
bool
is_exception
=
false
;
bool
boolean_data
=
true
;
try
{
ch
->
Send
(
&
boolean_data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
// Test with Buffered Channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
float
>
(
10
);
is_exception
=
false
;
int
int_data
=
23
;
try
{
ch
->
Send
(
&
int_data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
}
TEST
(
ChannelHolder
,
TypeMismatchReceiveTest
)
{
// Test with unbuffered channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
bool
is_exception
=
false
;
bool
float_data
;
try
{
ch
->
Receive
(
&
float_data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
// Test with Buffered Channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
float
>
(
10
);
is_exception
=
false
;
int
int_data
=
23
;
try
{
ch
->
Receive
(
&
int_data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
}
void
ChannelHolderCloseUnblocksReceiversTest
(
ChannelHolder
*
ch
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to read and are blocked because of no writers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
EXPECT_EQ
(
ch
->
Receive
(
&
data
),
false
);
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
// Verify that all the threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
// Explicitly close the channel
// This should unblock all receivers
ch
->
close
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
void
ChannelHolderCloseUnblocksSendersTest
(
ChannelHolder
*
ch
,
bool
isBuffered
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
bool
send_success
[
kNumThreads
];
// Launches threads that try to write and are blocked because of no readers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
send_success
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
,
bool
*
success
)
{
int
data
=
10
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
*
success
=
!
is_exception
;
*
ended
=
true
;
},
&
thread_ended
[
i
],
&
send_success
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
if
(
isBuffered
)
{
// If ch is Buffered, atleast 4 threads must be blocked.
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
!
thread_ended
[
i
])
ct
++
;
}
EXPECT_GE
(
ct
,
4
);
}
else
{
// If ch is UnBuffered, all the threads should be blocked.
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
}
// Explicitly close the thread
// This should unblock all senders
ch
->
close
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
if
(
isBuffered
)
{
// Verify that only 1 send was successful
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
send_success
[
i
])
ct
++
;
}
// Only 1 send must be successful
EXPECT_EQ
(
ct
,
1
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
// This tests that closing a channelholder unblocks
// any receivers waiting on the channel
TEST
(
ChannelHolder
,
ChannelHolderCloseUnblocksReceiversTest
)
{
// Check for buffered channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
1
);
ChannelHolderCloseUnblocksReceiversTest
(
ch
);
delete
ch
;
// Check for unbuffered channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderCloseUnblocksReceiversTest
(
ch
);
delete
ch
;
}
// This tests that closing a channelholder unblocks
// any senders waiting for channel to have write space
TEST
(
Channel
,
ChannelHolderCloseUnblocksSendersTest
)
{
// Check for buffered channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
1
);
ChannelHolderCloseUnblocksSendersTest
(
ch
,
true
);
delete
ch
;
// Check for unbuffered channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderCloseUnblocksSendersTest
(
ch
,
false
);
delete
ch
;
}
// This tests that destroying a channelholder unblocks
// any senders waiting for channel
void
ChannelHolderDestroyUnblockSenders
(
ChannelHolder
*
ch
,
bool
isBuffered
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
bool
send_success
[
kNumThreads
];
// Launches threads that try to write and are blocked because of no readers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
send_success
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
,
bool
*
success
)
{
int
data
=
10
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
*
success
=
!
is_exception
;
*
ended
=
true
;
},
&
thread_ended
[
i
],
&
send_success
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
if
(
isBuffered
)
{
// If channel is buffered, verify that atleast 4 threads are blocked
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
thread_ended
[
i
]
==
false
)
ct
++
;
}
// Atleast 4 threads must be blocked
EXPECT_GE
(
ct
,
4
);
}
else
{
// Verify that all the threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
}
// Explicitly destroy the channel
delete
ch
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
// Count number of successfuld sends
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
send_success
[
i
])
ct
++
;
}
if
(
isBuffered
)
{
// Only 1 send must be successful
EXPECT_EQ
(
ct
,
1
);
}
else
{
// In unbuffered channel, no send should be successful
EXPECT_EQ
(
ct
,
0
);
}
// Join all threads
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
// This tests that destroying a channelholder also unblocks
// any receivers waiting on the channel
void
ChannelHolderDestroyUnblockReceivers
(
ChannelHolder
*
ch
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to read and are blocked because of no writers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
// All reads should return false
EXPECT_EQ
(
ch
->
Receive
(
&
data
),
false
);
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
// delete the channel
delete
ch
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
TEST
(
ChannelHolder
,
ChannelHolderDestroyUnblocksReceiversTest
)
{
// Check for Buffered Channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
1
);
ChannelHolderDestroyUnblockReceivers
(
ch
);
// ch is already deleted already deleted in
// ChannelHolderDestroyUnblockReceivers
// Check for Unbuffered channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderDestroyUnblockReceivers
(
ch
);
}
TEST
(
ChannelHolder
,
ChannelHolderDestroyUnblocksSendersTest
)
{
// Check for Buffered Channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
1
);
ChannelHolderDestroyUnblockSenders
(
ch
,
true
);
// ch is already deleted already deleted in
// ChannelHolderDestroyUnblockReceivers
// Check for Unbuffered channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderDestroyUnblockSenders
(
ch
,
false
);
}
// This tests that closing a channelholder many times.
void
ChannelHolderManyTimesClose
(
ChannelHolder
*
ch
)
{
const
int
kNumThreads
=
15
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to send data to channel.
for
(
size_t
i
=
0
;
i
<
kNumThreads
/
3
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
)
{
int
data
=
10
;
ch
->
Send
(
&
data
);
*
ended
=
true
;
},
&
thread_ended
[
i
]);
}
// Launches threads that try to receive data to channel.
for
(
size_t
i
=
kNumThreads
/
3
;
i
<
2
*
kNumThreads
/
3
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
if
(
ch
->
Receive
(
&
data
))
{
EXPECT_EQ
(
data
,
10
);
}
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
// Launches threads that try to close the channel.
for
(
size_t
i
=
2
*
kNumThreads
/
3
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
if
(
!
ch
->
IsClosed
())
{
ch
->
close
();
}
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// wait
// Verify that all threads are unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
EXPECT_TRUE
(
ch
->
IsClosed
());
// delete the channel
delete
ch
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
TEST
(
ChannelHolder
,
ChannelHolderManyTimesCloseTest
)
{
// Check for Buffered Channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
10
);
ChannelHolderManyTimesClose
(
ch
);
}
paddle/fluid/framework/concurrency_test.cc
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
USE_NO_KERNEL_OP
(
go
);
USE_NO_KERNEL_OP
(
channel_close
);
USE_NO_KERNEL_OP
(
channel_create
);
USE_NO_KERNEL_OP
(
channel_recv
);
USE_NO_KERNEL_OP
(
channel_send
);
USE_NO_KERNEL_OP
(
elementwise_add
);
USE_NO_KERNEL_OP
(
select
);
USE_NO_KERNEL_OP
(
conditional_block
);
USE_NO_KERNEL_OP
(
equal
);
USE_NO_KERNEL_OP
(
assign
);
USE_NO_KERNEL_OP
(
while
);
USE_NO_KERNEL_OP
(
print
);
namespace
f
=
paddle
::
framework
;
namespace
p
=
paddle
::
platform
;
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
LoDTensor
*
CreateVariable
(
Scope
*
scope
,
const
p
::
CPUPlace
&
place
,
std
::
string
name
,
T
value
)
{
// Create LoDTensor<int> of dim [1]
auto
var
=
scope
->
Var
(
name
);
auto
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
({
1
});
T
*
expect
=
tensor
->
mutable_data
<
T
>
(
place
);
expect
[
0
]
=
value
;
return
tensor
;
}
void
AddOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
AttributeMap
attrs
,
BlockDesc
*
block
)
{
// insert op
auto
op
=
block
->
AppendOp
();
op
->
SetType
(
type
);
for
(
auto
&
kv
:
inputs
)
{
op
->
SetInput
(
kv
.
first
,
kv
.
second
);
}
for
(
auto
&
kv
:
outputs
)
{
op
->
SetOutput
(
kv
.
first
,
kv
.
second
);
}
op
->
SetAttrMap
(
attrs
);
}
void
AddCase
(
ProgramDesc
*
program
,
Scope
*
scope
,
p
::
CPUPlace
*
place
,
BlockDesc
*
casesBlock
,
int
caseId
,
int
caseType
,
std
::
string
caseChannel
,
std
::
string
caseVarName
,
std
::
function
<
void
(
BlockDesc
*
,
Scope
*
)
>
func
)
{
std
::
string
caseCondName
=
std
::
string
(
"caseCond"
)
+
std
::
to_string
(
caseId
);
std
::
string
caseCondXVarName
=
std
::
string
(
"caseCondX"
)
+
std
::
to_string
(
caseId
);
BlockDesc
*
caseBlock
=
program
->
AppendBlock
(
*
casesBlock
);
func
(
caseBlock
,
scope
);
CreateVariable
(
scope
,
*
place
,
caseCondName
,
false
);
CreateVariable
(
scope
,
*
place
,
caseCondXVarName
,
caseId
);
CreateVariable
(
scope
,
*
place
,
caseVarName
,
caseId
);
scope
->
Var
(
"step_scope"
);
AddOp
(
"equal"
,
{{
"X"
,
{
caseCondXVarName
}},
{
"Y"
,
{
"caseToExecute"
}}},
{{
"Out"
,
{
caseCondName
}}},
{},
casesBlock
);
AddOp
(
"conditional_block"
,
{{
"X"
,
{
caseCondName
}},
{
"Params"
,
{}}},
{{
"Out"
,
{}},
{
"Scope"
,
{
"step_scope"
}}},
{{
"sub_block"
,
caseBlock
},
{
"is_scalar_condition"
,
true
}},
casesBlock
);
}
void
AddFibonacciSelect
(
Scope
*
scope
,
p
::
CPUPlace
*
place
,
ProgramDesc
*
program
,
BlockDesc
*
parentBlock
,
std
::
string
dataChanName
,
std
::
string
quitChanName
)
{
BlockDesc
*
whileBlock
=
program
->
AppendBlock
(
*
parentBlock
);
CreateVariable
(
scope
,
*
place
,
"whileExitCond"
,
true
);
CreateVariable
(
scope
,
*
place
,
"caseToExecute"
,
-
1
);
CreateVariable
(
scope
,
*
place
,
"case1var"
,
0
);
CreateVariable
(
scope
,
*
place
,
"xtemp"
,
0
);
// TODO(thuan): Need to create fibXToSend, since channel send moves the actual
// data,
// which causes the data to be no longer accessible to do the fib calculation
// TODO(abhinav): Change channel send to do a copy instead of a move!
CreateVariable
(
scope
,
*
place
,
"fibXToSend"
,
0
);
CreateVariable
(
scope
,
*
place
,
"fibX"
,
0
);
CreateVariable
(
scope
,
*
place
,
"fibY"
,
1
);
CreateVariable
(
scope
,
*
place
,
"quitVar"
,
0
);
BlockDesc
*
casesBlock
=
program
->
AppendBlock
(
*
whileBlock
);
std
::
function
<
void
(
BlockDesc
*
caseBlock
)
>
f
=
[](
BlockDesc
*
caseBlock
)
{};
// TODO(thuan): Remove this once we change channel send to do a copy instead
// of move
AddOp
(
"assign"
,
{{
"X"
,
{
"fibX"
}}},
{{
"Out"
,
{
"fibXToSend"
}}},
{},
whileBlock
);
// Case 0: Send to dataChanName
std
::
function
<
void
(
BlockDesc
*
caseBlock
,
Scope
*
scope
)
>
case0Func
=
[
&
](
BlockDesc
*
caseBlock
,
Scope
*
scope
)
{
AddOp
(
"assign"
,
{{
"X"
,
{
"fibX"
}}},
{{
"Out"
,
{
"xtemp"
}}},
{},
caseBlock
);
AddOp
(
"assign"
,
{{
"X"
,
{
"fibY"
}}},
{{
"Out"
,
{
"fibX"
}}},
{},
caseBlock
);
AddOp
(
"elementwise_add"
,
{{
"X"
,
{
"xtemp"
}},
{
"Y"
,
{
"fibY"
}}},
{{
"Out"
,
{
"fibY"
}}},
{},
caseBlock
);
};
AddCase
(
program
,
scope
,
place
,
casesBlock
,
0
,
1
,
dataChanName
,
"fibXToSend"
,
case0Func
);
std
::
string
case0Config
=
std
::
string
(
"0,1,"
)
+
dataChanName
+
std
::
string
(
",fibXToSend"
);
// Case 1: Receive from quitChanName
std
::
function
<
void
(
BlockDesc
*
caseBlock
,
Scope
*
scope
)
>
case2Func
=
[
&
](
BlockDesc
*
caseBlock
,
Scope
*
scope
)
{
// Exit the while loop after we receive from quit channel.
// We assign a false to "whileExitCond" variable, which will
// break out of while_op loop
CreateVariable
(
scope
,
*
place
,
"whileFalse"
,
false
);
AddOp
(
"assign"
,
{{
"X"
,
{
"whileFalse"
}}},
{{
"Out"
,
{
"whileExitCond"
}}},
{},
caseBlock
);
};
AddCase
(
program
,
scope
,
place
,
casesBlock
,
1
,
2
,
quitChanName
,
"quitVar"
,
case2Func
);
std
::
string
case1Config
=
std
::
string
(
"1,2,"
)
+
quitChanName
+
std
::
string
(
",quitVar"
);
// Select block
AddOp
(
"select"
,
{{
"X"
,
{
dataChanName
,
quitChanName
}},
{
"case_to_execute"
,
{
"caseToExecute"
}}},
{{
"Out"
,
{}}},
{{
"sub_block"
,
casesBlock
},
{
"cases"
,
std
::
vector
<
std
::
string
>
{
case0Config
,
case1Config
}}},
whileBlock
);
scope
->
Var
(
"stepScopes"
);
AddOp
(
"while"
,
{{
"X"
,
{
dataChanName
,
quitChanName
}},
{
"Condition"
,
{
"whileExitCond"
}}},
{{
"Out"
,
{}},
{
"StepScopes"
,
{
"stepScopes"
}}},
{{
"sub_block"
,
whileBlock
}},
parentBlock
);
}
TEST
(
Concurrency
,
Go_Op
)
{
Scope
scope
;
p
::
CPUPlace
place
;
// Initialize scope variables
p
::
CPUDeviceContext
ctx
(
place
);
// Create channel variable
scope
.
Var
(
"Channel"
);
// Create Variables, x0 will be put into channel,
// result will be pulled from channel
CreateVariable
(
&
scope
,
place
,
"Status"
,
false
);
CreateVariable
(
&
scope
,
place
,
"x0"
,
99
);
CreateVariable
(
&
scope
,
place
,
"result"
,
0
);
framework
::
Executor
executor
(
place
);
ProgramDesc
program
;
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
// Create channel OP
AddOp
(
"channel_create"
,
{},
{{
"Out"
,
{
"Channel"
}}},
{{
"capacity"
,
10
},
{
"data_type"
,
f
::
proto
::
VarType
::
LOD_TENSOR
}},
block
);
// Create Go Op routine
BlockDesc
*
goOpBlock
=
program
.
AppendBlock
(
program
.
Block
(
0
));
AddOp
(
"channel_send"
,
{{
"Channel"
,
{
"Channel"
}},
{
"X"
,
{
"x0"
}}},
{{
"Status"
,
{
"Status"
}}},
{},
goOpBlock
);
// Create Go Op
AddOp
(
"go"
,
{{
"X"
,
{
"Channel"
,
"x0"
}}},
{},
{{
"sub_block"
,
goOpBlock
}},
block
);
// Create Channel Receive Op
AddOp
(
"channel_recv"
,
{{
"Channel"
,
{
"Channel"
}}},
{{
"Status"
,
{
"Status"
}},
{
"Out"
,
{
"result"
}}},
{},
block
);
// Create Channel Close Op
AddOp
(
"channel_close"
,
{{
"Channel"
,
{
"Channel"
}}},
{},
{},
block
);
// Check the result tensor to make sure it is set to 0
const
LoDTensor
&
tensor
=
(
scope
.
FindVar
(
"result"
))
->
Get
<
LoDTensor
>
();
auto
*
initialData
=
tensor
.
data
<
int
>
();
EXPECT_EQ
(
initialData
[
0
],
0
);
executor
.
Run
(
program
,
&
scope
,
0
,
true
,
true
);
// After we call executor.run, the Go operator should do a channel_send to
// set the "result" variable to 99.
auto
*
finalData
=
tensor
.
data
<
int
>
();
EXPECT_EQ
(
finalData
[
0
],
99
);
}
/**
* This test implements the fibonacci function using go_op and select_op
*/
TEST
(
Concurrency
,
Select
)
{
Scope
scope
;
p
::
CPUPlace
place
;
// Initialize scope variables
p
::
CPUDeviceContext
ctx
(
place
);
CreateVariable
(
&
scope
,
place
,
"Status"
,
false
);
CreateVariable
(
&
scope
,
place
,
"result"
,
0
);
CreateVariable
(
&
scope
,
place
,
"currentXFib"
,
0
);
framework
::
Executor
executor
(
place
);
ProgramDesc
program
;
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
// Create channel OP
std
::
string
dataChanName
=
"Channel"
;
scope
.
Var
(
dataChanName
);
AddOp
(
"channel_create"
,
{},
{{
"Out"
,
{
dataChanName
}}},
{{
"capacity"
,
0
},
{
"data_type"
,
f
::
proto
::
VarType
::
LOD_TENSOR
}},
block
);
std
::
string
quitChanName
=
"Quit"
;
scope
.
Var
(
quitChanName
);
AddOp
(
"channel_create"
,
{},
{{
"Out"
,
{
quitChanName
}}},
{{
"capacity"
,
0
},
{
"data_type"
,
f
::
proto
::
VarType
::
LOD_TENSOR
}},
block
);
// Create Go Op routine, which loops 10 times over fibonacci sequence
CreateVariable
(
&
scope
,
place
,
"xReceiveVar"
,
0
);
BlockDesc
*
goOpBlock
=
program
.
AppendBlock
(
program
.
Block
(
0
));
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
AddOp
(
"channel_recv"
,
{{
"Channel"
,
{
dataChanName
}}},
{{
"Status"
,
{
"Status"
}},
{
"Out"
,
{
"currentXFib"
}}},
{},
goOpBlock
);
AddOp
(
"print"
,
{{
"In"
,
{
"currentXFib"
}}},
{{
"Out"
,
{
"currentXFib"
}}},
{{
"first_n"
,
100
},
{
"summarize"
,
-
1
},
{
"print_tensor_name"
,
false
},
{
"print_tensor_type"
,
true
},
{
"print_tensor_shape"
,
false
},
{
"print_tensor_lod"
,
false
},
{
"print_phase"
,
std
::
string
(
"FORWARD"
)},
{
"message"
,
std
::
string
(
"X: "
)}},
goOpBlock
);
}
CreateVariable
(
&
scope
,
place
,
"quitSignal"
,
0
);
AddOp
(
"channel_send"
,
{{
"Channel"
,
{
quitChanName
}},
{
"X"
,
{
"quitSignal"
}}},
{{
"Status"
,
{
"Status"
}}},
{},
goOpBlock
);
// Create Go Op
AddOp
(
"go"
,
{{
"X"
,
{
dataChanName
,
quitChanName
}}},
{},
{{
"sub_block"
,
goOpBlock
}},
block
);
AddFibonacciSelect
(
&
scope
,
&
place
,
&
program
,
block
,
dataChanName
,
quitChanName
);
// Create Channel Close Op
AddOp
(
"channel_close"
,
{{
"Channel"
,
{
dataChanName
}}},
{},
{},
block
);
AddOp
(
"channel_close"
,
{{
"Channel"
,
{
quitChanName
}}},
{},
{},
block
);
executor
.
Run
(
program
,
&
scope
,
0
,
true
,
true
);
// After we call executor.run, "result" variable should be equal to 34
// (which is 10 loops through fibonacci sequence)
const
LoDTensor
&
tensor
=
(
scope
.
FindVar
(
"currentXFib"
))
->
Get
<
LoDTensor
>
();
auto
*
finalData
=
tensor
.
data
<
int
>
();
EXPECT_EQ
(
finalData
[
0
],
34
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/executor.cc
浏览文件 @
425a8821
...
@@ -14,7 +14,6 @@ limitations under the License. */
...
@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
...
@@ -76,15 +75,13 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
...
@@ -76,15 +75,13 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
var
->
GetMutable
<
platform
::
PlaceList
>
();
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
CHANNEL
)
{
var
->
GetMutable
<
ChannelHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
// GetMutable will be called in operator
// GetMutable will be called in operator
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
"Variable type %d is not in "
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER,
CHANNEL,
RAW]"
,
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW]"
,
var_type
);
var_type
);
}
}
}
}
...
...
paddle/fluid/framework/framework.proto
浏览文件 @
425a8821
...
@@ -126,7 +126,6 @@ message VarType {
...
@@ -126,7 +126,6 @@ message VarType {
LOD_TENSOR_ARRAY
=
13
;
LOD_TENSOR_ARRAY
=
13
;
PLACE_LIST
=
14
;
PLACE_LIST
=
14
;
READER
=
15
;
READER
=
15
;
CHANNEL
=
16
;
// Any runtime decided variable type is raw
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// raw variables should manage their own allocations
// in operators like nccl_op
// in operators like nccl_op
...
@@ -158,12 +157,6 @@ message VarType {
...
@@ -158,12 +157,6 @@ message VarType {
message
ReaderDesc
{
repeated
LoDTensorDesc
lod_tensor
=
1
;
}
message
ReaderDesc
{
repeated
LoDTensorDesc
lod_tensor
=
1
;
}
optional
ReaderDesc
reader
=
5
;
optional
ReaderDesc
reader
=
5
;
message
ChannelDesc
{
required
Type
data_type
=
1
;
required
int64
capacity
=
2
;
}
optional
ChannelDesc
channel
=
6
;
message
Tuple
{
repeated
Type
element_type
=
1
;
}
message
Tuple
{
repeated
Type
element_type
=
1
;
}
optional
Tuple
tuple
=
7
;
optional
Tuple
tuple
=
7
;
}
}
...
...
paddle/fluid/framework/naive_executor.cc
浏览文件 @
425a8821
...
@@ -12,11 +12,13 @@
...
@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/naive_executor.h"
#include <string>
#include "paddle/fluid/framework/channel.h"
#include <vector>
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/pretty_log.h"
...
@@ -44,8 +46,6 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
...
@@ -44,8 +46,6 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
var
->
GetMutable
<
platform
::
PlaceList
>
();
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
CHANNEL
)
{
var
->
GetMutable
<
ChannelHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
// GetMutable will be called in operator
// GetMutable will be called in operator
}
else
{
}
else
{
...
...
paddle/fluid/framework/tuple.h
浏览文件 @
425a8821
...
@@ -17,7 +17,6 @@ limitations under the License. */
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include <stdexcept>
#include <stdexcept>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_desc.h"
...
...
paddle/fluid/framework/var_desc.cc
浏览文件 @
425a8821
...
@@ -88,13 +88,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
...
@@ -88,13 +88,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
}
}
void
VarDesc
::
SetDataType
(
proto
::
VarType
::
Type
data_type
)
{
void
VarDesc
::
SetDataType
(
proto
::
VarType
::
Type
data_type
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
mutable_channel_desc
()
->
set_data_type
(
data_type
);
break
;
default:
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
}
}
}
void
VarDesc
::
SetDataTypes
(
void
VarDesc
::
SetDataTypes
(
...
@@ -115,13 +109,7 @@ void VarDesc::SetDataTypes(
...
@@ -115,13 +109,7 @@ void VarDesc::SetDataTypes(
}
}
proto
::
VarType
::
Type
VarDesc
::
GetDataType
()
const
{
proto
::
VarType
::
Type
VarDesc
::
GetDataType
()
const
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
channel_desc
().
data_type
();
break
;
default:
return
tensor_desc
().
data_type
();
return
tensor_desc
().
data_type
();
}
}
}
std
::
vector
<
proto
::
VarType
::
Type
>
VarDesc
::
GetDataTypes
()
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
VarDesc
::
GetDataTypes
()
const
{
...
@@ -134,17 +122,6 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
...
@@ -134,17 +122,6 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
return
res
;
return
res
;
}
}
void
VarDesc
::
SetCapacity
(
int64_t
capacity
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
desc_
.
mutable_type
()
->
mutable_channel
()
->
set_capacity
(
capacity
);
break
;
default:
PADDLE_THROW
(
"Setting 'capacity' is not supported by the type of var %s."
,
this
->
Name
());
}
}
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
LOD_TENSOR
:
case
proto
::
VarType
::
LOD_TENSOR
:
...
@@ -214,19 +191,6 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
...
@@ -214,19 +191,6 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}
}
}
}
const
proto
::
VarType
::
ChannelDesc
&
VarDesc
::
channel_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
desc_
.
type
().
channel
();
default:
PADDLE_THROW
(
"Getting 'channel_desc' is not supported by the type of var %s."
,
this
->
Name
());
}
}
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
...
@@ -262,20 +226,6 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
...
@@ -262,20 +226,6 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
}
}
}
}
proto
::
VarType
::
ChannelDesc
*
VarDesc
::
mutable_channel_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
desc_
.
mutable_type
()
->
mutable_channel
();
default:
PADDLE_THROW
(
"Getting 'mutable_channel_desc' is not supported by the type of var "
"%s."
,
this
->
Name
());
}
}
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
...
...
paddle/fluid/framework/var_desc.h
浏览文件 @
425a8821
...
@@ -87,8 +87,6 @@ class VarDesc {
...
@@ -87,8 +87,6 @@ class VarDesc {
void
SetDataTypes
(
void
SetDataTypes
(
const
std
::
vector
<
proto
::
VarType
::
Type
>
&
multiple_data_type
);
const
std
::
vector
<
proto
::
VarType
::
Type
>
&
multiple_data_type
);
void
SetCapacity
(
int64_t
capacity
);
proto
::
VarType
::
Type
GetDataType
()
const
;
proto
::
VarType
::
Type
GetDataType
()
const
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
()
const
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
()
const
;
...
@@ -110,10 +108,8 @@ class VarDesc {
...
@@ -110,10 +108,8 @@ class VarDesc {
void
SetPersistable
(
bool
persistable
)
{
desc_
.
set_persistable
(
persistable
);
}
void
SetPersistable
(
bool
persistable
)
{
desc_
.
set_persistable
(
persistable
);
}
private:
private:
const
proto
::
VarType
::
ChannelDesc
&
channel_desc
()
const
;
const
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
;
const
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
tensor_descs
()
const
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
tensor_descs
()
const
;
proto
::
VarType
::
ChannelDesc
*
mutable_channel_desc
();
proto
::
VarType
::
TensorDesc
*
mutable_tensor_desc
();
proto
::
VarType
::
TensorDesc
*
mutable_tensor_desc
();
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
mutable_tensor_descs
();
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
mutable_tensor_descs
();
...
...
paddle/fluid/framework/var_type.h
浏览文件 @
425a8821
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
@@ -41,8 +40,6 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
...
@@ -41,8 +40,6 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
return
proto
::
VarType_Type_SELECTED_ROWS
;
return
proto
::
VarType_Type_SELECTED_ROWS
;
}
else
if
(
IsType
<
ReaderHolder
>
(
type
))
{
}
else
if
(
IsType
<
ReaderHolder
>
(
type
))
{
return
proto
::
VarType_Type_READER
;
return
proto
::
VarType_Type_READER
;
}
else
if
(
IsType
<
ChannelHolder
>
(
type
))
{
return
proto
::
VarType_Type_CHANNEL
;
}
else
{
}
else
{
PADDLE_THROW
(
"ToVarType:Unsupported type %s"
,
type
.
name
());
PADDLE_THROW
(
"ToVarType:Unsupported type %s"
,
type
.
name
());
}
}
...
@@ -66,9 +63,6 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
...
@@ -66,9 +63,6 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case
proto
::
VarType_Type_READER
:
case
proto
::
VarType_Type_READER
:
visitor
(
var
.
Get
<
ReaderHolder
>
());
visitor
(
var
.
Get
<
ReaderHolder
>
());
return
;
return
;
case
proto
::
VarType_Type_CHANNEL
:
visitor
(
var
.
Get
<
ChannelHolder
>
());
return
;
default:
default:
PADDLE_THROW
(
"Not supported visit type, %d"
,
ToVarType
(
var
.
Type
()));
PADDLE_THROW
(
"Not supported visit type, %d"
,
ToVarType
(
var
.
Type
()));
}
}
...
...
paddle/fluid/inference/analysis/analysis_pass.h
浏览文件 @
425a8821
...
@@ -41,12 +41,6 @@ class AnalysisPass {
...
@@ -41,12 +41,6 @@ class AnalysisPass {
// all passes have run.
// all passes have run.
virtual
bool
Finalize
()
{
return
false
;
}
virtual
bool
Finalize
()
{
return
false
;
}
// Get a Pass appropriate to print the Node this pass operates on.
virtual
AnalysisPass
*
CreatePrinterPass
(
std
::
ostream
&
os
,
const
std
::
string
&
banner
)
const
{
return
nullptr
;
}
// Create a debugger Pass that draw the DFG by graphviz toolkit.
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual
AnalysisPass
*
CreateGraphvizDebugerPass
()
const
{
return
nullptr
;
}
virtual
AnalysisPass
*
CreateGraphvizDebugerPass
()
const
{
return
nullptr
;
}
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
425a8821
...
@@ -314,11 +314,6 @@ op_library(save_combine_op DEPS lod_tensor)
...
@@ -314,11 +314,6 @@ op_library(save_combine_op DEPS lod_tensor)
op_library
(
load_combine_op DEPS lod_tensor
)
op_library
(
load_combine_op DEPS lod_tensor
)
op_library
(
concat_op DEPS concat
)
op_library
(
concat_op DEPS concat
)
# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
add_subdirectory
(
concurrency
)
op_library
(
channel_send_op DEPS concurrency
)
op_library
(
channel_recv_op DEPS concurrency
)
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
...
...
paddle/fluid/operators/channel_close_op.cc
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
pf
=
paddle
::
framework
;
static
constexpr
char
kChannel
[]
=
"Channel"
;
namespace
paddle
{
namespace
operators
{
class
ChannelCloseOp
:
public
framework
::
OperatorBase
{
public:
ChannelCloseOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
inp
=
*
scope
.
FindVar
(
Input
(
kChannel
));
// Get the mutable version of the channel variable and closes it.
pf
::
ChannelHolder
*
ch
=
inp
.
GetMutable
<
framework
::
ChannelHolder
>
();
ch
->
close
();
}
};
class
ChannelCloseOpOpInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"Channel"
),
"The input of ChannelClose op must be set"
);
}
};
class
ChannelCloseOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
kChannel
,
"The Channel Variable that should be closed by"
" the ChannelClose Op."
);
AddComment
(
R"DOC(
Channel Close Operator.
This operator closes an open channel.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
channel_close
,
paddle
::
operators
::
ChannelCloseOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
ChannelCloseOpMaker
);
paddle/fluid/operators/channel_create_op.cc
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
namespace
pf
=
paddle
::
framework
;
static
constexpr
char
kOutput
[]
=
"Out"
;
namespace
paddle
{
namespace
operators
{
class
ChannelCreateOp
:
public
framework
::
OperatorBase
{
public:
ChannelCreateOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
out
=
*
scope
.
FindVar
(
Output
(
kOutput
));
// Determine the datatype and capacity of the channel to be created
// from the attributes provided.
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
Attr
<
int
>
(
"data_type"
));
auto
capacity
=
Attr
<
int
>
(
"capacity"
);
// Based on the datatype, create a new channel holder initialized with
// the given capacity. When capacity is 0, an unbuffered channel is
// created.
pf
::
ChannelHolder
*
ch
=
out
.
GetMutable
<
framework
::
ChannelHolder
>
();
if
(
dtype
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
ch
->
Reset
<
pf
::
LoDTensor
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
ch
->
Reset
<
pf
::
SelectedRows
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
LOD_RANK_TABLE
)
{
ch
->
Reset
<
pf
::
LoDRankTable
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
ch
->
Reset
<
pf
::
LoDTensorArray
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
READER
)
{
ch
->
Reset
<
pf
::
ReaderHolder
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
CHANNEL
)
{
ch
->
Reset
<
pf
::
ChannelHolder
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
BOOL
)
{
ch
->
Reset
<
bool
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
INT32
)
{
ch
->
Reset
<
int
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
INT64
)
{
ch
->
Reset
<
int64_t
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
FP32
)
{
ch
->
Reset
<
float
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
FP64
)
{
ch
->
Reset
<
double
>
(
capacity
);
}
else
{
PADDLE_THROW
(
"Data type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, LOD_RANK_TABLE, LOD_TENSOR_ARRAY, "
"READER, CHANNEL, BOOL, INT32, INT64, FP32, FP64]"
,
dtype
);
}
}
};
class
ChannelCreateOpOpInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasOutput
(
kOutput
),
"The output of ChannelCreate op must be set"
);
context
->
SetOutputDim
(
kOutput
,
{
1
});
}
};
class
ChannelCreateOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
kOutput
,
"The object of a Channel type created by ChannelCreate Op."
);
AddAttr
<
int
>
(
"capacity"
,
"The size of the buffer of Channel."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"data_type"
,
"The data type of elements inside the Channel."
);
AddComment
(
R"DOC(
Channel Create Operator.
This operator creates an object of the VarType Channel and returns it.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
channel_create
,
paddle
::
operators
::
ChannelCreateOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
ChannelCreateOpMaker
);
paddle/fluid/operators/channel_recv_op.cc
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include <paddle/fluid/framework/lod_rank_table.h>
#include <paddle/fluid/framework/lod_tensor_array.h>
#include <paddle/fluid/framework/reader.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/operators/math/math_function.h"
static
constexpr
char
Channel
[]
=
"Channel"
;
static
constexpr
char
Status
[]
=
"Status"
;
static
constexpr
char
Out
[]
=
"Out"
;
namespace
paddle
{
namespace
operators
{
void
SetReceiveStatus
(
const
platform
::
Place
&
dev_place
,
framework
::
Variable
*
status_var
,
bool
status
)
{
auto
cpu
=
platform
::
CPUPlace
();
auto
status_tensor
=
status_var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
mutable_data
<
bool
>
({
1
},
cpu
);
status_tensor
[
0
]
=
status
;
}
class
ChannelRecvOp
:
public
framework
::
OperatorBase
{
public:
ChannelRecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
Channel
),
"Input(Channel) of ChannelRecvOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
Out
),
"Input(Channel) of ChannelRecvOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
Status
),
"Output(Status) of ChannelRecvOp should not be null."
);
ctx
->
SetOutputDim
(
"Status"
,
{
1
});
}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// Get the channel holder created by channel_create op, passed as input.
framework
::
ChannelHolder
*
ch
=
scope
.
FindVar
(
Input
(
Channel
))
->
GetMutable
<
framework
::
ChannelHolder
>
();
auto
output_var
=
scope
.
FindVar
(
Output
(
Out
));
// Receive the data from the channel.
bool
ok
=
concurrency
::
ChannelReceive
(
ch
,
output_var
);
// Set the status output of the `ChannelReceive` call.
SetReceiveStatus
(
dev_place
,
scope
.
FindVar
(
Output
(
Status
)),
ok
);
}
};
class
ChannelRecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
Channel
,
"(Channel) A variable which
\"
receives
\"
the a value sent"
"to it by a channel_send op."
)
.
AsDuplicable
();
AddOutput
(
Out
,
"(Variable) Output Variable that will hold the data received"
" from the Channel"
)
.
AsDuplicable
();
AddOutput
(
Status
,
"(Tensor) An LoD Tensor that returns a boolean status of the"
"result of the receive operation."
)
.
AsDuplicable
();
AddComment
(
R"DOC(
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
channel_recv
,
paddle
::
operators
::
ChannelRecvOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
ChannelRecvOpMaker
);
paddle/fluid/operators/channel_send_op.cc
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include <paddle/fluid/framework/lod_rank_table.h>
#include <paddle/fluid/framework/lod_tensor_array.h>
#include <paddle/fluid/framework/reader.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/operators/math/math_function.h"
static
constexpr
char
Channel
[]
=
"Channel"
;
static
constexpr
char
X
[]
=
"X"
;
namespace
paddle
{
namespace
operators
{
class
ChannelSendOp
:
public
framework
::
OperatorBase
{
public:
ChannelSendOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
Channel
),
"Input(Channel) of ChannelSendOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
X
),
"Input(X) of ChannelSendOp should not be null."
);
}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// Get the channel holder created by channel_create op, passed as input.
framework
::
ChannelHolder
*
ch
=
scope
.
FindVar
(
Input
(
Channel
))
->
GetMutable
<
framework
::
ChannelHolder
>
();
auto
input_var
=
scope
.
FindVar
(
Input
(
X
));
// Send the input data through the channel.
concurrency
::
ChannelSend
(
ch
,
input_var
);
}
};
class
ChannelSendOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
Channel
,
"(Channel) A variable which
\"
sends
\"
the passed in value to "
"a listening receiver."
)
.
AsDuplicable
();
AddInput
(
X
,
"(Variable) The value which gets sent by the channel."
)
.
AsDuplicable
();
AddComment
(
R"DOC(
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
channel_send
,
paddle
::
operators
::
ChannelSendOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
ChannelSendOpMaker
);
paddle/fluid/operators/concurrency/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
23644940
cc_library
(
concurrency SRCS channel_util.cc DEPS device_context framework_proto boost eigen3
)
paddle/fluid/operators/concurrency/channel_util.cc
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/framework/var_type.h"
namespace
poc
=
paddle
::
operators
::
concurrency
;
void
poc
::
ChannelSend
(
framework
::
ChannelHolder
*
ch
,
framework
::
Variable
*
var
)
{
auto
type
=
framework
::
ToVarType
(
var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
LoDTensor
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_RANK_TABLE
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
LoDRankTable
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR_ARRAY
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
LoDTensorArray
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
SelectedRows
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_READER
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
ReaderHolder
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_CHANNEL
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
ChannelHolder
>
());
else
PADDLE_THROW
(
"ChannelSend:Unsupported type"
);
}
bool
poc
::
ChannelReceive
(
framework
::
ChannelHolder
*
ch
,
framework
::
Variable
*
var
)
{
// Get type of channel and use that to call mutable data for Variable
auto
type
=
framework
::
ToVarType
(
ch
->
Type
());
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
LoDTensor
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_RANK_TABLE
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
LoDRankTable
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR_ARRAY
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
LoDTensorArray
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
SelectedRows
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_READER
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
ReaderHolder
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_CHANNEL
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
ChannelHolder
>
());
else
PADDLE_THROW
(
"ChannelReceive:Unsupported type"
);
}
void
poc
::
ChannelAddToSendQ
(
framework
::
ChannelHolder
*
ch
,
const
void
*
referrer
,
framework
::
Variable
*
var
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
framework
::
ChannelAction
)
>
cb
)
{
auto
type
=
framework
::
ToVarType
(
var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDTensor
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_RANK_TABLE
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDRankTable
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR_ARRAY
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDTensorArray
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
SelectedRows
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_READER
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
ReaderHolder
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_CHANNEL
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
ChannelHolder
>
(),
cond
,
cb
);
}
else
{
PADDLE_THROW
(
"ChannelAddToSendQ:Unsupported type"
);
}
}
void
poc
::
ChannelAddToReceiveQ
(
framework
::
ChannelHolder
*
ch
,
const
void
*
referrer
,
framework
::
Variable
*
var
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
framework
::
ChannelAction
)
>
cb
)
{
auto
type
=
framework
::
ToVarType
(
var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDTensor
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_RANK_TABLE
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDRankTable
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR_ARRAY
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDTensorArray
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
SelectedRows
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_READER
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
ReaderHolder
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_CHANNEL
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
ChannelHolder
>
(),
cond
,
cb
);
}
else
{
PADDLE_THROW
(
"ChannelAddToReceiveQ:Unsupported type"
);
}
}
paddle/fluid/operators/concurrency/channel_util.h
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/variable.h"
namespace
paddle
{
namespace
operators
{
namespace
concurrency
{
void
ChannelSend
(
framework
::
ChannelHolder
*
ch
,
framework
::
Variable
*
var
);
bool
ChannelReceive
(
framework
::
ChannelHolder
*
ch
,
framework
::
Variable
*
var
);
void
ChannelAddToSendQ
(
framework
::
ChannelHolder
*
ch
,
const
void
*
referrer
,
framework
::
Variable
*
var
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
framework
::
ChannelAction
)
>
cb
);
void
ChannelAddToReceiveQ
(
framework
::
ChannelHolder
*
ch
,
const
void
*
referrer
,
framework
::
Variable
*
var
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
framework
::
ChannelAction
)
>
cb
);
}
// namespace concurrency
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
425a8821
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <time.h>
#include <time.h>
#include <atomic>
#include <chrono> // NOLINT
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <condition_variable> // NOLINT
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
425a8821
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <time.h>
#include <time.h>
#include <condition_variable> // NOLINT
#include <functional>
#include <functional>
#include <string>
#include <string>
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
425a8821
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <atomic>
#include <set>
#include <set>
#include <string>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
...
...
paddle/fluid/operators/select_op.cc
已删除
100644 → 0
浏览文件 @
23644940
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include <boost/tokenizer.hpp>
namespace
paddle
{
namespace
operators
{
static
constexpr
char
kX
[]
=
"X"
;
static
constexpr
char
kCaseToExecute
[]
=
"case_to_execute"
;
static
constexpr
char
kOutputs
[]
=
"Out"
;
static
constexpr
char
kCases
[]
=
"cases"
;
static
constexpr
char
kCasesBlock
[]
=
"sub_block"
;
class
SelectOp
:
public
framework
::
OperatorBase
{
public:
SelectOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
enum
class
SelectOpCaseType
{
DEFAULT
=
0
,
SEND
=
1
,
RECEIVE
=
2
,
};
struct
SelectOpCase
{
int
caseIndex
;
SelectOpCaseType
caseType
;
std
::
string
channelName
;
std
::
string
varName
;
SelectOpCase
()
{}
SelectOpCase
(
int
caseIndex
,
SelectOpCaseType
caseType
,
std
::
string
channelName
,
std
::
string
varName
)
:
caseIndex
(
caseIndex
),
caseType
(
caseType
),
channelName
(
channelName
),
varName
(
varName
)
{}
};
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
std
::
vector
<
std
::
string
>
casesConfigs
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kCases
);
framework
::
BlockDesc
*
casesBlock
=
Attr
<
framework
::
BlockDesc
*>
(
kCasesBlock
);
framework
::
Scope
&
casesBlockScope
=
scope
.
NewScope
();
std
::
string
caseToExecuteVarName
=
Input
(
kCaseToExecute
);
framework
::
Variable
*
caseToExecuteVar
=
casesBlockScope
.
FindVar
(
caseToExecuteVarName
);
// Construct cases from "conditional_block_op"(s) in the casesBlock
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
cases
=
ParseAndShuffleCases
(
&
casesConfigs
);
// Get all unique channels involved in select
std
::
set
<
framework
::
ChannelHolder
*>
channelsSet
;
for
(
auto
c
:
cases
)
{
if
(
!
c
->
channelName
.
empty
())
{
auto
channelVar
=
scope
.
FindVar
(
c
->
channelName
);
framework
::
ChannelHolder
*
ch
=
channelVar
->
GetMutable
<
framework
::
ChannelHolder
>
();
if
(
channelsSet
.
find
(
ch
)
==
channelsSet
.
end
())
{
channelsSet
.
insert
(
ch
);
}
}
}
// Order all channels by their pointer address
std
::
vector
<
framework
::
ChannelHolder
*>
channels
(
channelsSet
.
begin
(),
channelsSet
.
end
());
std
::
sort
(
channels
.
begin
(),
channels
.
end
());
// Poll all cases
int32_t
caseToExecute
=
pollCases
(
&
scope
,
&
cases
,
channels
);
// At this point, the case to execute has already been determined,
// so we can proceed with executing the cases block
framework
::
LoDTensor
*
caseToExecuteTensor
=
caseToExecuteVar
->
GetMutable
<
framework
::
LoDTensor
>
();
caseToExecuteTensor
->
data
<
int32_t
>
()[
0
]
=
caseToExecute
;
// Execute the cases block, only one case will be executed since we set the
// case_to_execute value to the index of the case we want to execute
framework
::
Executor
executor
(
dev_place
);
framework
::
ProgramDesc
*
program
=
casesBlock
->
Program
();
executor
.
Run
(
*
program
,
&
casesBlockScope
,
casesBlock
->
ID
(),
false
/*create_local_scope*/
);
}
/**
* Goes through all operators in the casesConfigs and processes
* "conditional_block" operators. These operators are mapped to our
* SelectOpCase objects. We randomize the case orders, and set the
* default case (if any exists) as the last case)
* @param casesBlock
* @return
*/
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
ParseAndShuffleCases
(
std
::
vector
<
std
::
string
>
*
casesConfigs
)
const
{
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
cases
;
std
::
shared_ptr
<
SelectOpCase
>
defaultCase
;
if
(
casesConfigs
!=
nullptr
)
{
boost
::
char_delimiters_separator
<
char
>
sep
(
false
,
","
,
""
);
for
(
std
::
vector
<
std
::
string
>::
iterator
itr
=
casesConfigs
->
begin
();
itr
<
casesConfigs
->
end
();
++
itr
)
{
std
::
string
caseConfig
=
*
itr
;
boost
::
tokenizer
<>
tokens
(
caseConfig
,
sep
);
boost
::
tokenizer
<>::
iterator
tok_iter
=
tokens
.
begin
();
PADDLE_ENFORCE
(
tok_iter
!=
tokens
.
end
(),
"Cannot get case index"
);
std
::
string
caseIndexString
=
*
tok_iter
;
int
caseIndex
=
std
::
stoi
(
caseIndexString
);
++
tok_iter
;
PADDLE_ENFORCE
(
tok_iter
!=
tokens
.
end
(),
"Cannot get case type"
);
std
::
string
caseTypeString
=
*
tok_iter
;
SelectOpCaseType
caseType
=
(
SelectOpCaseType
)
std
::
stoi
(
caseTypeString
);
std
::
string
caseChannel
;
std
::
string
caseChannelVar
;
++
tok_iter
;
if
(
caseType
!=
SelectOpCaseType
::
DEFAULT
)
{
PADDLE_ENFORCE
(
tok_iter
!=
tokens
.
end
(),
"Cannot get case channel"
);
caseChannel
=
*
tok_iter
;
++
tok_iter
;
PADDLE_ENFORCE
(
tok_iter
!=
tokens
.
end
(),
"Cannot get case channel variable"
);
caseChannelVar
=
*
tok_iter
;
}
auto
c
=
std
::
make_shared
<
SelectOpCase
>
(
caseIndex
,
caseType
,
caseChannel
,
caseChannelVar
);
if
(
caseType
==
SelectOpCaseType
::
DEFAULT
)
{
PADDLE_ENFORCE
(
defaultCase
==
nullptr
,
"Select can only contain one default case."
);
defaultCase
=
c
;
}
else
{
cases
.
push_back
(
c
);
}
}
}
// Randomly sort cases, with default case being last
std
::
random_shuffle
(
cases
.
begin
(),
cases
.
end
());
if
(
defaultCase
!=
nullptr
)
{
cases
.
push_back
(
defaultCase
);
}
return
cases
;
}
/**
* This method will recursively poll the cases and determines if any case
* condition is true.
* If none of the cases conditions are true (and there is no default case),
* then block
* the thread. The thread may be woken up by a channel operation, at which
* point we
* execute the case.
* @param scope
* @param cases
* @param channels
* @return
*/
int32_t
pollCases
(
const
framework
::
Scope
*
scope
,
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
*
cases
,
std
::
vector
<
framework
::
ChannelHolder
*>
channels
)
const
{
// Lock all involved channels
lockChannels
(
channels
);
std
::
atomic
<
int
>
caseToExecute
(
-
1
);
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>::
iterator
it
=
cases
->
begin
();
while
(
it
!=
cases
->
end
())
{
std
::
shared_ptr
<
SelectOpCase
>
c
=
*
it
;
auto
chVar
=
scope
->
FindVar
(
c
->
channelName
);
framework
::
ChannelHolder
*
ch
=
chVar
->
GetMutable
<
framework
::
ChannelHolder
>
();
switch
(
c
->
caseType
)
{
case
SelectOpCaseType
::
SEND
:
PADDLE_ENFORCE
(
!
ch
->
IsClosed
(),
"Cannot send to a closed channel"
);
if
(
ch
->
CanSend
())
{
// We can send to channel directly, send the data to channel
// and execute case
auto
chVar
=
scope
->
FindVar
(
c
->
varName
);
concurrency
::
ChannelSend
(
ch
,
chVar
);
caseToExecute
=
c
->
caseIndex
;
}
break
;
case
SelectOpCaseType
::
RECEIVE
:
if
(
ch
->
CanReceive
())
{
// We can receive from channel directly, send the data to channel
// and execute case
auto
chVar
=
scope
->
FindVar
(
c
->
varName
);
concurrency
::
ChannelReceive
(
ch
,
chVar
);
caseToExecute
=
c
->
caseIndex
;
}
break
;
case
SelectOpCaseType
::
DEFAULT
:
caseToExecute
=
c
->
caseIndex
;
break
;
}
if
(
caseToExecute
!=
-
1
)
{
// We found a case to execute, stop looking at other case statements
break
;
}
++
it
;
}
if
(
caseToExecute
==
-
1
)
{
// None of the cases are eligible to execute, enqueue current thread
// into all the sending/receiving queue of each involved channel
std
::
atomic
<
bool
>
completed
(
false
);
std
::
recursive_mutex
mutex
;
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mutex
};
// std::condition_variable_any selectCond;
auto
selectCond
=
std
::
make_shared
<
std
::
condition_variable_any
>
();
std
::
recursive_mutex
callbackMutex
;
pushThreadOnChannelQueues
(
scope
,
cases
,
selectCond
,
&
caseToExecute
,
&
completed
,
&
callbackMutex
);
// TODO(thuan): Atomically unlock all channels and sleep current thread
unlockChannels
(
channels
);
selectCond
->
wait
(
lock
,
[
&
completed
]()
{
return
completed
.
load
();
});
// Select has been woken up by case operation
lockChannels
(
channels
);
removeThreadOnChannelQueues
(
scope
,
cases
);
if
(
caseToExecute
==
-
1
)
{
// Recursively poll cases, since we were woken up by a channel close
// TODO(thuan): Need to test if this is a valid case
unlockChannels
(
channels
);
return
pollCases
(
scope
,
cases
,
channels
);
}
}
// At this point, caseToExecute != -1, and we can proceed with executing
// the case block
unlockChannels
(
channels
);
return
caseToExecute
;
}
void
lockChannels
(
std
::
vector
<
framework
::
ChannelHolder
*>
chs
)
const
{
std
::
vector
<
framework
::
ChannelHolder
*>::
iterator
it
=
chs
.
begin
();
while
(
it
!=
chs
.
end
())
{
framework
::
ChannelHolder
*
ch
=
*
it
;
ch
->
Lock
();
++
it
;
}
}
void
unlockChannels
(
std
::
vector
<
framework
::
ChannelHolder
*>
chs
)
const
{
std
::
vector
<
framework
::
ChannelHolder
*>::
reverse_iterator
it
=
chs
.
rbegin
();
while
(
it
!=
chs
.
rend
())
{
framework
::
ChannelHolder
*
ch
=
*
it
;
ch
->
Unlock
();
++
it
;
}
}
void
pushThreadOnChannelQueues
(
const
framework
::
Scope
*
scope
,
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
*
cases
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
rCond
,
std
::
atomic
<
int
>
*
caseToExecute
,
std
::
atomic
<
bool
>
*
completed
,
std
::
recursive_mutex
*
callbackMutex
)
const
{
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>::
iterator
it
=
cases
->
begin
();
while
(
it
!=
cases
->
end
())
{
std
::
shared_ptr
<
SelectOpCase
>
c
=
*
it
;
auto
chVar
=
scope
->
FindVar
(
c
->
channelName
);
framework
::
ChannelHolder
*
ch
=
chVar
->
GetMutable
<
framework
::
ChannelHolder
>
();
std
::
function
<
bool
(
framework
::
ChannelAction
channelAction
)
>
cb
=
[
&
caseToExecute
,
&
completed
,
&
callbackMutex
,
c
](
framework
::
ChannelAction
channelAction
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
*
callbackMutex
};
bool
canProcess
=
false
;
if
(
!
(
*
completed
))
{
// If the channel wasn't closed, we set the caseToExecute index
// as this current case
if
(
channelAction
!=
framework
::
ChannelAction
::
CLOSE
)
{
*
caseToExecute
=
c
->
caseIndex
;
}
// This will allow our conditional variable to break out of wait
*
completed
=
true
;
canProcess
=
true
;
}
return
canProcess
;
};
switch
(
c
->
caseType
)
{
case
SelectOpCaseType
::
SEND
:
{
auto
chOutputVar
=
scope
->
FindVar
(
c
->
varName
);
concurrency
::
ChannelAddToSendQ
(
ch
,
this
,
chOutputVar
,
rCond
,
cb
);
break
;
}
case
SelectOpCaseType
::
RECEIVE
:
{
auto
chOutputVar
=
scope
->
FindVar
(
c
->
varName
);
concurrency
::
ChannelAddToReceiveQ
(
ch
,
this
,
chOutputVar
,
rCond
,
cb
);
break
;
}
default:
break
;
}
++
it
;
}
}
void
removeThreadOnChannelQueues
(
const
framework
::
Scope
*
scope
,
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
*
cases
)
const
{
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>::
iterator
it
=
cases
->
begin
();
while
(
it
!=
cases
->
end
())
{
std
::
shared_ptr
<
SelectOpCase
>
c
=
*
it
;
auto
chVar
=
scope
->
FindVar
(
c
->
channelName
);
framework
::
ChannelHolder
*
ch
=
chVar
->
GetMutable
<
framework
::
ChannelHolder
>
();
switch
(
c
->
caseType
)
{
case
SelectOpCaseType
::
SEND
:
{
ch
->
RemoveFromSendQ
(
this
);
break
;
}
case
SelectOpCaseType
::
RECEIVE
:
{
ch
->
RemoveFromReceiveQ
(
this
);
break
;
}
default:
break
;
}
++
it
;
}
}
};
class
SelectOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
kX
,
"A set of variables, which are required by operators inside the "
"cases of Select Op"
)
.
AsDuplicable
();
AddInput
(
kCaseToExecute
,
"(Int) The variable the sets the index of the case to execute, "
"after evaluating the channels being sent to and received from"
)
.
AsDuplicable
();
AddOutput
(
kOutputs
,
"A set of variables, which will be assigned with values "
"generated by the operators inside the cases of Select Op."
)
.
AsDuplicable
();
AddAttr
<
std
::
vector
<
std
::
string
>>
(
kCases
,
"(String vector) Serialized list of"
"all cases in the select op. Each"
"case is serialized as: "
"'<index>,<type>,<channel>,<value>'"
"where type is 0 for default, 1 for"
"send, and 2 for receive"
"No channel and values are needed for"
"default cases."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kCasesBlock
,
"The cases block inside select_op"
);
AddComment
(
R"DOC(
)DOC"
);
}
};
// TODO(thuan): Implement Gradient Operator for SELECT_OP
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
select
,
paddle
::
operators
::
SelectOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
SelectOpMaker
);
paddle/fluid/pybind/protobuf.cc
浏览文件 @
425a8821
...
@@ -214,7 +214,6 @@ void BindVarDsec(pybind11::module *m) {
...
@@ -214,7 +214,6 @@ void BindVarDsec(pybind11::module *m) {
.
def
(
"set_shapes"
,
&
pd
::
VarDesc
::
SetShapes
)
.
def
(
"set_shapes"
,
&
pd
::
VarDesc
::
SetShapes
)
.
def
(
"set_dtype"
,
&
pd
::
VarDesc
::
SetDataType
)
.
def
(
"set_dtype"
,
&
pd
::
VarDesc
::
SetDataType
)
.
def
(
"set_dtypes"
,
&
pd
::
VarDesc
::
SetDataTypes
)
.
def
(
"set_dtypes"
,
&
pd
::
VarDesc
::
SetDataTypes
)
.
def
(
"set_capacity"
,
&
pd
::
VarDesc
::
SetCapacity
)
.
def
(
"shape"
,
&
pd
::
VarDesc
::
GetShape
,
.
def
(
"shape"
,
&
pd
::
VarDesc
::
GetShape
,
pybind11
::
return_value_policy
::
reference
)
pybind11
::
return_value_policy
::
reference
)
.
def
(
"shapes"
,
&
pd
::
VarDesc
::
GetShapes
,
.
def
(
"shapes"
,
&
pd
::
VarDesc
::
GetShapes
,
...
@@ -251,7 +250,6 @@ void BindVarDsec(pybind11::module *m) {
...
@@ -251,7 +250,6 @@ void BindVarDsec(pybind11::module *m) {
.
value
(
"STEP_SCOPES"
,
pd
::
proto
::
VarType
::
STEP_SCOPES
)
.
value
(
"STEP_SCOPES"
,
pd
::
proto
::
VarType
::
STEP_SCOPES
)
.
value
(
"LOD_RANK_TABLE"
,
pd
::
proto
::
VarType
::
LOD_RANK_TABLE
)
.
value
(
"LOD_RANK_TABLE"
,
pd
::
proto
::
VarType
::
LOD_RANK_TABLE
)
.
value
(
"LOD_TENSOR_ARRAY"
,
pd
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
.
value
(
"LOD_TENSOR_ARRAY"
,
pd
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
.
value
(
"CHANNEL"
,
pd
::
proto
::
VarType
::
CHANNEL
)
.
value
(
"PLACE_LIST"
,
pd
::
proto
::
VarType
::
PLACE_LIST
)
.
value
(
"PLACE_LIST"
,
pd
::
proto
::
VarType
::
PLACE_LIST
)
.
value
(
"READER"
,
pd
::
proto
::
VarType
::
READER
)
.
value
(
"READER"
,
pd
::
proto
::
VarType
::
READER
)
.
value
(
"RAW"
,
pd
::
proto
::
VarType
::
RAW
);
.
value
(
"RAW"
,
pd
::
proto
::
VarType
::
RAW
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
425a8821
...
@@ -21,7 +21,6 @@ limitations under the License. */
...
@@ -21,7 +21,6 @@ limitations under the License. */
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
...
...
python/paddle/fluid/concurrency.py
已删除
100644 → 0
浏览文件 @
23644940
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
.layers.control_flow
import
BlockGuard
,
equal
from
.framework
import
Operator
from
.layer_helper
import
LayerHelper
,
unique_name
from
.layers
import
fill_constant
from
.
import
core
__all__
=
[
'make_channel'
,
'channel_send'
,
'channel_recv'
,
'channel_close'
,
'Select'
]
class
Go
(
BlockGuard
):
def
__init__
(
self
,
name
=
None
):
self
.
helper
=
LayerHelper
(
"go"
,
name
=
name
)
super
(
Go
,
self
).
__init__
(
self
.
helper
.
main_program
)
def
__enter__
(
self
):
super
(
Go
,
self
).
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
if
exc_type
is
not
None
:
return
False
self
.
_construct_go_op
()
return
super
(
Go
,
self
).
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
def
_construct_go_op
(
self
):
main_program
=
self
.
helper
.
main_program
go_block
=
main_program
.
current_block
()
parent_block
=
main_program
.
block
(
main_program
.
current_block
()
.
parent_idx
)
inner_outputs
=
set
()
x_name_list
=
set
()
for
op
in
go_block
.
ops
:
# Iterate over all operators, get all the inputs
# and add as input to the Go operator.
for
iname
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
iname
):
if
in_var_name
not
in
inner_outputs
:
x_name_list
.
add
(
in_var_name
)
for
oname
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
oname
):
inner_outputs
.
add
(
out_var_name
)
# Iterate over all operators , get all the outputs
# add to the output list of Go operator only if
# they exist in the parent block.
out_vars
=
[]
for
inner_out_name
in
inner_outputs
:
if
inner_out_name
in
parent_block
.
vars
:
out_vars
.
append
(
parent_block
.
var
(
inner_out_name
))
parent_block
.
append_op
(
type
=
'go'
,
inputs
=
{
'X'
:
[
parent_block
.
_var_recursive
(
x_name
)
for
x_name
in
x_name_list
]
},
outputs
=
{},
attrs
=
{
'sub_block'
:
go_block
})
class
SelectCase
(
object
):
DEFAULT
=
0
SEND
=
1
RECEIVE
=
2
def
__init__
(
self
,
select
,
case_idx
,
case_to_execute
,
channel_action_fn
=
None
,
channel
=
None
,
value
=
None
,
is_copy
=
False
):
self
.
select
=
select
self
.
helper
=
LayerHelper
(
'conditional_block'
)
self
.
main_program
=
self
.
helper
.
main_program
self
.
is_scalar_condition
=
True
self
.
case_to_execute
=
case_to_execute
self
.
idx
=
case_idx
# Since we aren't going to use the `channel_send` or `channel_recv`
# functions directly, we just need to capture the name.
self
.
action
=
(
self
.
SEND
if
channel_action_fn
.
__name__
==
(
'channel_send'
)
else
self
.
RECEIVE
)
if
channel_action_fn
else
self
.
DEFAULT
X
=
value
if
self
.
action
==
self
.
SEND
and
is_copy
:
# We create of copy of the data we want to send
copied_X
=
self
.
select
.
parent_block
.
create_var
(
name
=
unique_name
.
generate
(
value
.
name
+
'_copy'
),
type
=
value
.
type
,
dtype
=
value
.
dtype
,
shape
=
value
.
shape
,
lod_level
=
value
.
lod_level
,
capacity
=
value
.
capacity
if
hasattr
(
value
,
'capacity'
)
else
None
,
)
self
.
select
.
parent_block
.
append_op
(
type
=
"assign"
,
inputs
=
{
"X"
:
value
},
outputs
=
{
"Out"
:
copied_X
})
X
=
copied_X
self
.
value
=
X
self
.
channel
=
channel
def
__enter__
(
self
):
self
.
block
=
self
.
main_program
.
_create_block
()
def
construct_op
(
self
):
main_program
=
self
.
helper
.
main_program
cases_block
=
main_program
.
current_block
()
inner_outputs
=
set
()
input_set
=
set
()
params
=
set
()
for
op
in
self
.
block
.
ops
:
# Iterate over all operators, get all the inputs
# and add as input to the SelectCase operator.
for
iname
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
iname
):
if
in_var_name
not
in
inner_outputs
:
input_set
.
add
(
in_var_name
)
for
oname
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
oname
):
inner_outputs
.
add
(
out_var_name
)
param_list
=
[
cases_block
.
var
(
each_name
)
for
each_name
in
params
if
each_name
not
in
input_set
]
# Iterate over all operators, get all the outputs
# add to the output list of SelectCase operator only if
# they exist in the parent block.
out_vars
=
[]
for
inner_out_name
in
inner_outputs
:
if
inner_out_name
in
cases_block
.
vars
:
out_vars
.
append
(
cases_block
.
var
(
inner_out_name
))
# First, create an op that will determine whether or not this is the
# conditional variable to execute.
should_execute_block
=
equal
(
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
self
.
idx
),
self
.
case_to_execute
)
step_scope
=
cases_block
.
create_var
(
type
=
core
.
VarDesc
.
VarType
.
STEP_SCOPES
)
cases_block
.
append_op
(
type
=
'conditional_block'
,
inputs
=
{
'X'
:
[
should_execute_block
],
'Params'
:
param_list
},
outputs
=
{
'Out'
:
out_vars
,
'Scope'
:
[
step_scope
]},
attrs
=
{
'sub_block'
:
self
.
block
,
'is_scalar_condition'
:
self
.
is_scalar_condition
})
return
'%s,%s,%s,%s'
%
(
self
.
idx
,
self
.
action
,
self
.
channel
.
name
if
self
.
channel
else
''
,
self
.
value
.
name
if
self
.
value
else
''
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
main_program
.
_rollback
()
if
exc_type
is
not
None
:
return
False
# re-raise exception
return
True
class
Select
(
BlockGuard
):
def
__init__
(
self
,
name
=
None
):
self
.
helper
=
LayerHelper
(
'select'
,
name
=
name
)
self
.
parent_block
=
self
.
helper
.
main_program
.
current_block
()
self
.
cases
=
[]
super
(
Select
,
self
).
__init__
(
self
.
helper
.
main_program
)
self
.
case_to_execute
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=-
1
)
def
__enter__
(
self
):
super
(
Select
,
self
).
__enter__
()
return
self
def
case
(
self
,
channel_action_fn
,
channel
,
value
,
is_copy
=
False
):
"""Create a new block for this condition.
"""
select_case
=
SelectCase
(
self
,
len
(
self
.
cases
),
self
.
case_to_execute
,
channel_action_fn
,
channel
,
value
,
is_copy
)
self
.
cases
.
append
(
select_case
)
return
select_case
def
default
(
self
):
"""Create a default case block for this condition.
"""
default_case
=
SelectCase
(
self
,
len
(
self
.
cases
),
self
.
case_to_execute
)
self
.
cases
.
append
(
default_case
)
return
default_case
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
if
exc_type
is
not
None
:
return
False
# Create a select op and another block to wrap its
# case blocks.
select_block
=
self
.
helper
.
main_program
.
current_block
()
parent_block
=
self
.
helper
.
main_program
.
block
(
select_block
.
parent_idx
)
# Construct each case op, inside the newly created select block.
serialized_cases
=
[]
for
case
in
self
.
cases
:
serialized_cases
.
append
(
case
.
construct_op
())
intermediate
=
set
()
params
=
set
()
for
case_block
in
select_block
.
ops
:
if
case_block
.
attrs
and
'sub_block'
in
case_block
.
attrs
:
for
each_op
in
case_block
.
attrs
[
'sub_block'
].
ops
:
assert
isinstance
(
each_op
,
Operator
)
for
iname
in
each_op
.
input_names
:
for
in_var_name
in
each_op
.
input
(
iname
):
if
in_var_name
not
in
intermediate
:
params
.
add
(
in_var_name
)
for
oname
in
each_op
.
output_names
:
for
out_var_name
in
each_op
.
output
(
oname
):
intermediate
.
add
(
out_var_name
)
out_list
=
[
parent_block
.
var
(
var_name
)
for
var_name
in
parent_block
.
vars
if
var_name
in
intermediate
]
X
=
[
select_block
.
_var_recursive
(
x_name
)
for
x_name
in
params
]
# Needs to be used by `equal` inside the cases block.
X
.
append
(
self
.
case_to_execute
)
# Construct the select op.
parent_block
.
append_op
(
type
=
'select'
,
inputs
=
{
'X'
:
X
,
'case_to_execute'
:
self
.
case_to_execute
},
attrs
=
{
'sub_block'
:
select_block
,
'cases'
:
serialized_cases
},
outputs
=
{
'Out'
:
out_list
})
return
super
(
Select
,
self
).
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
def
make_channel
(
dtype
,
capacity
=
0
):
"""
Helps implementation of a concurrent program by creating a "channel" of
a defined data type. Channels allow for the passing of data in
concurrent scenarios - such as when using threads to divide computation.
Channels can be used to "send" and "receive" such data concurrently.
There are two kinds of channels: unbuffered and buffered. Unbuffered
channels have no capacity - and thus, block on send and only unblock only
once what they have sent has been received.
On the other hand, buffered channels are initialized with a capacity -
and do not block on sends.
Use this method in combination with `channel_send`, `channel_recv`,
`channel_close`, and `Go` to design a concurrent Paddle program.
Args:
dtype (ParamAttr|string): Data type of the data sent in the channel.
This data type should be the string name of a numpy data type.
capacity (ParamAttr|int): Size of the channel. Defaults to 0 for
to create an unbuffered channel.
Returns:
Variable: The channel variable that can be used to send an receive data
of the defined dtype.
Examples:
.. code-block:: python
ch = fluid.make_channel(dtype='int32', capacity=10)
...
# Code to execute in a Go block, which receives the channel data.
fluid.channel_send(ch, 100)
fluid.channel_close(ch)
"""
helper
=
LayerHelper
(
'channel_create'
,
**
locals
())
main_program
=
helper
.
main_program
make_channel_block
=
main_program
.
current_block
()
# Make a channel variable (using the channel data type) and make sure it
# persists into the global scope.
channel
=
helper
.
create_variable
(
name
=
unique_name
.
generate
(
'channel'
),
type
=
core
.
VarDesc
.
VarType
.
CHANNEL
,
persistable
=
True
)
create_channel_op
=
make_channel_block
.
append_op
(
type
=
"channel_create"
,
outputs
=
{
"Out"
:
channel
},
attrs
=
{
"data_type"
:
dtype
,
"capacity"
:
capacity
})
return
channel
def
channel_send
(
channel
,
value
,
is_copy
=
False
):
"""
Sends a value through a channel variable. Used by an unbuffered or buffered
channel to pass data from within or to a concurrent Go block, where
`channel_recv` to used to get the passed value.
Args:
channel (Variable|Channel): Channel variable created using
`make_channel`.
value (Variable): Value to send to channel
is_copy (bool): Copy data while channel send. If False, then data
is moved. The input cannot be used after move. (default False)
Returns:
Variable: The boolean status on whether or not the channel
successfully sent the passed value.
Examples:
.. code-block:: python
ch = fluid.make_channel(dtype='int32', capacity=10)
...
# Code to execute in a Go block, which receives the channel data.
fluid.channel_send(ch, 100)
"""
helper
=
LayerHelper
(
'channel_send'
,
**
locals
())
main_program
=
helper
.
main_program
channel_send_block
=
main_program
.
current_block
()
X
=
value
if
is_copy
:
copied_X
=
helper
.
create_variable
(
name
=
unique_name
.
generate
(
value
.
name
+
'_copy'
),
type
=
value
.
type
,
dtype
=
value
.
dtype
,
shape
=
value
.
shape
,
lod_level
=
value
.
lod_level
,
capacity
=
value
.
capacity
if
hasattr
(
value
,
'capacity'
)
else
None
)
assign_op
=
channel_send_block
.
append_op
(
type
=
"assign"
,
inputs
=
{
"X"
:
value
},
outputs
=
{
"Out"
:
copied_X
})
X
=
copied_X
channel_send_block
.
append_op
(
type
=
"channel_send"
,
inputs
=
{
"Channel"
:
channel
,
"X"
:
X
,
})
def
channel_recv
(
channel
,
return_value
):
"""
Receives a value through a channel variable. Used by an unbuffered or
buffered channel within a concurrent Go block to get data from originally
sent using `channel_send`, or from outside such a block where
`channel_send` is used to send the value.
Args:
channel (Variable|Channel): Channel variable created using
`make_channel`.
return_value (Variable): Variable to set as a result of running channel_recv_op
Returns:
Variable: The received value from the channel.
Variable: The boolean status on whether or not the channel
successfully received the passed value.
Examples:
.. code-block:: python
ch = fluid.make_channel(dtype='int32', capacity=10)
with fluid.Go():
returned_value, return_status = fluid.channel_recv(ch, 'int32')
# Code to send data through the channel.
"""
helper
=
LayerHelper
(
'channel_recv'
,
**
locals
())
main_program
=
helper
.
main_program
channel_recv_block
=
main_program
.
current_block
()
status
=
helper
.
create_variable
(
name
=
unique_name
.
generate
(
'status'
),
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
dtype
=
core
.
VarDesc
.
VarType
.
BOOL
)
channel_recv_op
=
channel_recv_block
.
append_op
(
type
=
"channel_recv"
,
inputs
=
{
"Channel"
:
channel
},
outputs
=
{
"Out"
:
return_value
,
"Status"
:
status
})
return
return_value
,
status
def
channel_close
(
channel
):
"""
Closes a channel created using `make_channel`.
Args:
channel (Variable|Channel): Channel variable created using
`make_channel`.
Examples:
.. code-block:: python
ch = fluid.make_channel(dtype='int32', capacity=10)
...
# Code to receive and send data through a channel
...
fluid.channel_close(ch)
"""
helper
=
LayerHelper
(
'channel_close'
,
**
locals
())
main_program
=
helper
.
main_program
channel_close_block
=
main_program
.
current_block
()
channel_close_op
=
channel_close_block
.
append_op
(
type
=
"channel_close"
,
inputs
=
{
"Channel"
:
channel
})
python/paddle/fluid/framework.py
浏览文件 @
425a8821
...
@@ -541,8 +541,7 @@ class Operator(object):
...
@@ -541,8 +541,7 @@ class Operator(object):
'feed'
,
'fetch'
,
'save'
,
'load'
,
'recurrent'
,
'go'
,
'feed'
,
'fetch'
,
'save'
,
'load'
,
'recurrent'
,
'go'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'listen_and_serv'
,
'parallel_do'
,
'save_combine'
,
'load_combine'
,
'listen_and_serv'
,
'parallel_do'
,
'save_combine'
,
'load_combine'
,
'ncclInit'
,
'channel_create'
,
'channel_close'
,
'channel_send'
,
'ncclInit'
,
'select'
,
'checkpoint_notify'
,
'gen_nccl_id'
'channel_recv'
,
'select'
,
'checkpoint_notify'
,
'gen_nccl_id'
}
}
def
__init__
(
self
,
def
__init__
(
self
,
...
...
python/paddle/fluid/tests/no_test_concurrency.py
已删除
100644 → 0
浏览文件 @
23644940
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid
import
framework
,
unique_name
,
layer_helper
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.layers
import
fill_constant
,
assign
,
While
,
elementwise_add
,
Print
class
TestRoutineOp
(
unittest
.
TestCase
):
def
test_simple_routine
(
self
):
ch
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
# Create LOD_TENSOR<INT64> and put it into the scope. This placeholder
# variable will be filled in and returned by fluid.channel_recv
result
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT64
)
with
fluid
.
Go
():
input_value
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
FP64
,
value
=
1234
)
fluid
.
channel_send
(
ch
,
input_value
)
result
,
status
=
fluid
.
channel_recv
(
ch
,
result
)
fluid
.
channel_close
(
ch
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
outs
=
exe
.
run
(
fetch_list
=
[
result
])
self
.
assertEqual
(
outs
[
0
],
1234
)
def
test_daisy_chain
(
self
):
'''
Mimics classic Daisy-chain test: https://talks.golang.org/2012/concurrency.slide#39
'''
n
=
100
leftmost
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
left
=
leftmost
# TODO(thuan): Use fluid.While() after scope capture is implemented.
# https://github.com/PaddlePaddle/Paddle/issues/8502
for
i
in
range
(
n
):
right
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
with
fluid
.
Go
():
one_tensor
=
self
.
_create_one_dim_tensor
(
1
)
result
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT64
)
result
,
status
=
fluid
.
channel_recv
(
right
,
result
)
one_added
=
fluid
.
layers
.
elementwise_add
(
x
=
one_tensor
,
y
=
result
)
fluid
.
channel_send
(
left
,
one_added
)
left
=
right
# Trigger the channel propagation by sending a "1" to rightmost channel
with
fluid
.
Go
():
one_tensor
=
self
.
_create_one_dim_tensor
(
1
)
fluid
.
channel_send
(
right
,
one_tensor
)
leftmost_result
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT64
)
leftmost_result
,
status
=
fluid
.
channel_recv
(
leftmost
,
leftmost_result
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
leftmost_data
=
exe
.
run
(
fetch_list
=
[
leftmost_result
])
# The leftmost_data should be equal to the number of channels + 1
self
.
assertEqual
(
leftmost_data
[
0
][
0
],
n
+
1
)
def
_create_one_dim_tensor
(
self
,
value
):
one_dim_tensor
=
fill_constant
(
shape
=
[
1
],
dtype
=
'int'
,
value
=
value
)
one_dim_tensor
.
stop_gradient
=
True
return
one_dim_tensor
def
_create_tensor
(
self
,
name
,
type
,
dtype
):
return
framework
.
default_main_program
().
current_block
().
create_var
(
name
=
unique_name
.
generate
(
name
),
type
=
type
,
dtype
=
dtype
)
def
_create_persistable_tensor
(
self
,
name
,
type
,
dtype
):
return
framework
.
default_main_program
().
current_block
().
create_var
(
name
=
unique_name
.
generate
(
name
),
type
=
type
,
dtype
=
dtype
,
persistable
=
True
)
def
test_select
(
self
):
with
framework
.
program_guard
(
framework
.
Program
()):
ch1
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
capacity
=
1
)
result1
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
FP64
)
input_value
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
FP64
,
value
=
10
)
with
fluid
.
Select
()
as
select
:
with
select
.
case
(
fluid
.
channel_send
,
ch1
,
input_value
):
# Execute something.
pass
with
select
.
default
():
pass
# This should not block because we are using a buffered channel.
result1
,
status
=
fluid
.
channel_recv
(
ch1
,
result1
)
fluid
.
channel_close
(
ch1
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
result
=
exe
.
run
(
fetch_list
=
[
result1
])
self
.
assertEqual
(
result
[
0
][
0
],
10
)
def
test_fibonacci
(
self
):
"""
Mimics Fibonacci Go example: https://tour.golang.org/concurrency/5
"""
with
framework
.
program_guard
(
framework
.
Program
()):
quit_ch_input_var
=
self
.
_create_persistable_tensor
(
'quit_ch_input'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT32
)
quit_ch_input
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
,
out
=
quit_ch_input_var
)
result
=
self
.
_create_persistable_tensor
(
'result'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT32
)
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
,
out
=
result
)
x
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
)
y
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
1
)
while_cond
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
BOOL
,
value
=
True
)
while_false
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
BOOL
,
value
=
False
)
x_tmp
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
)
def
fibonacci
(
channel
,
quit_channel
):
while_op
=
While
(
cond
=
while_cond
)
with
while_op
.
block
():
result2
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
)
with
fluid
.
Select
()
as
select
:
with
select
.
case
(
fluid
.
channel_send
,
channel
,
x
,
is_copy
=
True
):
assign
(
input
=
x
,
output
=
x_tmp
)
assign
(
input
=
y
,
output
=
x
)
assign
(
elementwise_add
(
x
=
x_tmp
,
y
=
y
),
output
=
y
)
with
select
.
case
(
fluid
.
channel_recv
,
quit_channel
,
result2
):
# Quit
helper
=
layer_helper
.
LayerHelper
(
'assign'
)
helper
.
append_op
(
type
=
'assign'
,
inputs
=
{
'X'
:
[
while_false
]},
outputs
=
{
'Out'
:
[
while_cond
]})
ch1
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
quit_ch
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
with
fluid
.
Go
():
for
i
in
range
(
10
):
fluid
.
channel_recv
(
ch1
,
result
)
Print
(
result
)
fluid
.
channel_send
(
quit_ch
,
quit_ch_input
)
fibonacci
(
ch1
,
quit_ch
)
fluid
.
channel_close
(
ch1
)
fluid
.
channel_close
(
quit_ch
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
exe_result
=
exe
.
run
(
fetch_list
=
[
result
])
self
.
assertEqual
(
exe_result
[
0
][
0
],
34
)
def
test_ping_pong
(
self
):
"""
Mimics Ping Pong example: https://gobyexample.com/channel-directions
"""
with
framework
.
program_guard
(
framework
.
Program
()):
result
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
FP64
)
ping_result
=
self
.
_create_tensor
(
'ping_return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
FP64
)
def
ping
(
ch
,
message
):
fluid
.
channel_send
(
ch
,
message
,
is_copy
=
True
)
def
pong
(
ch1
,
ch2
):
fluid
.
channel_recv
(
ch1
,
ping_result
)
fluid
.
channel_send
(
ch2
,
ping_result
,
is_copy
=
True
)
pings
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
capacity
=
1
)
pongs
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
capacity
=
1
)
msg
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
FP64
,
value
=
9
)
ping
(
pings
,
msg
)
pong
(
pings
,
pongs
)
fluid
.
channel_recv
(
pongs
,
result
)
fluid
.
channel_close
(
pings
)
fluid
.
channel_close
(
pongs
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
exe_result
=
exe
.
run
(
fetch_list
=
[
result
])
self
.
assertEqual
(
exe_result
[
0
][
0
],
9
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/notest_concurrency.py
已删除
100644 → 0
浏览文件 @
23644940
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.executor
import
Executor
class
TestRoutineOp
(
unittest
.
TestCase
):
def
test_simple_routine
(
self
):
ch
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
BOOL
,
name
=
"CreateChannel"
)
with
fluid
.
Go
():
fluid
.
channel_send
(
ch
,
True
)
result
=
fluid
.
channel_recv
(
ch
)
fluid
.
channel_close
(
ch
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
outs
=
exe
.
run
(
fetch_list
=
[
result
])
self
.
assertEqual
(
outs
[
0
],
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录