Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f63ab561
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f63ab561
编写于
9月 29, 2018
作者:
D
Dang Qingqing
浏览文件
操作
浏览文件
下载
差异文件
Fix conflict.
上级
8f5d918a
425a8821
变更
44
显示空白变更内容
内联
并排
Showing
44 changed file
with
1028 addition
and
3722 deletion
+1028
-3722
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+0
-7
paddle/fluid/framework/channel.h
paddle/fluid/framework/channel.h
+0
-291
paddle/fluid/framework/channel_impl.h
paddle/fluid/framework/channel_impl.h
+0
-369
paddle/fluid/framework/channel_test.cc
paddle/fluid/framework/channel_test.cc
+0
-1008
paddle/fluid/framework/concurrency_test.cc
paddle/fluid/framework/concurrency_test.cc
+0
-292
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+1
-4
paddle/fluid/framework/framework.proto
paddle/fluid/framework/framework.proto
+0
-7
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
+243
-0
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
+40
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+18
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+17
-0
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+4
-4
paddle/fluid/framework/tuple.h
paddle/fluid/framework/tuple.h
+0
-1
paddle/fluid/framework/var_desc.cc
paddle/fluid/framework/var_desc.cc
+2
-52
paddle/fluid/framework/var_desc.h
paddle/fluid/framework/var_desc.h
+0
-4
paddle/fluid/framework/var_type.h
paddle/fluid/framework/var_type.h
+0
-6
paddle/fluid/inference/analysis/analysis_pass.h
paddle/fluid/inference/analysis/analysis_pass.h
+0
-6
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+9
-8
paddle/fluid/inference/api/api_impl_tester.cc
paddle/fluid/inference/api/api_impl_tester.cc
+11
-5
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+1
-1
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
...nference/tests/api/analyzer_text_classification_tester.cc
+13
-0
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+0
-5
paddle/fluid/operators/channel_close_op.cc
paddle/fluid/operators/channel_close_op.cc
+0
-70
paddle/fluid/operators/channel_create_op.cc
paddle/fluid/operators/channel_create_op.cc
+0
-113
paddle/fluid/operators/channel_recv_op.cc
paddle/fluid/operators/channel_recv_op.cc
+0
-98
paddle/fluid/operators/channel_send_op.cc
paddle/fluid/operators/channel_send_op.cc
+0
-76
paddle/fluid/operators/concurrency/CMakeLists.txt
paddle/fluid/operators/concurrency/CMakeLists.txt
+0
-1
paddle/fluid/operators/concurrency/channel_util.cc
paddle/fluid/operators/concurrency/channel_util.cc
+0
-111
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+1
-0
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+1
-0
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+1
-0
paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
+604
-0
paddle/fluid/operators/fused_embedding_fc_lstm_op.h
paddle/fluid/operators/fused_embedding_fc_lstm_op.h
+41
-0
paddle/fluid/operators/select_op.cc
paddle/fluid/operators/select_op.cc
+0
-419
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+0
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
paddle/legacy/trainer/tests/CMakeLists.txt
paddle/legacy/trainer/tests/CMakeLists.txt
+5
-1
python/paddle/fluid/concurrency.py
python/paddle/fluid/concurrency.py
+0
-454
python/paddle/fluid/contrib/tests/test_quantize_transpiler.py
...on/paddle/fluid/contrib/tests/test_quantize_transpiler.py
+1
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+1
-2
python/paddle/fluid/tests/book/high-level-api/recognize_digits/CMakeLists.txt
...tests/book/high-level-api/recognize_digits/CMakeLists.txt
+13
-3
python/paddle/fluid/tests/no_test_concurrency.py
python/paddle/fluid/tests/no_test_concurrency.py
+0
-260
python/paddle/fluid/tests/notest_concurrency.py
python/paddle/fluid/tests/notest_concurrency.py
+0
-41
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
f63ab561
...
...
@@ -169,15 +169,8 @@ cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto
)
cc_test
(
cow_ptr_tests SRCS details/cow_ptr_test.cc
)
# cc_test(channel_test SRCS channel_test.cc)
cc_test
(
tuple_test SRCS tuple_test.cc
)
if
(
NOT WIN32
)
cc_test
(
rw_lock_test SRCS rw_lock_test.cc
)
endif
(
NOT WIN32
)
# disable test temporarily.
# TODO https://github.com/PaddlePaddle/Paddle/issues/11971
# cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
# channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
# conditional_block_op while_op assign_op print_op executor proto_desc)
paddle/fluid/framework/channel.h
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <condition_variable> // NOLINT
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
enum
class
ChannelAction
{
SEND
=
0
,
RECEIVE
=
1
,
CLOSE
=
2
,
};
// Channel is the abstract class of buffered and un-buffered channels.
template
<
typename
T
>
class
Channel
{
public:
virtual
bool
CanSend
()
=
0
;
virtual
bool
CanReceive
()
=
0
;
virtual
void
Send
(
T
*
)
=
0
;
virtual
bool
Receive
(
T
*
)
=
0
;
virtual
size_t
Cap
()
=
0
;
virtual
void
Lock
()
=
0
;
virtual
void
Unlock
()
=
0
;
virtual
bool
IsClosed
()
=
0
;
virtual
void
Close
()
=
0
;
virtual
~
Channel
()
{}
virtual
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
=
0
;
virtual
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
=
0
;
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
)
=
0
;
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
=
0
;
};
// Forward declaration of channel implementations.
template
<
typename
T
>
class
ChannelImpl
;
template
<
typename
T
>
Channel
<
T
>*
MakeChannel
(
size_t
buffer_size
)
{
return
new
ChannelImpl
<
T
>
(
buffer_size
);
}
template
<
typename
T
>
void
CloseChannel
(
Channel
<
T
>*
ch
)
{
ch
->
Close
();
}
/*
* The ChannelHolder class serves two main purposes:
* 1. It acts as a unified wrapper for the different kinds of
* channels, i.e. Buffered and Unbuffered channels. This is
* similar to the ReaderHolder class.
* 2. It also helps us in TypeHiding. This is similar to the
* PlaceHolder implementations in variable.h and tensor.h.
*/
class
ChannelHolder
{
public:
template
<
typename
T
>
void
Reset
(
size_t
buffer_size
)
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
buffer_size
));
}
template
<
typename
T
>
void
Send
(
T
*
data
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
PADDLE_ENFORCE_EQ
(
holder_
->
Type
(),
std
::
type_index
(
typeid
(
T
)),
"Channel type is not same as the type of the data being sent"
);
// Static cast should be safe because we have ensured that types are same
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
PADDLE_ENFORCE_EQ
(
channel
!=
nullptr
,
true
,
"Channel should not be null."
);
channel
->
Send
(
data
);
}
template
<
typename
T
>
bool
Receive
(
T
*
data
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
PADDLE_ENFORCE_EQ
(
holder_
->
Type
(),
std
::
type_index
(
typeid
(
T
)),
"Channel type is not same as the type of the data being sent"
);
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
PADDLE_ENFORCE_EQ
(
channel
!=
nullptr
,
true
,
"Channel should not be null."
);
return
channel
->
Receive
(
data
);
}
bool
IsClosed
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
IsClosed
();
}
bool
CanSend
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
CanSend
();
}
bool
CanReceive
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
CanReceive
();
}
void
close
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
Close
();
}
size_t
Cap
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
Cap
();
}
void
Lock
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
Lock
();
}
void
Unlock
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
Unlock
();
}
template
<
typename
T
>
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
if
(
channel
!=
nullptr
)
{
channel
->
AddToSendQ
(
referrer
,
data
,
cond
,
cb
);
}
}
template
<
typename
T
>
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
if
(
channel
!=
nullptr
)
{
channel
->
AddToReceiveQ
(
referrer
,
data
,
cond
,
cb
);
}
}
void
RemoveFromSendQ
(
const
void
*
referrer
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
RemoveFromSendQ
(
referrer
);
}
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
holder_
->
RemoveFromReceiveQ
(
referrer
);
}
inline
bool
IsInitialized
()
const
{
return
holder_
!=
nullptr
;
}
inline
const
std
::
type_index
Type
()
{
PADDLE_ENFORCE_EQ
(
IsInitialized
(),
true
,
"The Channel hasn't been initialized"
);
return
holder_
->
Type
();
}
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of ChannelHolder.
*/
struct
Placeholder
{
virtual
~
Placeholder
()
{}
virtual
const
std
::
type_index
Type
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
virtual
bool
IsClosed
()
=
0
;
virtual
bool
CanSend
()
=
0
;
virtual
bool
CanReceive
()
=
0
;
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
)
=
0
;
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
=
0
;
virtual
void
Close
()
=
0
;
virtual
void
Lock
()
=
0
;
virtual
void
Unlock
()
=
0
;
virtual
size_t
Cap
()
=
0
;
};
template
<
typename
T
>
struct
PlaceholderImpl
:
public
Placeholder
{
explicit
PlaceholderImpl
(
size_t
buffer_size
)
:
type_
(
std
::
type_index
(
typeid
(
T
)))
{
channel_
.
reset
(
MakeChannel
<
T
>
(
buffer_size
));
}
virtual
const
std
::
type_index
Type
()
const
{
return
type_
;
}
virtual
void
*
Ptr
()
const
{
return
static_cast
<
void
*>
(
channel_
.
get
());
}
virtual
bool
IsClosed
()
{
if
(
channel_
)
{
return
channel_
->
IsClosed
();
}
return
false
;
}
virtual
bool
CanSend
()
{
if
(
channel_
)
{
return
channel_
->
CanSend
();
}
return
false
;
}
virtual
bool
CanReceive
()
{
if
(
channel_
)
{
return
channel_
->
CanReceive
();
}
return
false
;
}
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
)
{
if
(
channel_
)
{
channel_
->
RemoveFromSendQ
(
referrer
);
}
}
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
)
{
if
(
channel_
)
{
channel_
->
RemoveFromReceiveQ
(
referrer
);
}
}
virtual
void
Close
()
{
if
(
channel_
)
channel_
->
Close
();
}
virtual
size_t
Cap
()
{
if
(
channel_
)
return
channel_
->
Cap
();
else
return
-
1
;
}
virtual
void
Lock
()
{
if
(
channel_
)
channel_
->
Lock
();
}
virtual
void
Unlock
()
{
if
(
channel_
)
channel_
->
Unlock
();
}
std
::
unique_ptr
<
Channel
<
T
>>
channel_
;
const
std
::
type_index
type_
;
};
// Pointer to a PlaceholderImpl object
std
::
unique_ptr
<
Placeholder
>
holder_
;
};
}
// namespace framework
}
// namespace paddle
#include "paddle/fluid/framework/channel_impl.h"
paddle/fluid/framework/channel_impl.h
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <atomic>
#include <condition_variable> // NOLINT
#include <deque>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
class
ChannelImpl
:
public
paddle
::
framework
::
Channel
<
T
>
{
friend
Channel
<
T
>
*
paddle
::
framework
::
MakeChannel
<
T
>
(
size_t
);
friend
void
paddle
::
framework
::
CloseChannel
<
T
>
(
Channel
<
T
>
*
);
public:
virtual
bool
CanSend
();
virtual
bool
CanReceive
();
virtual
void
Send
(
T
*
);
virtual
bool
Receive
(
T
*
);
virtual
size_t
Cap
()
{
return
cap_
;
}
virtual
void
Lock
();
virtual
void
Unlock
();
virtual
bool
IsClosed
();
virtual
void
Close
();
explicit
ChannelImpl
(
size_t
);
virtual
~
ChannelImpl
();
virtual
void
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
);
virtual
void
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
);
virtual
void
RemoveFromSendQ
(
const
void
*
referrer
);
virtual
void
RemoveFromReceiveQ
(
const
void
*
referrer
);
private:
struct
QueueMessage
{
T
*
data
;
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
;
bool
chan_closed
=
false
;
bool
completed
=
false
;
const
void
*
referrer
;
// TODO(thuan): figure out better way to do this
std
::
function
<
bool
(
ChannelAction
)
>
callback
;
explicit
QueueMessage
(
T
*
item
)
:
data
(
item
),
cond
(
std
::
make_shared
<
std
::
condition_variable_any
>
())
{}
QueueMessage
(
T
*
item
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
)
:
data
(
item
),
cond
(
cond
)
{}
void
Wait
(
std
::
unique_lock
<
std
::
recursive_mutex
>
&
lock
)
{
cond
->
wait
(
lock
,
[
this
]()
{
return
completed
;
});
}
void
Notify
()
{
completed
=
true
;
cond
->
notify_all
();
}
};
void
send_return
()
{
send_ctr
--
;
destructor_cond_
.
notify_all
();
}
bool
recv_return
(
bool
value
)
{
recv_ctr
--
;
destructor_cond_
.
notify_all
();
return
value
;
}
std
::
shared_ptr
<
QueueMessage
>
get_first_message
(
std
::
deque
<
std
::
shared_ptr
<
QueueMessage
>>
*
queue
,
ChannelAction
action
)
{
while
(
!
queue
->
empty
())
{
// Check whether this message was added by Select
// If this was added by Select then execute the callback
// to check if you can execute this message. The callback
// can return false if some other case was executed in Select.
// In that case just discard this QueueMessage and process next.
std
::
shared_ptr
<
QueueMessage
>
m
=
queue
->
front
();
queue
->
pop_front
();
if
(
m
->
callback
==
nullptr
||
m
->
callback
(
action
))
return
m
;
}
return
nullptr
;
}
size_t
cap_
;
std
::
recursive_mutex
mu_
;
bool
closed_
;
std
::
deque
<
T
>
buf_
;
std
::
deque
<
std
::
shared_ptr
<
QueueMessage
>>
recvq
;
std
::
deque
<
std
::
shared_ptr
<
QueueMessage
>>
sendq
;
std
::
atomic
<
unsigned
>
send_ctr
{
0
};
std
::
atomic
<
unsigned
>
recv_ctr
{
0
};
std
::
condition_variable_any
destructor_cond_
;
};
template
<
typename
T
>
ChannelImpl
<
T
>::
ChannelImpl
(
size_t
capacity
)
:
cap_
(
capacity
),
closed_
(
false
),
send_ctr
(
0
),
recv_ctr
(
0
)
{
PADDLE_ENFORCE_GE
(
capacity
,
0
);
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
CanSend
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
!
closed_
&&
(
!
recvq
.
empty
()
||
buf_
.
size
()
<
cap_
);
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
CanReceive
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
!
(
closed_
&&
buf_
.
empty
())
&&
(
!
sendq
.
empty
()
||
buf_
.
size
()
>
0
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Send
(
T
*
item
)
{
send_ctr
++
;
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
// If channel is closed, throw exception
if
(
closed_
)
{
send_return
();
lock
.
unlock
();
PADDLE_THROW
(
"Cannot send on closed channel"
);
}
// If there is a receiver, directly pass the value we want
// to send to the receiver, bypassing the channel buffer if any
if
(
!
recvq
.
empty
())
{
std
::
shared_ptr
<
QueueMessage
>
m
=
get_first_message
(
&
recvq
,
ChannelAction
::
SEND
);
if
(
m
!=
nullptr
)
{
*
(
m
->
data
)
=
std
::
move
(
*
item
);
m
->
Notify
();
send_return
();
return
;
}
else
{
Send
(
item
);
send_return
();
return
;
}
}
// Unbuffered channel will always bypass this
// If buffered channel has space in buffer,
// write the element to the buffer.
if
(
buf_
.
size
()
<
cap_
)
{
// Copy to buffer
buf_
.
push_back
(
std
::
move
(
*
item
));
send_return
();
return
;
}
// Block on channel, because some receiver will complete
// the operation for us
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
item
);
sendq
.
push_back
(
m
);
m
->
Wait
(
lock
);
if
(
m
->
chan_closed
)
{
send_return
();
lock
.
unlock
();
PADDLE_THROW
(
"Cannot send on closed channel"
);
}
send_return
();
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
Receive
(
T
*
item
)
{
recv_ctr
++
;
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
// If channel is closed and buffer is empty or
// channel is unbuffered
if
(
closed_
&&
buf_
.
empty
())
return
recv_return
(
false
);
// If there is a sender, directly receive the value we want
// from the sender. In case of a buffered channel, read from
// buffer and move front of send queue to the buffer
if
(
!
sendq
.
empty
())
{
std
::
shared_ptr
<
QueueMessage
>
m
=
get_first_message
(
&
sendq
,
ChannelAction
::
RECEIVE
);
if
(
buf_
.
size
()
>
0
)
{
// Case 1 : Channel is Buffered
// Do Data transfer from front of buffer
// and add a QueueMessage to the buffer
*
item
=
std
::
move
(
buf_
.
front
());
buf_
.
pop_front
();
// If first message from sendq is not null
// add it to the buffer and notify it
if
(
m
!=
nullptr
)
{
// Copy to buffer
buf_
.
push_back
(
std
::
move
(
*
(
m
->
data
)));
m
->
Notify
();
}
// Ignore if there is no first message
}
else
{
// Case 2: Channel is Unbuffered
// Do data transfer from front of SendQ
// If front is nullptr, then recursively call itself
if
(
m
!=
nullptr
)
{
*
item
=
std
::
move
(
*
(
m
->
data
));
m
->
Notify
();
}
else
{
return
recv_return
(
Receive
(
item
));
}
}
return
recv_return
(
true
);
}
// If this is a buffered channel and there are items in buffer
if
(
buf_
.
size
()
>
0
)
{
// Directly read from buffer
*
item
=
std
::
move
(
buf_
.
front
());
buf_
.
pop_front
();
// return true
return
recv_return
(
true
);
}
// No sender available, block on this channel
// Some receiver will complete the option for us
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
item
);
recvq
.
push_back
(
m
);
m
->
Wait
(
lock
);
return
recv_return
(
!
m
->
chan_closed
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Lock
()
{
mu_
.
lock
();
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Unlock
()
{
mu_
.
unlock
();
}
template
<
typename
T
>
bool
ChannelImpl
<
T
>::
IsClosed
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
return
closed_
;
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
Close
()
{
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
if
(
closed_
)
{
// TODO(abhinavarora): closing an already closed channel should panic
lock
.
unlock
();
return
;
}
closed_
=
true
;
// Empty the readers
while
(
!
recvq
.
empty
())
{
std
::
shared_ptr
<
QueueMessage
>
m
=
recvq
.
front
();
recvq
.
pop_front
();
m
->
chan_closed
=
true
;
// Execute callback function (if any)
if
(
m
->
callback
!=
nullptr
)
{
m
->
callback
(
ChannelAction
::
CLOSE
);
}
m
->
Notify
();
}
// Empty the senders
while
(
!
sendq
.
empty
())
{
std
::
shared_ptr
<
QueueMessage
>
m
=
sendq
.
front
();
sendq
.
pop_front
();
m
->
chan_closed
=
true
;
// Execute callback function (if any)
if
(
m
->
callback
!=
nullptr
)
{
m
->
callback
(
ChannelAction
::
CLOSE
);
}
m
->
Notify
();
}
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
AddToSendQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
data
,
cond
);
m
->
referrer
=
referrer
;
m
->
callback
=
cb
;
sendq
.
push_back
(
m
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
AddToReceiveQ
(
const
void
*
referrer
,
T
*
data
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
ChannelAction
)
>
cb
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
auto
m
=
std
::
make_shared
<
QueueMessage
>
(
data
,
cond
);
m
->
referrer
=
referrer
;
m
->
callback
=
cb
;
recvq
.
push_back
(
m
);
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
RemoveFromSendQ
(
const
void
*
referrer
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
for
(
auto
it
=
sendq
.
begin
();
it
!=
sendq
.
end
();)
{
std
::
shared_ptr
<
QueueMessage
>
sendMsg
=
(
std
::
shared_ptr
<
QueueMessage
>
)
*
it
;
if
(
sendMsg
->
referrer
==
referrer
)
{
it
=
sendq
.
erase
(
it
);
}
else
{
++
it
;
}
}
}
template
<
typename
T
>
void
ChannelImpl
<
T
>::
RemoveFromReceiveQ
(
const
void
*
referrer
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
mu_
};
for
(
auto
it
=
recvq
.
begin
();
it
!=
recvq
.
end
();)
{
std
::
shared_ptr
<
QueueMessage
>
recvMsg
=
(
std
::
shared_ptr
<
QueueMessage
>
)
*
it
;
if
(
recvMsg
->
referrer
==
referrer
)
{
it
=
recvq
.
erase
(
it
);
}
else
{
++
it
;
}
}
}
template
<
typename
T
>
ChannelImpl
<
T
>::~
ChannelImpl
()
{
Close
();
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mu_
};
destructor_cond_
.
wait
(
lock
,
[
this
]()
{
return
send_ctr
==
0
&&
recv_ctr
==
0
;
});
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/channel_test.cc
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include "gtest/gtest.h"
using
paddle
::
framework
::
Channel
;
using
paddle
::
framework
::
ChannelHolder
;
using
paddle
::
framework
::
MakeChannel
;
using
paddle
::
framework
::
CloseChannel
;
TEST
(
Channel
,
ChannelCapacityTest
)
{
const
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
EXPECT_EQ
(
ch
->
Cap
(),
buffer_size
);
CloseChannel
(
ch
);
delete
ch
;
ch
=
MakeChannel
<
size_t
>
(
0
);
EXPECT_EQ
(
ch
->
Cap
(),
0U
);
CloseChannel
(
ch
);
delete
ch
;
}
void
RecevingOrderEqualToSendingOrder
(
Channel
<
int
>
*
ch
,
int
num_items
)
{
unsigned
sum_send
=
0
;
std
::
thread
t
([
&
]()
{
for
(
int
i
=
0
;
i
<
num_items
;
i
++
)
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
});
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
for
(
int
i
=
0
;
i
<
num_items
;
i
++
)
{
int
recv
=
-
1
;
EXPECT_EQ
(
ch
->
Receive
(
&
recv
),
true
);
EXPECT_EQ
(
recv
,
i
);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
CloseChannel
(
ch
);
t
.
join
();
unsigned
expected_sum
=
(
num_items
*
(
num_items
-
1
))
/
2
;
EXPECT_EQ
(
sum_send
,
expected_sum
);
delete
ch
;
}
TEST
(
Channel
,
SufficientBufferSizeDoesntBlock
)
{
const
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
for
(
size_t
i
=
0
;
i
<
buffer_size
;
++
i
)
{
ch
->
Send
(
&
i
);
}
size_t
out
;
for
(
size_t
i
=
0
;
i
<
buffer_size
;
++
i
)
{
EXPECT_EQ
(
ch
->
Receive
(
&
out
),
true
);
// should not block
EXPECT_EQ
(
out
,
i
);
}
CloseChannel
(
ch
);
delete
ch
;
}
// This tests that a channel must return false
// on send and receive performed after closing the channel.
// Receive will only return false after close when queue is empty.
// By creating separate threads for sending and receiving, we make this
// function able to test both buffered and unbuffered channels.
void
SendReceiveWithACloseChannelShouldPanic
(
Channel
<
size_t
>
*
ch
)
{
const
size_t
data
=
5
;
std
::
thread
send_thread
{[
&
]()
{
size_t
i
=
data
;
ch
->
Send
(
&
i
);
// should not block
}};
std
::
thread
recv_thread
{[
&
]()
{
size_t
i
;
EXPECT_EQ
(
ch
->
Receive
(
&
i
),
true
);
// should not block
EXPECT_EQ
(
i
,
data
);
}};
send_thread
.
join
();
recv_thread
.
join
();
// After closing send should panic. Receive should
// also false as there is no data in queue.
CloseChannel
(
ch
);
send_thread
=
std
::
thread
{[
&
]()
{
size_t
i
=
data
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
i
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
}};
recv_thread
=
std
::
thread
{[
&
]()
{
size_t
i
;
// should return false because channel is closed and queue is empty
EXPECT_EQ
(
ch
->
Receive
(
&
i
),
false
);
}};
send_thread
.
join
();
recv_thread
.
join
();
}
TEST
(
Channel
,
SendReceiveClosedBufferedChannelPanics
)
{
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
SendReceiveWithACloseChannelShouldPanic
(
ch
);
delete
ch
;
}
TEST
(
Channel
,
SendReceiveClosedUnBufferedChannelPanics
)
{
auto
ch
=
MakeChannel
<
size_t
>
(
0
);
SendReceiveWithACloseChannelShouldPanic
(
ch
);
delete
ch
;
}
TEST
(
Channel
,
ReceiveFromBufferedChannelReturnResidualValuesTest
)
{
const
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
for
(
size_t
i
=
0
;
i
<
buffer_size
;
++
i
)
{
ch
->
Send
(
&
i
);
// sending should not block
}
size_t
out
;
for
(
size_t
i
=
0
;
i
<
buffer_size
/
2
;
++
i
)
{
EXPECT_EQ
(
ch
->
Receive
(
&
out
),
true
);
// receiving should not block
EXPECT_EQ
(
out
,
i
);
}
CloseChannel
(
ch
);
for
(
size_t
i
=
buffer_size
/
2
;
i
<
buffer_size
;
++
i
)
{
EXPECT_EQ
(
ch
->
Receive
(
&
out
),
true
);
// receving should return residual values.
EXPECT_EQ
(
out
,
i
);
}
for
(
size_t
i
=
0
;
i
<
buffer_size
;
++
i
)
{
EXPECT_EQ
(
ch
->
Receive
(
&
out
),
false
);
// receiving on closed channel should return false
}
delete
ch
;
}
TEST
(
Channel
,
ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize
)
{
const
size_t
buffer_size
=
10
;
auto
ch
=
MakeChannel
<
size_t
>
(
buffer_size
);
std
::
thread
t
([
&
]()
{
// Try to write more than buffer size.
for
(
size_t
i
=
0
;
i
<
2
*
buffer_size
;
++
i
)
{
if
(
i
<
buffer_size
)
{
ch
->
Send
(
&
i
);
// should block after 10 iterations
}
else
{
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
i
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
}
}
});
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
CloseChannel
(
ch
);
t
.
join
();
delete
ch
;
}
TEST
(
Channel
,
RecevingOrderEqualToSendingOrderWithUnBufferedChannel
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
RecevingOrderEqualToSendingOrder
(
ch
,
20
);
}
TEST
(
Channel
,
RecevingOrderEqualToSendingOrderWithBufferedChannel1
)
{
// Test that Receive Order is same as Send Order when number of items
// sent is less than size of buffer
auto
ch
=
MakeChannel
<
int
>
(
10
);
RecevingOrderEqualToSendingOrder
(
ch
,
5
);
}
TEST
(
Channel
,
RecevingOrderEqualToSendingOrderWithBufferedChannel2
)
{
// Test that Receive Order is same as Send Order when number of items
// sent is equal to size of buffer
auto
ch
=
MakeChannel
<
int
>
(
10
);
RecevingOrderEqualToSendingOrder
(
ch
,
10
);
}
TEST
(
Channel
,
RecevingOrderEqualToSendingOrderWithBufferedChannel3
)
{
// Test that Receive Order is same as Send Order when number of items
// sent is greater than the size of buffer
auto
ch
=
MakeChannel
<
int
>
(
10
);
RecevingOrderEqualToSendingOrder
(
ch
,
20
);
}
void
ChannelCloseUnblocksReceiversTest
(
Channel
<
int
>
*
ch
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to read and are blocked because of no writers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
EXPECT_EQ
(
ch
->
Receive
(
&
data
),
false
);
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
// Verify that all the threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
// Explicitly close the channel
// This should unblock all receivers
CloseChannel
(
ch
);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
void
ChannelCloseUnblocksSendersTest
(
Channel
<
int
>
*
ch
,
bool
isBuffered
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
bool
send_success
[
kNumThreads
];
// Launches threads that try to write and are blocked because of no readers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
send_success
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
,
bool
*
success
)
{
int
data
=
10
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
*
success
=
!
is_exception
;
*
ended
=
true
;
},
&
thread_ended
[
i
],
&
send_success
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
if
(
isBuffered
)
{
// If ch is Buffered, atleast 4 threads must be blocked.
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
!
thread_ended
[
i
])
ct
++
;
}
EXPECT_GE
(
ct
,
4
);
}
else
{
// If ch is UnBuffered, all the threads should be blocked.
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
}
// Explicitly close the thread
// This should unblock all senders
CloseChannel
(
ch
);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
if
(
isBuffered
)
{
// Verify that only 1 send was successful
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
send_success
[
i
])
ct
++
;
}
// Only 1 send must be successful
EXPECT_EQ
(
ct
,
1
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
// This tests that closing a buffered channel also unblocks
// any receivers waiting on the channel
TEST
(
Channel
,
BufferedChannelCloseUnblocksReceiversTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
1
);
ChannelCloseUnblocksReceiversTest
(
ch
);
delete
ch
;
}
// This tests that closing a buffered channel also unblocks
// any senders waiting for channel to have write space
TEST
(
Channel
,
BufferedChannelCloseUnblocksSendersTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
1
);
ChannelCloseUnblocksSendersTest
(
ch
,
true
);
delete
ch
;
}
// This tests that closing an unbuffered channel also unblocks
// unblocks any receivers waiting for senders
TEST
(
Channel
,
UnbufferedChannelCloseUnblocksReceiversTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
ChannelCloseUnblocksReceiversTest
(
ch
);
delete
ch
;
}
// This tests that closing an unbuffered channel also unblocks
// unblocks any senders waiting for senders
TEST
(
Channel
,
UnbufferedChannelCloseUnblocksSendersTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
ChannelCloseUnblocksSendersTest
(
ch
,
false
);
delete
ch
;
}
TEST
(
Channel
,
UnbufferedLessReceiveMoreSendTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
unsigned
sum_send
=
0
;
// Send should block after three iterations
// since we only have three receivers.
std
::
thread
t
([
&
]()
{
// Try to send more number of times
// than receivers
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
try
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
}
}
});
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
int
recv
;
ch
->
Receive
(
&
recv
);
EXPECT_EQ
(
recv
,
i
);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
EXPECT_EQ
(
sum_send
,
3U
);
CloseChannel
(
ch
);
t
.
join
();
delete
ch
;
}
TEST
(
Channel
,
UnbufferedMoreReceiveLessSendTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
unsigned
sum_send
=
0
;
unsigned
sum_receive
=
0
;
// The receiver should block after 5
// iterations, since there are only 5 senders.
std
::
thread
t
([
&
]()
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
int
recv
;
ch
->
Receive
(
&
recv
);
// should block after the fifth iteration.
EXPECT_EQ
(
recv
,
i
);
sum_receive
+=
i
;
}
});
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
EXPECT_EQ
(
sum_send
,
10U
);
EXPECT_EQ
(
sum_receive
,
10U
);
// send three more elements
for
(
int
i
=
5
;
i
<
8
;
i
++
)
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
CloseChannel
(
ch
);
t
.
join
();
EXPECT_EQ
(
sum_send
,
28U
);
EXPECT_EQ
(
sum_receive
,
28U
);
delete
ch
;
}
// This tests that destroying a channel unblocks
// any senders waiting for channel to have write space
void
ChannelDestroyUnblockSenders
(
Channel
<
int
>
*
ch
,
bool
isBuffered
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
bool
send_success
[
kNumThreads
];
// Launches threads that try to write and are blocked because of no readers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
send_success
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
,
bool
*
success
)
{
int
data
=
10
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
*
success
=
!
is_exception
;
*
ended
=
true
;
},
&
thread_ended
[
i
],
&
send_success
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
if
(
isBuffered
)
{
// If channel is buffered, verify that atleast 4 threads are blocked
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
thread_ended
[
i
]
==
false
)
ct
++
;
}
// Atleast 4 threads must be blocked
EXPECT_GE
(
ct
,
4
);
}
else
{
// Verify that all the threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
}
// Explicitly destroy the channel
delete
ch
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
// Count number of successful sends
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
send_success
[
i
])
ct
++
;
}
if
(
isBuffered
)
{
// Only 1 send must be successful
EXPECT_EQ
(
ct
,
1
);
}
else
{
// In unbuffered channel, no send should be successful
EXPECT_EQ
(
ct
,
0
);
}
// Join all threads
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
// This tests that destroying a channel also unblocks
// any receivers waiting on the channel
void
ChannelDestroyUnblockReceivers
(
Channel
<
int
>
*
ch
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to read and are blocked because of no writers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
// All reads should return false
EXPECT_EQ
(
ch
->
Receive
(
&
data
),
false
);
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// wait
// Verify that all threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
// delete the channel
delete
ch
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
TEST
(
Channel
,
BufferedChannelDestroyUnblocksReceiversTest
)
{
size_t
buffer_size
=
1
;
auto
ch
=
MakeChannel
<
int
>
(
buffer_size
);
ChannelDestroyUnblockReceivers
(
ch
);
}
TEST
(
Channel
,
BufferedChannelDestroyUnblocksSendersTest
)
{
size_t
buffer_size
=
1
;
auto
ch
=
MakeChannel
<
int
>
(
buffer_size
);
ChannelDestroyUnblockSenders
(
ch
,
true
);
}
// This tests that destroying an unbuffered channel also unblocks
// unblocks any receivers waiting for senders
TEST
(
Channel
,
UnbufferedChannelDestroyUnblocksReceiversTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
ChannelDestroyUnblockReceivers
(
ch
);
}
TEST
(
Channel
,
UnbufferedChannelDestroyUnblocksSendersTest
)
{
auto
ch
=
MakeChannel
<
int
>
(
0
);
ChannelDestroyUnblockSenders
(
ch
,
false
);
}
TEST
(
ChannelHolder
,
ChannelHolderCapacityTest
)
{
const
size_t
buffer_size
=
10
;
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
buffer_size
);
EXPECT_EQ
(
ch
->
Cap
(),
buffer_size
);
delete
ch
;
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
EXPECT_EQ
(
ch
->
Cap
(),
0U
);
delete
ch
;
}
void
ChannelHolderSendReceive
(
ChannelHolder
*
ch
)
{
unsigned
sum_send
=
0
;
std
::
thread
t
([
&
]()
{
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
ch
->
Send
(
&
i
);
sum_send
+=
i
;
}
});
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
int
recv
;
EXPECT_EQ
(
ch
->
Receive
(
&
recv
),
true
);
EXPECT_EQ
(
recv
,
i
);
}
ch
->
close
();
t
.
join
();
EXPECT_EQ
(
sum_send
,
10U
);
}
TEST
(
ChannelHolder
,
ChannelHolderBufferedSendReceiveTest
)
{
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
10
);
ChannelHolderSendReceive
(
ch
);
delete
ch
;
}
TEST
(
ChannelHolder
,
ChannelHolderUnBufferedSendReceiveTest
)
{
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderSendReceive
(
ch
);
delete
ch
;
}
TEST
(
ChannelHolder
,
ChannelUninitializedTest
)
{
ChannelHolder
*
ch
=
new
ChannelHolder
();
EXPECT_EQ
(
ch
->
IsInitialized
(),
false
);
int
i
=
10
;
bool
send_exception
=
false
;
try
{
ch
->
Send
(
&
i
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
send_exception
=
true
;
}
EXPECT_EQ
(
send_exception
,
true
);
bool
recv_exception
=
false
;
try
{
ch
->
Receive
(
&
i
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
recv_exception
=
true
;
}
EXPECT_EQ
(
recv_exception
,
true
);
bool
is_exception
=
false
;
try
{
ch
->
Type
();
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
}
TEST
(
ChannelHolder
,
ChannelInitializedTest
)
{
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
2
);
EXPECT_EQ
(
ch
->
IsInitialized
(),
true
);
// Channel should remain intialized even after close
ch
->
close
();
EXPECT_EQ
(
ch
->
IsInitialized
(),
true
);
delete
ch
;
}
TEST
(
ChannelHolder
,
TypeMismatchSendTest
)
{
// Test with unbuffered channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
bool
is_exception
=
false
;
bool
boolean_data
=
true
;
try
{
ch
->
Send
(
&
boolean_data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
// Test with Buffered Channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
float
>
(
10
);
is_exception
=
false
;
int
int_data
=
23
;
try
{
ch
->
Send
(
&
int_data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
}
TEST
(
ChannelHolder
,
TypeMismatchReceiveTest
)
{
// Test with unbuffered channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
bool
is_exception
=
false
;
bool
float_data
;
try
{
ch
->
Receive
(
&
float_data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
// Test with Buffered Channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
float
>
(
10
);
is_exception
=
false
;
int
int_data
=
23
;
try
{
ch
->
Receive
(
&
int_data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
EXPECT_EQ
(
is_exception
,
true
);
delete
ch
;
}
void
ChannelHolderCloseUnblocksReceiversTest
(
ChannelHolder
*
ch
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to read and are blocked because of no writers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
EXPECT_EQ
(
ch
->
Receive
(
&
data
),
false
);
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
// Verify that all the threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
// Explicitly close the channel
// This should unblock all receivers
ch
->
close
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
void
ChannelHolderCloseUnblocksSendersTest
(
ChannelHolder
*
ch
,
bool
isBuffered
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
bool
send_success
[
kNumThreads
];
// Launches threads that try to write and are blocked because of no readers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
send_success
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
,
bool
*
success
)
{
int
data
=
10
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
*
success
=
!
is_exception
;
*
ended
=
true
;
},
&
thread_ended
[
i
],
&
send_success
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
if
(
isBuffered
)
{
// If ch is Buffered, atleast 4 threads must be blocked.
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
!
thread_ended
[
i
])
ct
++
;
}
EXPECT_GE
(
ct
,
4
);
}
else
{
// If ch is UnBuffered, all the threads should be blocked.
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
}
// Explicitly close the thread
// This should unblock all senders
ch
->
close
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
if
(
isBuffered
)
{
// Verify that only 1 send was successful
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
send_success
[
i
])
ct
++
;
}
// Only 1 send must be successful
EXPECT_EQ
(
ct
,
1
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
// This tests that closing a channelholder unblocks
// any receivers waiting on the channel
TEST
(
ChannelHolder
,
ChannelHolderCloseUnblocksReceiversTest
)
{
// Check for buffered channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
1
);
ChannelHolderCloseUnblocksReceiversTest
(
ch
);
delete
ch
;
// Check for unbuffered channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderCloseUnblocksReceiversTest
(
ch
);
delete
ch
;
}
// This tests that closing a channelholder unblocks
// any senders waiting for channel to have write space
TEST
(
Channel
,
ChannelHolderCloseUnblocksSendersTest
)
{
// Check for buffered channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
1
);
ChannelHolderCloseUnblocksSendersTest
(
ch
,
true
);
delete
ch
;
// Check for unbuffered channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderCloseUnblocksSendersTest
(
ch
,
false
);
delete
ch
;
}
// This tests that destroying a channelholder unblocks
// any senders waiting for channel
void
ChannelHolderDestroyUnblockSenders
(
ChannelHolder
*
ch
,
bool
isBuffered
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
bool
send_success
[
kNumThreads
];
// Launches threads that try to write and are blocked because of no readers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
send_success
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
,
bool
*
success
)
{
int
data
=
10
;
bool
is_exception
=
false
;
try
{
ch
->
Send
(
&
data
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
is_exception
=
true
;
}
*
success
=
!
is_exception
;
*
ended
=
true
;
},
&
thread_ended
[
i
],
&
send_success
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait 0.2 sec
if
(
isBuffered
)
{
// If channel is buffered, verify that atleast 4 threads are blocked
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
thread_ended
[
i
]
==
false
)
ct
++
;
}
// Atleast 4 threads must be blocked
EXPECT_GE
(
ct
,
4
);
}
else
{
// Verify that all the threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
}
// Explicitly destroy the channel
delete
ch
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
// Count number of successfuld sends
int
ct
=
0
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
if
(
send_success
[
i
])
ct
++
;
}
if
(
isBuffered
)
{
// Only 1 send must be successful
EXPECT_EQ
(
ct
,
1
);
}
else
{
// In unbuffered channel, no send should be successful
EXPECT_EQ
(
ct
,
0
);
}
// Join all threads
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
// This tests that destroying a channelholder also unblocks
// any receivers waiting on the channel
void
ChannelHolderDestroyUnblockReceivers
(
ChannelHolder
*
ch
)
{
const
size_t
kNumThreads
=
5
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to read and are blocked because of no writers
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
// All reads should return false
EXPECT_EQ
(
ch
->
Receive
(
&
data
),
false
);
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads are blocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
false
);
}
// delete the channel
delete
ch
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
200
));
// wait
// Verify that all threads got unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
TEST
(
ChannelHolder
,
ChannelHolderDestroyUnblocksReceiversTest
)
{
// Check for Buffered Channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
1
);
ChannelHolderDestroyUnblockReceivers
(
ch
);
// ch is already deleted already deleted in
// ChannelHolderDestroyUnblockReceivers
// Check for Unbuffered channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderDestroyUnblockReceivers
(
ch
);
}
TEST
(
ChannelHolder
,
ChannelHolderDestroyUnblocksSendersTest
)
{
// Check for Buffered Channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
1
);
ChannelHolderDestroyUnblockSenders
(
ch
,
true
);
// ch is already deleted already deleted in
// ChannelHolderDestroyUnblockReceivers
// Check for Unbuffered channel
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
0
);
ChannelHolderDestroyUnblockSenders
(
ch
,
false
);
}
// This tests that closing a channelholder many times.
void
ChannelHolderManyTimesClose
(
ChannelHolder
*
ch
)
{
const
int
kNumThreads
=
15
;
std
::
thread
t
[
kNumThreads
];
bool
thread_ended
[
kNumThreads
];
// Launches threads that try to send data to channel.
for
(
size_t
i
=
0
;
i
<
kNumThreads
/
3
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
ended
)
{
int
data
=
10
;
ch
->
Send
(
&
data
);
*
ended
=
true
;
},
&
thread_ended
[
i
]);
}
// Launches threads that try to receive data to channel.
for
(
size_t
i
=
kNumThreads
/
3
;
i
<
2
*
kNumThreads
/
3
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
int
data
;
if
(
ch
->
Receive
(
&
data
))
{
EXPECT_EQ
(
data
,
10
);
}
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
// Launches threads that try to close the channel.
for
(
size_t
i
=
2
*
kNumThreads
/
3
;
i
<
kNumThreads
;
i
++
)
{
thread_ended
[
i
]
=
false
;
t
[
i
]
=
std
::
thread
(
[
&
](
bool
*
p
)
{
if
(
!
ch
->
IsClosed
())
{
ch
->
close
();
}
*
p
=
true
;
},
&
thread_ended
[
i
]);
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// wait
// Verify that all threads are unblocked
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
{
EXPECT_EQ
(
thread_ended
[
i
],
true
);
}
EXPECT_TRUE
(
ch
->
IsClosed
());
// delete the channel
delete
ch
;
for
(
size_t
i
=
0
;
i
<
kNumThreads
;
i
++
)
t
[
i
].
join
();
}
TEST
(
ChannelHolder
,
ChannelHolderManyTimesCloseTest
)
{
// Check for Buffered Channel
ChannelHolder
*
ch
=
new
ChannelHolder
();
ch
->
Reset
<
int
>
(
10
);
ChannelHolderManyTimesClose
(
ch
);
}
paddle/fluid/framework/concurrency_test.cc
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
USE_NO_KERNEL_OP
(
go
);
USE_NO_KERNEL_OP
(
channel_close
);
USE_NO_KERNEL_OP
(
channel_create
);
USE_NO_KERNEL_OP
(
channel_recv
);
USE_NO_KERNEL_OP
(
channel_send
);
USE_NO_KERNEL_OP
(
elementwise_add
);
USE_NO_KERNEL_OP
(
select
);
USE_NO_KERNEL_OP
(
conditional_block
);
USE_NO_KERNEL_OP
(
equal
);
USE_NO_KERNEL_OP
(
assign
);
USE_NO_KERNEL_OP
(
while
);
USE_NO_KERNEL_OP
(
print
);
namespace
f
=
paddle
::
framework
;
namespace
p
=
paddle
::
platform
;
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
LoDTensor
*
CreateVariable
(
Scope
*
scope
,
const
p
::
CPUPlace
&
place
,
std
::
string
name
,
T
value
)
{
// Create LoDTensor<int> of dim [1]
auto
var
=
scope
->
Var
(
name
);
auto
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
({
1
});
T
*
expect
=
tensor
->
mutable_data
<
T
>
(
place
);
expect
[
0
]
=
value
;
return
tensor
;
}
void
AddOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
AttributeMap
attrs
,
BlockDesc
*
block
)
{
// insert op
auto
op
=
block
->
AppendOp
();
op
->
SetType
(
type
);
for
(
auto
&
kv
:
inputs
)
{
op
->
SetInput
(
kv
.
first
,
kv
.
second
);
}
for
(
auto
&
kv
:
outputs
)
{
op
->
SetOutput
(
kv
.
first
,
kv
.
second
);
}
op
->
SetAttrMap
(
attrs
);
}
void
AddCase
(
ProgramDesc
*
program
,
Scope
*
scope
,
p
::
CPUPlace
*
place
,
BlockDesc
*
casesBlock
,
int
caseId
,
int
caseType
,
std
::
string
caseChannel
,
std
::
string
caseVarName
,
std
::
function
<
void
(
BlockDesc
*
,
Scope
*
)
>
func
)
{
std
::
string
caseCondName
=
std
::
string
(
"caseCond"
)
+
std
::
to_string
(
caseId
);
std
::
string
caseCondXVarName
=
std
::
string
(
"caseCondX"
)
+
std
::
to_string
(
caseId
);
BlockDesc
*
caseBlock
=
program
->
AppendBlock
(
*
casesBlock
);
func
(
caseBlock
,
scope
);
CreateVariable
(
scope
,
*
place
,
caseCondName
,
false
);
CreateVariable
(
scope
,
*
place
,
caseCondXVarName
,
caseId
);
CreateVariable
(
scope
,
*
place
,
caseVarName
,
caseId
);
scope
->
Var
(
"step_scope"
);
AddOp
(
"equal"
,
{{
"X"
,
{
caseCondXVarName
}},
{
"Y"
,
{
"caseToExecute"
}}},
{{
"Out"
,
{
caseCondName
}}},
{},
casesBlock
);
AddOp
(
"conditional_block"
,
{{
"X"
,
{
caseCondName
}},
{
"Params"
,
{}}},
{{
"Out"
,
{}},
{
"Scope"
,
{
"step_scope"
}}},
{{
"sub_block"
,
caseBlock
},
{
"is_scalar_condition"
,
true
}},
casesBlock
);
}
void
AddFibonacciSelect
(
Scope
*
scope
,
p
::
CPUPlace
*
place
,
ProgramDesc
*
program
,
BlockDesc
*
parentBlock
,
std
::
string
dataChanName
,
std
::
string
quitChanName
)
{
BlockDesc
*
whileBlock
=
program
->
AppendBlock
(
*
parentBlock
);
CreateVariable
(
scope
,
*
place
,
"whileExitCond"
,
true
);
CreateVariable
(
scope
,
*
place
,
"caseToExecute"
,
-
1
);
CreateVariable
(
scope
,
*
place
,
"case1var"
,
0
);
CreateVariable
(
scope
,
*
place
,
"xtemp"
,
0
);
// TODO(thuan): Need to create fibXToSend, since channel send moves the actual
// data,
// which causes the data to be no longer accessible to do the fib calculation
// TODO(abhinav): Change channel send to do a copy instead of a move!
CreateVariable
(
scope
,
*
place
,
"fibXToSend"
,
0
);
CreateVariable
(
scope
,
*
place
,
"fibX"
,
0
);
CreateVariable
(
scope
,
*
place
,
"fibY"
,
1
);
CreateVariable
(
scope
,
*
place
,
"quitVar"
,
0
);
BlockDesc
*
casesBlock
=
program
->
AppendBlock
(
*
whileBlock
);
std
::
function
<
void
(
BlockDesc
*
caseBlock
)
>
f
=
[](
BlockDesc
*
caseBlock
)
{};
// TODO(thuan): Remove this once we change channel send to do a copy instead
// of move
AddOp
(
"assign"
,
{{
"X"
,
{
"fibX"
}}},
{{
"Out"
,
{
"fibXToSend"
}}},
{},
whileBlock
);
// Case 0: Send to dataChanName
std
::
function
<
void
(
BlockDesc
*
caseBlock
,
Scope
*
scope
)
>
case0Func
=
[
&
](
BlockDesc
*
caseBlock
,
Scope
*
scope
)
{
AddOp
(
"assign"
,
{{
"X"
,
{
"fibX"
}}},
{{
"Out"
,
{
"xtemp"
}}},
{},
caseBlock
);
AddOp
(
"assign"
,
{{
"X"
,
{
"fibY"
}}},
{{
"Out"
,
{
"fibX"
}}},
{},
caseBlock
);
AddOp
(
"elementwise_add"
,
{{
"X"
,
{
"xtemp"
}},
{
"Y"
,
{
"fibY"
}}},
{{
"Out"
,
{
"fibY"
}}},
{},
caseBlock
);
};
AddCase
(
program
,
scope
,
place
,
casesBlock
,
0
,
1
,
dataChanName
,
"fibXToSend"
,
case0Func
);
std
::
string
case0Config
=
std
::
string
(
"0,1,"
)
+
dataChanName
+
std
::
string
(
",fibXToSend"
);
// Case 1: Receive from quitChanName
std
::
function
<
void
(
BlockDesc
*
caseBlock
,
Scope
*
scope
)
>
case2Func
=
[
&
](
BlockDesc
*
caseBlock
,
Scope
*
scope
)
{
// Exit the while loop after we receive from quit channel.
// We assign a false to "whileExitCond" variable, which will
// break out of while_op loop
CreateVariable
(
scope
,
*
place
,
"whileFalse"
,
false
);
AddOp
(
"assign"
,
{{
"X"
,
{
"whileFalse"
}}},
{{
"Out"
,
{
"whileExitCond"
}}},
{},
caseBlock
);
};
AddCase
(
program
,
scope
,
place
,
casesBlock
,
1
,
2
,
quitChanName
,
"quitVar"
,
case2Func
);
std
::
string
case1Config
=
std
::
string
(
"1,2,"
)
+
quitChanName
+
std
::
string
(
",quitVar"
);
// Select block
AddOp
(
"select"
,
{{
"X"
,
{
dataChanName
,
quitChanName
}},
{
"case_to_execute"
,
{
"caseToExecute"
}}},
{{
"Out"
,
{}}},
{{
"sub_block"
,
casesBlock
},
{
"cases"
,
std
::
vector
<
std
::
string
>
{
case0Config
,
case1Config
}}},
whileBlock
);
scope
->
Var
(
"stepScopes"
);
AddOp
(
"while"
,
{{
"X"
,
{
dataChanName
,
quitChanName
}},
{
"Condition"
,
{
"whileExitCond"
}}},
{{
"Out"
,
{}},
{
"StepScopes"
,
{
"stepScopes"
}}},
{{
"sub_block"
,
whileBlock
}},
parentBlock
);
}
TEST
(
Concurrency
,
Go_Op
)
{
Scope
scope
;
p
::
CPUPlace
place
;
// Initialize scope variables
p
::
CPUDeviceContext
ctx
(
place
);
// Create channel variable
scope
.
Var
(
"Channel"
);
// Create Variables, x0 will be put into channel,
// result will be pulled from channel
CreateVariable
(
&
scope
,
place
,
"Status"
,
false
);
CreateVariable
(
&
scope
,
place
,
"x0"
,
99
);
CreateVariable
(
&
scope
,
place
,
"result"
,
0
);
framework
::
Executor
executor
(
place
);
ProgramDesc
program
;
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
// Create channel OP
AddOp
(
"channel_create"
,
{},
{{
"Out"
,
{
"Channel"
}}},
{{
"capacity"
,
10
},
{
"data_type"
,
f
::
proto
::
VarType
::
LOD_TENSOR
}},
block
);
// Create Go Op routine
BlockDesc
*
goOpBlock
=
program
.
AppendBlock
(
program
.
Block
(
0
));
AddOp
(
"channel_send"
,
{{
"Channel"
,
{
"Channel"
}},
{
"X"
,
{
"x0"
}}},
{{
"Status"
,
{
"Status"
}}},
{},
goOpBlock
);
// Create Go Op
AddOp
(
"go"
,
{{
"X"
,
{
"Channel"
,
"x0"
}}},
{},
{{
"sub_block"
,
goOpBlock
}},
block
);
// Create Channel Receive Op
AddOp
(
"channel_recv"
,
{{
"Channel"
,
{
"Channel"
}}},
{{
"Status"
,
{
"Status"
}},
{
"Out"
,
{
"result"
}}},
{},
block
);
// Create Channel Close Op
AddOp
(
"channel_close"
,
{{
"Channel"
,
{
"Channel"
}}},
{},
{},
block
);
// Check the result tensor to make sure it is set to 0
const
LoDTensor
&
tensor
=
(
scope
.
FindVar
(
"result"
))
->
Get
<
LoDTensor
>
();
auto
*
initialData
=
tensor
.
data
<
int
>
();
EXPECT_EQ
(
initialData
[
0
],
0
);
executor
.
Run
(
program
,
&
scope
,
0
,
true
,
true
);
// After we call executor.run, the Go operator should do a channel_send to
// set the "result" variable to 99.
auto
*
finalData
=
tensor
.
data
<
int
>
();
EXPECT_EQ
(
finalData
[
0
],
99
);
}
/**
* This test implements the fibonacci function using go_op and select_op
*/
TEST
(
Concurrency
,
Select
)
{
Scope
scope
;
p
::
CPUPlace
place
;
// Initialize scope variables
p
::
CPUDeviceContext
ctx
(
place
);
CreateVariable
(
&
scope
,
place
,
"Status"
,
false
);
CreateVariable
(
&
scope
,
place
,
"result"
,
0
);
CreateVariable
(
&
scope
,
place
,
"currentXFib"
,
0
);
framework
::
Executor
executor
(
place
);
ProgramDesc
program
;
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
// Create channel OP
std
::
string
dataChanName
=
"Channel"
;
scope
.
Var
(
dataChanName
);
AddOp
(
"channel_create"
,
{},
{{
"Out"
,
{
dataChanName
}}},
{{
"capacity"
,
0
},
{
"data_type"
,
f
::
proto
::
VarType
::
LOD_TENSOR
}},
block
);
std
::
string
quitChanName
=
"Quit"
;
scope
.
Var
(
quitChanName
);
AddOp
(
"channel_create"
,
{},
{{
"Out"
,
{
quitChanName
}}},
{{
"capacity"
,
0
},
{
"data_type"
,
f
::
proto
::
VarType
::
LOD_TENSOR
}},
block
);
// Create Go Op routine, which loops 10 times over fibonacci sequence
CreateVariable
(
&
scope
,
place
,
"xReceiveVar"
,
0
);
BlockDesc
*
goOpBlock
=
program
.
AppendBlock
(
program
.
Block
(
0
));
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
AddOp
(
"channel_recv"
,
{{
"Channel"
,
{
dataChanName
}}},
{{
"Status"
,
{
"Status"
}},
{
"Out"
,
{
"currentXFib"
}}},
{},
goOpBlock
);
AddOp
(
"print"
,
{{
"In"
,
{
"currentXFib"
}}},
{{
"Out"
,
{
"currentXFib"
}}},
{{
"first_n"
,
100
},
{
"summarize"
,
-
1
},
{
"print_tensor_name"
,
false
},
{
"print_tensor_type"
,
true
},
{
"print_tensor_shape"
,
false
},
{
"print_tensor_lod"
,
false
},
{
"print_phase"
,
std
::
string
(
"FORWARD"
)},
{
"message"
,
std
::
string
(
"X: "
)}},
goOpBlock
);
}
CreateVariable
(
&
scope
,
place
,
"quitSignal"
,
0
);
AddOp
(
"channel_send"
,
{{
"Channel"
,
{
quitChanName
}},
{
"X"
,
{
"quitSignal"
}}},
{{
"Status"
,
{
"Status"
}}},
{},
goOpBlock
);
// Create Go Op
AddOp
(
"go"
,
{{
"X"
,
{
dataChanName
,
quitChanName
}}},
{},
{{
"sub_block"
,
goOpBlock
}},
block
);
AddFibonacciSelect
(
&
scope
,
&
place
,
&
program
,
block
,
dataChanName
,
quitChanName
);
// Create Channel Close Op
AddOp
(
"channel_close"
,
{{
"Channel"
,
{
dataChanName
}}},
{},
{},
block
);
AddOp
(
"channel_close"
,
{{
"Channel"
,
{
quitChanName
}}},
{},
{},
block
);
executor
.
Run
(
program
,
&
scope
,
0
,
true
,
true
);
// After we call executor.run, "result" variable should be equal to 34
// (which is 10 loops through fibonacci sequence)
const
LoDTensor
&
tensor
=
(
scope
.
FindVar
(
"currentXFib"
))
->
Get
<
LoDTensor
>
();
auto
*
finalData
=
tensor
.
data
<
int
>
();
EXPECT_EQ
(
finalData
[
0
],
34
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/executor.cc
浏览文件 @
f63ab561
...
...
@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
...
...
@@ -76,15 +75,13 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
CHANNEL
)
{
var
->
GetMutable
<
ChannelHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
// GetMutable will be called in operator
}
else
{
PADDLE_THROW
(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER,
CHANNEL,
RAW]"
,
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW]"
,
var_type
);
}
}
...
...
paddle/fluid/framework/framework.proto
浏览文件 @
f63ab561
...
...
@@ -126,7 +126,6 @@ message VarType {
LOD_TENSOR_ARRAY
=
13
;
PLACE_LIST
=
14
;
READER
=
15
;
CHANNEL
=
16
;
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
...
...
@@ -158,12 +157,6 @@ message VarType {
message
ReaderDesc
{
repeated
LoDTensorDesc
lod_tensor
=
1
;
}
optional
ReaderDesc
reader
=
5
;
message
ChannelDesc
{
required
Type
data_type
=
1
;
required
int64
capacity
=
2
;
}
optional
ChannelDesc
channel
=
6
;
message
Tuple
{
repeated
Type
element_type
=
1
;
}
optional
Tuple
tuple
=
7
;
}
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
f63ab561
...
...
@@ -34,6 +34,7 @@ endif ()
pass_library
(
attention_lstm_fuse_pass inference
)
pass_library
(
infer_clean_graph_pass inference
)
pass_library
(
fc_lstm_fuse_pass inference
)
pass_library
(
embedding_fc_lstm_fuse_pass inference
)
pass_library
(
fc_gru_fuse_pass inference
)
pass_library
(
seq_concat_fc_fuse_pass inference
)
...
...
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
0 → 100644
浏览文件 @
f63ab561
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
bool
with_fc_bias
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Build pattern
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
PDNodeName
(
name_scope
,
"x"
))
->
assert_is_op_input
(
"lookup_table"
)
->
assert_var_not_persistable
();
patterns
::
Embedding
embedding_pattern
(
pattern
,
name_scope
);
// TODO(jczaja): Intermediate can only be for val that are not used anywhere
// but lookup table output may go into other LSTM (for reverse
// direction)
auto
*
embedding_out
=
embedding_pattern
(
x
);
patterns
::
FC
fc_pattern
(
pattern
,
name_scope
);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto
*
fc_out
=
fc_pattern
(
embedding_out
,
with_fc_bias
)
->
AsIntermediate
();
patterns
::
LSTM
lstm_pattern
(
pattern
,
name_scope
);
lstm_pattern
(
fc_out
);
// Create New OpDesc
auto
embedding_lstm_creator
=
[
&
](
Node
*
embedding
,
Node
*
W
,
Node
*
lstm
,
Node
*
input
,
Node
*
weight_x
,
Node
*
weight_h
,
Node
*
bias
,
Node
*
hidden
,
Node
*
cell
,
Node
*
xx
,
Node
*
fc_bias
)
{
OpDesc
op_desc
;
op_desc
.
SetType
(
"fused_embedding_fc_lstm"
);
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN
(
Ids
,
input
);
SET_IN
(
WeightH
,
weight_h
);
// Neet to have this passed as We need Wc data for peephole connections
SET_IN
(
Bias
,
bias
);
#undef SET_IN
// Multiply embeddings with Weights
PADDLE_ENFORCE
(
scope
);
const
std
::
string
&
embeddings
=
patterns
::
UniqueKey
(
"Embeddings"
);
auto
*
embeddings_var
=
scope
->
Var
(
embeddings
);
PADDLE_ENFORCE
(
embeddings_var
);
auto
*
embeddings_tensor
=
embeddings_var
->
GetMutable
<
framework
::
LoDTensor
>
();
// Get WeightX size: [single_embedding, fc_size]
// and embedding size: [dict_size, single_embedding]
// and create new size of embeddings eg. [dict_size , hidden_size]
auto
*
embedding_var
=
scope
->
FindVar
(
W
->
Name
());
PADDLE_ENFORCE
(
embedding_var
);
const
auto
&
embedding_tensor
=
embedding_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
weightx_tensor
=
scope
->
FindVar
(
weight_x
->
Name
())
->
Get
<
framework
::
LoDTensor
>
();
embeddings_tensor
->
Resize
(
{
embedding_tensor
.
dims
()[
0
],
weightx_tensor
.
dims
()[
1
]});
// Multiplie embeddings via WeightsX and add bias
auto
embedding_data
=
embedding_tensor
.
data
<
float
>
();
auto
weightx_data
=
weightx_tensor
.
data
<
float
>
();
auto
embeddings_data
=
embeddings_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
// Adding biases to GEMM result to be
auto
*
lstm_bias_var
=
scope
->
FindVar
(
bias
->
Name
());
PADDLE_ENFORCE
(
lstm_bias_var
);
const
auto
&
lstm_bias_tensor
=
lstm_bias_var
->
Get
<
framework
::
LoDTensor
>
();
auto
alpha
=
1.0
f
;
auto
beta
=
1.0
f
;
int
m
=
embedding_tensor
.
dims
()[
0
];
int
n
=
weightx_tensor
.
dims
()[
1
];
int
k
=
embedding_tensor
.
dims
()[
1
];
// Copy only gate biases values (only actual bias data, not peephole
// weights)
std
::
vector
<
float
>
combined_biases
;
combined_biases
.
reserve
(
n
);
std
::
copy_n
(
lstm_bias_tensor
.
data
<
float
>
(),
n
,
std
::
back_inserter
(
combined_biases
));
if
(
with_fc_bias
)
{
// Add FC-bias with LSTM-bias (into GEMM result to be)
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias
->
Name
());
const
auto
&
fc_bias_tensor
=
fc_bias_var
->
Get
<
framework
::
LoDTensor
>
();
for
(
int
i
=
0
;
i
<
fc_bias_tensor
.
numel
();
i
++
)
{
combined_biases
[
i
]
+=
fc_bias_tensor
.
data
<
float
>
()[
i
];
}
}
// broadcast biases
std
::
vector
<
float
>
ones
(
m
,
1.0
f
);
paddle
::
operators
::
math
::
CBlas
<
float
>::
GEMM
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
m
,
n
,
1
,
alpha
,
&
ones
[
0
],
1
,
&
combined_biases
[
0
],
n
,
0.0
f
,
embeddings_data
,
n
);
// Wx*embeddings + biases
paddle
::
operators
::
math
::
CBlas
<
float
>::
GEMM
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
m
,
n
,
k
,
alpha
,
embedding_data
,
k
,
weightx_data
,
n
,
beta
,
embeddings_data
,
n
);
op_desc
.
SetInput
(
"Embeddings"
,
{
embeddings
});
// Create temp variables.
const
std
::
string
BatchedInput
=
patterns
::
UniqueKey
(
"BatchedInput"
);
const
std
::
string
BatchedCellPreAct
=
patterns
::
UniqueKey
(
"BatchedCellPreAct"
);
const
std
::
string
BatchedGate
=
patterns
::
UniqueKey
(
"BatchedGate"
);
scope
->
Var
(
BatchedInput
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedCellPreAct
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedGate
)
->
GetMutable
<
framework
::
LoDTensor
>
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx
->
Name
()});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
BatchedGate
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
BatchedCellPreAct
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
BatchedInput
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm
->
Op
()
->
GetAttr
(
"use_peepholes"
));
// TODO(TJ): get from attr
op_desc
.
SetAttr
(
"use_seq"
,
true
);
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT
(
BatchedCell
);
OP_SET_OUT
(
BatchedHidden
);
OP_SET_OUT
(
ReorderedH0
);
OP_SET_OUT
(
ReorderedC0
);
#undef OP_SET_OUT
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
IR_NODE_LINK_TO
(
input
,
op
);
IR_NODE_LINK_TO
(
weight_x
,
op
);
IR_NODE_LINK_TO
(
weight_h
,
op
);
IR_NODE_LINK_TO
(
bias
,
op
);
IR_NODE_LINK_TO
(
op
,
hidden
);
return
op
;
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
lstm
,
lstm
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Weight
,
Weight
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Bias
,
Bias
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Cell
,
Cell
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Hidden
,
Hidden
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table
,
lookup_table
,
embedding_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
W
,
W
,
embedding_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
w
,
w
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul
,
mul
,
fc_pattern
);
// TODO(jczaja): Add support for is_sparse / is_distributed
auto
is_sparse
=
boost
::
get
<
bool
>
(
lookup_table
->
Op
()
->
GetAttr
(
"is_sparse"
));
auto
is_distributed
=
boost
::
get
<
bool
>
(
lookup_table
->
Op
()
->
GetAttr
(
"is_distributed"
));
if
(
is_sparse
==
true
||
is_distributed
==
true
)
{
return
;
}
if
(
with_fc_bias
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
Out
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
bias
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add
,
elementwise_add
,
fc_pattern
);
embedding_lstm_creator
(
lookup_table
,
W
,
lstm
,
subgraph
.
at
(
x
),
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
fc_bias
);
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{
mul
,
lstm
,
elementwise_add
,
fc_bias
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
else
{
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
mul_out
,
fc_pattern
);
embedding_lstm_creator
(
lookup_table
,
W
,
lstm
,
subgraph
.
at
(
x
),
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
nullptr
);
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
// std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
// lstm});
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul
,
lstm
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
std
::
unique_ptr
<
ir
::
Graph
>
EmbeddingFCLSTMFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
BuildFusion
(
graph
.
get
(),
name_scope_
,
param_scope
(),
true
/*with_fc_bias*/
);
AddStatis
(
fusion_count
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
embedding_fc_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
EmbeddingFCLSTMFusePass
);
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
0 → 100644
浏览文件 @
f63ab561
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// Fusing of Embedding , FC and LSTM op
// Just FC without bias
class
EmbeddingFCLSTMFusePass
:
public
FusePassBase
{
public:
virtual
~
EmbeddingFCLSTMFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"embedding_fc_lstm_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
f63ab561
...
...
@@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
}
}
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
auto
*
lookup_table_op
=
pattern
->
NewNode
(
lookup_table_repr
())
->
assert_is_op
(
"lookup_table"
);
#define NEW_NODE(arg__, io__) \
auto *arg__ = pattern->NewNode(arg__##_repr()) \
->assert_is_op_##io__("lookup_table", #arg__);
NEW_NODE
(
W
,
input
);
NEW_NODE
(
Out
,
output
);
#undef NEW_NODE
lookup_table_op
->
LinksFrom
({
x
,
W
});
lookup_table_op
->
LinksTo
({
Out
});
return
Out
;
}
PDNode
*
patterns
::
LSTM
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lstm"
,
"Input"
);
auto
*
lstm_op
=
pattern
->
NewNode
(
lstm_repr
())
->
assert_is_op
(
"lstm"
);
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
f63ab561
...
...
@@ -418,6 +418,23 @@ struct FC : public PatternBase {
PATTERN_DECL_NODE
(
Out
);
};
// Embedding
struct
Embedding
:
public
PatternBase
{
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"embedding"
)
{}
PDNode
*
operator
()(
PDNode
*
x
);
// declare operator node's name
PATTERN_DECL_NODE
(
lookup_table
);
// Inputs
//
PATTERN_DECL_NODE
(
Ids
);
PATTERN_DECL_NODE
(
W
);
// embeddings
// Outputs
PATTERN_DECL_NODE
(
Out
);
};
struct
LSTM
:
public
PatternBase
{
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"lstm"
)
{}
...
...
paddle/fluid/framework/naive_executor.cc
浏览文件 @
f63ab561
...
...
@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/channel.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/string/pretty_log.h"
...
...
@@ -44,8 +46,6 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
CHANNEL
)
{
var
->
GetMutable
<
ChannelHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
// GetMutable will be called in operator
}
else
{
...
...
paddle/fluid/framework/tuple.h
浏览文件 @
f63ab561
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_desc.h"
...
...
paddle/fluid/framework/var_desc.cc
浏览文件 @
f63ab561
...
...
@@ -88,13 +88,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
}
void
VarDesc
::
SetDataType
(
proto
::
VarType
::
Type
data_type
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
mutable_channel_desc
()
->
set_data_type
(
data_type
);
break
;
default:
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
}
}
void
VarDesc
::
SetDataTypes
(
...
...
@@ -115,13 +109,7 @@ void VarDesc::SetDataTypes(
}
proto
::
VarType
::
Type
VarDesc
::
GetDataType
()
const
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
channel_desc
().
data_type
();
break
;
default:
return
tensor_desc
().
data_type
();
}
}
std
::
vector
<
proto
::
VarType
::
Type
>
VarDesc
::
GetDataTypes
()
const
{
...
...
@@ -134,17 +122,6 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
return
res
;
}
void
VarDesc
::
SetCapacity
(
int64_t
capacity
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
desc_
.
mutable_type
()
->
mutable_channel
()
->
set_capacity
(
capacity
);
break
;
default:
PADDLE_THROW
(
"Setting 'capacity' is not supported by the type of var %s."
,
this
->
Name
());
}
}
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
LOD_TENSOR
:
...
...
@@ -214,19 +191,6 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}
}
const
proto
::
VarType
::
ChannelDesc
&
VarDesc
::
channel_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
desc_
.
type
().
channel
();
default:
PADDLE_THROW
(
"Getting 'channel_desc' is not supported by the type of var %s."
,
this
->
Name
());
}
}
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
...
...
@@ -262,20 +226,6 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
}
}
proto
::
VarType
::
ChannelDesc
*
VarDesc
::
mutable_channel_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
desc_
.
mutable_type
()
->
mutable_channel
();
default:
PADDLE_THROW
(
"Getting 'mutable_channel_desc' is not supported by the type of var "
"%s."
,
this
->
Name
());
}
}
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
...
...
paddle/fluid/framework/var_desc.h
浏览文件 @
f63ab561
...
...
@@ -87,8 +87,6 @@ class VarDesc {
void
SetDataTypes
(
const
std
::
vector
<
proto
::
VarType
::
Type
>
&
multiple_data_type
);
void
SetCapacity
(
int64_t
capacity
);
proto
::
VarType
::
Type
GetDataType
()
const
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
()
const
;
...
...
@@ -110,10 +108,8 @@ class VarDesc {
void
SetPersistable
(
bool
persistable
)
{
desc_
.
set_persistable
(
persistable
);
}
private:
const
proto
::
VarType
::
ChannelDesc
&
channel_desc
()
const
;
const
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
tensor_descs
()
const
;
proto
::
VarType
::
ChannelDesc
*
mutable_channel_desc
();
proto
::
VarType
::
TensorDesc
*
mutable_tensor_desc
();
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
mutable_tensor_descs
();
...
...
paddle/fluid/framework/var_type.h
浏览文件 @
f63ab561
...
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
...
@@ -41,8 +40,6 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
return
proto
::
VarType_Type_SELECTED_ROWS
;
}
else
if
(
IsType
<
ReaderHolder
>
(
type
))
{
return
proto
::
VarType_Type_READER
;
}
else
if
(
IsType
<
ChannelHolder
>
(
type
))
{
return
proto
::
VarType_Type_CHANNEL
;
}
else
{
PADDLE_THROW
(
"ToVarType:Unsupported type %s"
,
type
.
name
());
}
...
...
@@ -66,9 +63,6 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case
proto
::
VarType_Type_READER
:
visitor
(
var
.
Get
<
ReaderHolder
>
());
return
;
case
proto
::
VarType_Type_CHANNEL
:
visitor
(
var
.
Get
<
ChannelHolder
>
());
return
;
default:
PADDLE_THROW
(
"Not supported visit type, %d"
,
ToVarType
(
var
.
Type
()));
}
...
...
paddle/fluid/inference/analysis/analysis_pass.h
浏览文件 @
f63ab561
...
...
@@ -41,12 +41,6 @@ class AnalysisPass {
// all passes have run.
virtual
bool
Finalize
()
{
return
false
;
}
// Get a Pass appropriate to print the Node this pass operates on.
virtual
AnalysisPass
*
CreatePrinterPass
(
std
::
ostream
&
os
,
const
std
::
string
&
banner
)
const
{
return
nullptr
;
}
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual
AnalysisPass
*
CreateGraphvizDebugerPass
()
const
{
return
nullptr
;
}
...
...
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
f63ab561
...
...
@@ -66,6 +66,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
// Manual update the passes here.
"infer_clean_graph_pass"
,
//
"attention_lstm_fuse_pass"
,
//
"embedding_fc_lstm_fuse_pass"
,
//
"fc_lstm_fuse_pass"
,
//
"mul_lstm_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
...
...
paddle/fluid/inference/api/api_impl_tester.cc
浏览文件 @
f63ab561
...
...
@@ -21,6 +21,12 @@ limitations under the License. */
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#ifdef __clang__
#define ACC_DIFF 4e-3
#else
#define ACC_DIFF 1e-3
#endif
DEFINE_string
(
dirname
,
""
,
"Directory of the inference model."
);
namespace
paddle
{
...
...
@@ -99,8 +105,8 @@ void MainWord2Vec(bool use_gpu) {
float
*
lod_data
=
output1
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
output1
.
numel
();
++
i
)
{
EXPECT_LT
(
lod_data
[
i
]
-
data
[
i
],
1e-3
);
EXPECT_GT
(
lod_data
[
i
]
-
data
[
i
],
-
1e-3
);
EXPECT_LT
(
lod_data
[
i
]
-
data
[
i
],
ACC_DIFF
);
EXPECT_GT
(
lod_data
[
i
]
-
data
[
i
],
-
ACC_DIFF
);
}
}
...
...
@@ -144,7 +150,7 @@ void MainImageClassification(bool use_gpu) {
float
*
data
=
static_cast
<
float
*>
(
outputs
[
0
].
data
.
data
());
float
*
lod_data
=
output1
.
data
<
float
>
();
for
(
size_t
j
=
0
;
j
<
len
/
sizeof
(
float
);
++
j
)
{
EXPECT_NEAR
(
lod_data
[
j
],
data
[
j
],
1e-3
);
EXPECT_NEAR
(
lod_data
[
j
],
data
[
j
],
ACC_DIFF
);
}
}
...
...
@@ -199,7 +205,7 @@ void MainThreadsWord2Vec(bool use_gpu) {
float
*
ref_data
=
refs
[
tid
].
data
<
float
>
();
EXPECT_EQ
(
refs
[
tid
].
numel
(),
static_cast
<
int64_t
>
(
len
/
sizeof
(
float
)));
for
(
int
i
=
0
;
i
<
refs
[
tid
].
numel
();
++
i
)
{
EXPECT_NEAR
(
ref_data
[
i
],
data
[
i
],
1e-3
);
EXPECT_NEAR
(
ref_data
[
i
],
data
[
i
],
ACC_DIFF
);
}
});
}
...
...
@@ -251,7 +257,7 @@ void MainThreadsImageClassification(bool use_gpu) {
float
*
ref_data
=
refs
[
tid
].
data
<
float
>
();
EXPECT_EQ
((
size_t
)
refs
[
tid
].
numel
(),
len
/
sizeof
(
float
));
for
(
int
i
=
0
;
i
<
refs
[
tid
].
numel
();
++
i
)
{
EXPECT_NEAR
(
ref_data
[
i
],
data
[
i
],
1e-3
);
EXPECT_NEAR
(
ref_data
[
i
],
data
[
i
],
ACC_DIFF
);
}
});
}
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
f63ab561
...
...
@@ -263,7 +263,7 @@ struct AnalysisConfig : public NativeConfig {
bool
enable_ir_optim
=
true
;
// Manually determine the IR passes to run.
IrPassMode
ir_mode
{
IrPassMode
::
kExclude
};
std
::
vector
<
std
::
string
>
ir_passes
;
std
::
vector
<
std
::
string
>
ir_passes
{
"embedding_fc_lstm_fuse_pass"
}
;
// NOT stable yet.
bool
use_feed_fetch_ops
{
true
};
...
...
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
浏览文件 @
f63ab561
...
...
@@ -104,5 +104,18 @@ TEST(Analyzer_Text_Classification, compare) {
CompareNativeAndAnalysis
(
cfg
,
input_slots_all
);
}
TEST
(
Analyzer_Text_Classification
,
compare_against_embedding_fc_lstm_fused
)
{
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
// Enable embedding_fc_lstm_fuse_pass (disabled by default)
auto
it
=
std
::
find
(
cfg
.
ir_passes
.
begin
(),
cfg
.
ir_passes
.
end
(),
"embedding_fc_lstm_fuse_pass"
);
if
(
it
!=
cfg
.
ir_passes
.
end
())
cfg
.
ir_passes
.
erase
(
it
);
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
CompareNativeAndAnalysis
(
cfg
,
input_slots_all
);
}
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
f63ab561
...
...
@@ -314,11 +314,6 @@ op_library(save_combine_op DEPS lod_tensor)
op_library
(
load_combine_op DEPS lod_tensor
)
op_library
(
concat_op DEPS concat
)
# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
add_subdirectory
(
concurrency
)
op_library
(
channel_send_op DEPS concurrency
)
op_library
(
channel_recv_op DEPS concurrency
)
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
...
...
paddle/fluid/operators/channel_close_op.cc
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
pf
=
paddle
::
framework
;
static
constexpr
char
kChannel
[]
=
"Channel"
;
namespace
paddle
{
namespace
operators
{
class
ChannelCloseOp
:
public
framework
::
OperatorBase
{
public:
ChannelCloseOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
inp
=
*
scope
.
FindVar
(
Input
(
kChannel
));
// Get the mutable version of the channel variable and closes it.
pf
::
ChannelHolder
*
ch
=
inp
.
GetMutable
<
framework
::
ChannelHolder
>
();
ch
->
close
();
}
};
class
ChannelCloseOpOpInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"Channel"
),
"The input of ChannelClose op must be set"
);
}
};
class
ChannelCloseOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
kChannel
,
"The Channel Variable that should be closed by"
" the ChannelClose Op."
);
AddComment
(
R"DOC(
Channel Close Operator.
This operator closes an open channel.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
channel_close
,
paddle
::
operators
::
ChannelCloseOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
ChannelCloseOpMaker
);
paddle/fluid/operators/channel_create_op.cc
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
namespace
pf
=
paddle
::
framework
;
static
constexpr
char
kOutput
[]
=
"Out"
;
namespace
paddle
{
namespace
operators
{
class
ChannelCreateOp
:
public
framework
::
OperatorBase
{
public:
ChannelCreateOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
out
=
*
scope
.
FindVar
(
Output
(
kOutput
));
// Determine the datatype and capacity of the channel to be created
// from the attributes provided.
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
Attr
<
int
>
(
"data_type"
));
auto
capacity
=
Attr
<
int
>
(
"capacity"
);
// Based on the datatype, create a new channel holder initialized with
// the given capacity. When capacity is 0, an unbuffered channel is
// created.
pf
::
ChannelHolder
*
ch
=
out
.
GetMutable
<
framework
::
ChannelHolder
>
();
if
(
dtype
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
ch
->
Reset
<
pf
::
LoDTensor
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
ch
->
Reset
<
pf
::
SelectedRows
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
LOD_RANK_TABLE
)
{
ch
->
Reset
<
pf
::
LoDRankTable
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
ch
->
Reset
<
pf
::
LoDTensorArray
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
READER
)
{
ch
->
Reset
<
pf
::
ReaderHolder
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
CHANNEL
)
{
ch
->
Reset
<
pf
::
ChannelHolder
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
BOOL
)
{
ch
->
Reset
<
bool
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
INT32
)
{
ch
->
Reset
<
int
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
INT64
)
{
ch
->
Reset
<
int64_t
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
FP32
)
{
ch
->
Reset
<
float
>
(
capacity
);
}
else
if
(
dtype
==
framework
::
proto
::
VarType
::
FP64
)
{
ch
->
Reset
<
double
>
(
capacity
);
}
else
{
PADDLE_THROW
(
"Data type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, LOD_RANK_TABLE, LOD_TENSOR_ARRAY, "
"READER, CHANNEL, BOOL, INT32, INT64, FP32, FP64]"
,
dtype
);
}
}
};
class
ChannelCreateOpOpInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasOutput
(
kOutput
),
"The output of ChannelCreate op must be set"
);
context
->
SetOutputDim
(
kOutput
,
{
1
});
}
};
class
ChannelCreateOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
kOutput
,
"The object of a Channel type created by ChannelCreate Op."
);
AddAttr
<
int
>
(
"capacity"
,
"The size of the buffer of Channel."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"data_type"
,
"The data type of elements inside the Channel."
);
AddComment
(
R"DOC(
Channel Create Operator.
This operator creates an object of the VarType Channel and returns it.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
channel_create
,
paddle
::
operators
::
ChannelCreateOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
ChannelCreateOpMaker
);
paddle/fluid/operators/channel_recv_op.cc
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include <paddle/fluid/framework/lod_rank_table.h>
#include <paddle/fluid/framework/lod_tensor_array.h>
#include <paddle/fluid/framework/reader.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/operators/math/math_function.h"
static
constexpr
char
Channel
[]
=
"Channel"
;
static
constexpr
char
Status
[]
=
"Status"
;
static
constexpr
char
Out
[]
=
"Out"
;
namespace
paddle
{
namespace
operators
{
void
SetReceiveStatus
(
const
platform
::
Place
&
dev_place
,
framework
::
Variable
*
status_var
,
bool
status
)
{
auto
cpu
=
platform
::
CPUPlace
();
auto
status_tensor
=
status_var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
mutable_data
<
bool
>
({
1
},
cpu
);
status_tensor
[
0
]
=
status
;
}
class
ChannelRecvOp
:
public
framework
::
OperatorBase
{
public:
ChannelRecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
Channel
),
"Input(Channel) of ChannelRecvOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
Out
),
"Input(Channel) of ChannelRecvOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
Status
),
"Output(Status) of ChannelRecvOp should not be null."
);
ctx
->
SetOutputDim
(
"Status"
,
{
1
});
}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// Get the channel holder created by channel_create op, passed as input.
framework
::
ChannelHolder
*
ch
=
scope
.
FindVar
(
Input
(
Channel
))
->
GetMutable
<
framework
::
ChannelHolder
>
();
auto
output_var
=
scope
.
FindVar
(
Output
(
Out
));
// Receive the data from the channel.
bool
ok
=
concurrency
::
ChannelReceive
(
ch
,
output_var
);
// Set the status output of the `ChannelReceive` call.
SetReceiveStatus
(
dev_place
,
scope
.
FindVar
(
Output
(
Status
)),
ok
);
}
};
class
ChannelRecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
Channel
,
"(Channel) A variable which
\"
receives
\"
the a value sent"
"to it by a channel_send op."
)
.
AsDuplicable
();
AddOutput
(
Out
,
"(Variable) Output Variable that will hold the data received"
" from the Channel"
)
.
AsDuplicable
();
AddOutput
(
Status
,
"(Tensor) An LoD Tensor that returns a boolean status of the"
"result of the receive operation."
)
.
AsDuplicable
();
AddComment
(
R"DOC(
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
channel_recv
,
paddle
::
operators
::
ChannelRecvOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
ChannelRecvOpMaker
);
paddle/fluid/operators/channel_send_op.cc
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include <paddle/fluid/framework/lod_rank_table.h>
#include <paddle/fluid/framework/lod_tensor_array.h>
#include <paddle/fluid/framework/reader.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/operators/math/math_function.h"
static
constexpr
char
Channel
[]
=
"Channel"
;
static
constexpr
char
X
[]
=
"X"
;
namespace
paddle
{
namespace
operators
{
class
ChannelSendOp
:
public
framework
::
OperatorBase
{
public:
ChannelSendOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
Channel
),
"Input(Channel) of ChannelSendOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
X
),
"Input(X) of ChannelSendOp should not be null."
);
}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// Get the channel holder created by channel_create op, passed as input.
framework
::
ChannelHolder
*
ch
=
scope
.
FindVar
(
Input
(
Channel
))
->
GetMutable
<
framework
::
ChannelHolder
>
();
auto
input_var
=
scope
.
FindVar
(
Input
(
X
));
// Send the input data through the channel.
concurrency
::
ChannelSend
(
ch
,
input_var
);
}
};
class
ChannelSendOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
Channel
,
"(Channel) A variable which
\"
sends
\"
the passed in value to "
"a listening receiver."
)
.
AsDuplicable
();
AddInput
(
X
,
"(Variable) The value which gets sent by the channel."
)
.
AsDuplicable
();
AddComment
(
R"DOC(
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
channel_send
,
paddle
::
operators
::
ChannelSendOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
ChannelSendOpMaker
);
paddle/fluid/operators/concurrency/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
8f5d918a
cc_library
(
concurrency SRCS channel_util.cc DEPS device_context framework_proto boost eigen3
)
paddle/fluid/operators/concurrency/channel_util.cc
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/framework/var_type.h"
namespace
poc
=
paddle
::
operators
::
concurrency
;
void
poc
::
ChannelSend
(
framework
::
ChannelHolder
*
ch
,
framework
::
Variable
*
var
)
{
auto
type
=
framework
::
ToVarType
(
var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
LoDTensor
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_RANK_TABLE
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
LoDRankTable
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR_ARRAY
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
LoDTensorArray
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
SelectedRows
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_READER
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
ReaderHolder
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_CHANNEL
)
ch
->
Send
(
var
->
GetMutable
<
framework
::
ChannelHolder
>
());
else
PADDLE_THROW
(
"ChannelSend:Unsupported type"
);
}
bool
poc
::
ChannelReceive
(
framework
::
ChannelHolder
*
ch
,
framework
::
Variable
*
var
)
{
// Get type of channel and use that to call mutable data for Variable
auto
type
=
framework
::
ToVarType
(
ch
->
Type
());
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
LoDTensor
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_RANK_TABLE
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
LoDRankTable
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR_ARRAY
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
LoDTensorArray
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
SelectedRows
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_READER
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
ReaderHolder
>
());
else
if
(
type
==
framework
::
proto
::
VarType_Type_CHANNEL
)
return
ch
->
Receive
(
var
->
GetMutable
<
framework
::
ChannelHolder
>
());
else
PADDLE_THROW
(
"ChannelReceive:Unsupported type"
);
}
void
poc
::
ChannelAddToSendQ
(
framework
::
ChannelHolder
*
ch
,
const
void
*
referrer
,
framework
::
Variable
*
var
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
framework
::
ChannelAction
)
>
cb
)
{
auto
type
=
framework
::
ToVarType
(
var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDTensor
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_RANK_TABLE
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDRankTable
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR_ARRAY
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDTensorArray
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
SelectedRows
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_READER
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
ReaderHolder
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_CHANNEL
)
{
ch
->
AddToSendQ
(
referrer
,
var
->
GetMutable
<
framework
::
ChannelHolder
>
(),
cond
,
cb
);
}
else
{
PADDLE_THROW
(
"ChannelAddToSendQ:Unsupported type"
);
}
}
void
poc
::
ChannelAddToReceiveQ
(
framework
::
ChannelHolder
*
ch
,
const
void
*
referrer
,
framework
::
Variable
*
var
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
framework
::
ChannelAction
)
>
cb
)
{
auto
type
=
framework
::
ToVarType
(
var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDTensor
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_RANK_TABLE
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDRankTable
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_LOD_TENSOR_ARRAY
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
LoDTensorArray
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
SelectedRows
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_READER
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
ReaderHolder
>
(),
cond
,
cb
);
}
else
if
(
type
==
framework
::
proto
::
VarType_Type_CHANNEL
)
{
ch
->
AddToReceiveQ
(
referrer
,
var
->
GetMutable
<
framework
::
ChannelHolder
>
(),
cond
,
cb
);
}
else
{
PADDLE_THROW
(
"ChannelAddToReceiveQ:Unsupported type"
);
}
}
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
f63ab561
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <time.h>
#include <atomic>
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
f63ab561
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <time.h>
#include <condition_variable> // NOLINT
#include <functional>
#include <string>
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
f63ab561
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <atomic>
#include <set>
#include <string>
#include <thread> // NOLINT
...
...
paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
0 → 100644
浏览文件 @
f63ab561
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused_embedding_fc_lstm_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
void
FusedEmbeddingFCLSTMOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Embeddings"
),
"Assert only one Input(Embeddings) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
"Assert only one Input(WeightH) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Assert only one Input(Bias) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Assert only one Output(XX) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Assert only one Output(Hidden) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Assert only one Output(Cell) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) of LookupTableOp should not be null."
);
auto
table_dims
=
ctx
->
GetInputDim
(
"Embeddings"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
int
ids_rank
=
ids_dims
.
size
();
PADDLE_ENFORCE_EQ
(
table_dims
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
ids_rank
-
1
],
1
,
"The last dimension of the 'Ids' tensor must be 1."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"Ids"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(Ids)'s rank must be 2."
);
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time."
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
"The dimension of Input(H0) and Input(C0) "
"should be the same."
);
}
auto
embeddings_dims
=
ctx
->
GetInputDim
(
"Embeddings"
);
PADDLE_ENFORCE_EQ
(
embeddings_dims
.
size
(),
2
,
"The rank of Input(Embeddings) should be 2."
);
auto
wh_dims
=
ctx
->
GetInputDim
(
"WeightH"
);
int
frame_size
=
wh_dims
[
1
]
/
4
;
PADDLE_ENFORCE_EQ
(
wh_dims
.
size
(),
2
,
"The rank of Input(WeightH) should be 2."
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
0
],
frame_size
,
"The first dimension of Input(WeightH) "
"should be %d."
,
frame_size
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(WeightH) "
"should be 4 * %d."
,
frame_size
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
)
?
7
:
4
)
*
frame_size
,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection or"
"4 * %d if disable peepholes"
,
frame_size
,
frame_size
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
ShareLoD
(
"Ids"
,
"Hidden"
);
ctx
->
ShareLoD
(
"Ids"
,
"Cell"
);
int
xx_width
;
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
))
{
xx_width
=
wh_dims
[
1
];
}
else
{
xx_width
=
x_dims
[
1
]
>
wh_dims
[
1
]
?
wh_dims
[
1
]
:
x_dims
[
1
];
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Assert only one Output(BatchedInput) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Assert only one Output(BatchedHidden) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Assert only one Output(BatchedCell) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Assert only one Output(ReorderedH0) of LSTM"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Assert only one Output(ReorderedC0) of LSTM."
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wh_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
}
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"Ids"
,
"XX"
);
}
framework
::
OpKernelType
FusedEmbeddingFCLSTMOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Embeddings"
)
->
type
()),
ctx
.
device_context
());
}
void
FusedEmbeddingFCLSTMOpMaker
::
Make
()
{
AddInput
(
"Ids"
,
"An input with type int32 or int64 "
"contains the ids to be looked up in W. "
"The last dimension size must be 1."
);
AddInput
(
"Embeddings"
,
"(Tensor) the learnable weights of X."
" - The shape is (M x 4D), where M is the dim size of x, D is the "
"hidden size. "
" - Weight = {W_cx, W_ix, W_fx, W_ox}"
);
AddInput
(
"WeightH"
,
"(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
" - Weight = {W_ch, W_ih, W_fh, W_oh}"
);
AddInput
(
"Bias"
,
"(Tensor) the learnable weights. Almost same as LSTMOp"
"Note: we should add the fc bias into this (1x4D) in bias."
"input-hidden bias weight and peephole connections weight if "
"setting `use_peepholes` True. "
"1. `use_peepholes = False` "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `use_peepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."
);
AddInput
(
"H0"
,
"(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
"optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size and D is the hidden size."
)
.
AsDispensable
();
AddInput
(
"C0"
,
"(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
"optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. `H0` and `C0` can be NULL but only at the same time."
)
.
AsDispensable
();
AddOutput
(
"Hidden"
,
"(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"Cell"
,
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"XX"
,
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
" or batched_X (size is T x M), this will be automatically chosen,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input."
)
.
AsIntermediate
();
AddOutput
(
"BatchedInput"
,
"(LoDTensor) (T x 4D)."
).
AsIntermediate
();
AddOutput
(
"BatchedHidden"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
AddOutput
(
"BatchedCell"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedH0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedC0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddAttr
<
bool
>
(
"use_peepholes"
,
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"is_reverse"
,
"(bool, defalut: False) "
"whether to compute reversed LSTM."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_seq"
,
"(bool, defalut: True) "
"whether to use seq mode to compute."
)
.
SetDefault
(
true
);
AddAttr
<
std
::
string
>
(
"gate_activation"
,
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default."
)
.
SetDefault
(
"sigmoid"
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
AddAttr
<
std
::
string
>
(
"cell_activation"
,
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut."
)
.
SetDefault
(
"tanh"
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
AddAttr
<
std
::
string
>
(
"candidate_activation"
,
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default."
)
.
SetDefault
(
"tanh"
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
AddComment
(
R"DOC(
Fusion Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC"
);
}
template
<
typename
T
>
class
FusedEmbeddingFCLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
#define INIT_VEC_FUNC \
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* ids = ctx.Input<LoDTensor>("Ids"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* embeddings = ctx.Input<Tensor>("Embeddings"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes");
#define INIT_BASE_SIZES \
auto ids_dims = ids->dims();
/* T x M*/
\
auto ids_numel = ids->numel();
/* T x 1*/
\
auto wh_dims = wh->dims();
/* D x 4D*/
\
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const int D3 = D * 3; \
int64_t row_number = embeddings->dims()[0]; \
int64_t row_width = embeddings->dims()[1]; \
const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \
const int64_t* ids_data = ids->data<int64_t>(); \
const T* embeddings_data = embeddings->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/
\
Tensor checked_cell; \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
}
/// Compute LSTM
#define GEMM_WH_ADDON(bs, prev, out) \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
wh_data, D4, static_cast<T>(1), out, D4)
// gates: W_ch, W_ih, W_fh, W_oh
#define GET_Ct(ct_1, gates, ct) \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, gates + D); \
blas.VMUL(D, ct_1, gates + D2, gates + D2); \
blas.VADD(D, gates + D, gates + D2, ct)
#define GET_Ht(ct, gates, ht) \
/* H_t = act_cell(C_t) * ogated */
\
act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht)
#define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/
\
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, ct)
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
act_gate(D3, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \
/* get fgated and igated*/
\
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
act_gate(D2, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
/* get ogated*/
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
// std::cout << "====> SeqCompute" << std::endl;
auto
ids_lod
=
ids
->
lod
();
const
int
total_T
=
ids_dims
[
0
];
const
int
N
=
ids_lod
[
0
].
size
()
-
1
;
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
h_out_data
=
hidden_out
->
mutable_data
<
T
>
(
place
);
T
*
c_out_data
=
cell_out
->
mutable_data
<
T
>
(
place
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
row_number
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
,
"ids %d"
,
i
);
memcpy
(
xx_data
+
i
*
row_width
,
embeddings_data
+
ids_data
[
i
]
*
row_width
,
row_width
*
sizeof
(
T
));
}
int
xx_offset
=
D4
;
int
gate_offset
=
D
;
if
(
is_reverse
)
{
const
int
offset
=
(
total_T
-
1
)
*
D
;
xx_data
=
xx_data
+
offset
*
4
;
h_out_data
=
h_out_data
+
offset
;
c_out_data
=
c_out_data
+
offset
;
xx_offset
=
-
D4
;
gate_offset
=
-
D
;
}
#define MOVE_ONE_STEP \
prev_h_data = h_out_data; \
prev_c_data = c_out_data; \
xx_data = xx_data + xx_offset; \
h_out_data = h_out_data + gate_offset; \
c_out_data = c_out_data + gate_offset
#define PROCESS_H0C0_DEFINES \
int bid = is_reverse ? N - 1 - i : i; \
int seq_len = ids_lod[0][bid + 1] - ids_lod[0][bid]; \
const T* prev_c_data = nullptr; \
const T* prev_h_data = nullptr; \
int tstart = 0
#define PROCESS_H0C0_PEEPHOLE \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
#define PROCESS_H0C0 \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
if
(
use_peepholes
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
PROCESS_H0C0_PEEPHOLE
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt_PEEPHOLE
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
}
else
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
PROCESS_H0C0
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
}
#undef PROCESS_H0C0_DEFINES
#undef PROCESS_H0C0_PEEPHOLE
#undef PROCESS_H0C0
#undef MOVE_ONE_STEP
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
if
(
ids
->
lod
()[
0
].
size
()
==
2
)
{
SeqCompute
(
ctx
);
return
;
}
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
// std::cout << "===> Batch Compute" << std::endl;
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
auto
*
batched_c_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedCell"
);
auto
*
batched_h_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_c_out_data
=
batched_c_out
->
mutable_data
<
T
>
(
place
);
T
*
batched_h_out_data
=
batched_h_out
->
mutable_data
<
T
>
(
place
);
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
row_number
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
,
"ids %d"
,
i
);
memcpy
(
xx_data
+
i
*
row_width
,
embeddings_data
+
ids_data
[
i
]
*
row_width
,
row_width
*
sizeof
(
T
));
}
to_batch
(
dev_ctx
,
*
xx
,
batched_input
,
true
,
is_reverse
);
auto
batched_lod
=
batched_input
->
lod
();
const
auto
&
seq_order
=
batched_lod
[
2
];
const
int
max_bs
=
seq_order
.
size
();
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
int
tstart
=
0
;
T
*
prev_h_data
=
nullptr
;
T
*
prev_c_data
=
nullptr
;
if
(
h0
)
{
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
prev_h_data
=
reordered_h0_data
;
prev_c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
std
::
memcpy
(
reordered_c0_data
,
c0_data
+
seq_order
[
i
]
*
D
,
sz
);
reordered_h0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
else
{
// compute without h0, c0
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
GET_Ct_NOH0C0
(
cur_in_data
,
cur_c_out_data
);
if
(
use_peepholes
)
{
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
cur_in_data
+
D
);
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
GET_Ht
(
cur_c_out_data
,
cur_in_data
,
cur_h_out_data
);
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
#define DEFINE_CUR \
T* cur_in_data = batched_input_data; \
T* cur_prev_c_data = prev_c_data; \
T* cur_c_out_data = batched_c_out_data; \
T* cur_h_out_data = batched_h_out_data
#define MOVE_ONE_BATCH \
cur_in_data += D4; \
cur_prev_c_data += D; \
cur_c_out_data += D; \
cur_h_out_data += D
#define MOVE_ONE_STEP \
prev_c_data = batched_c_out_data; \
prev_h_data = batched_h_out_data; \
batched_c_out_data = cur_c_out_data; \
batched_h_out_data = cur_h_out_data; \
batched_input_data = cur_in_data
if
(
use_peepholes
)
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
COMPUTE_CtHt_PEEPHOLE
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
else
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
COMPUTE_CtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
#undef MOVE_ONE_STEP
#undef MOVE_ONE_BATCH
#undef DEFINE_CUR
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batched_h_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_h_out
,
hidden_out
);
batched_c_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_c_out
,
cell_out
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
ctx
.
Attr
<
bool
>
(
"use_seq"
))
{
SeqCompute
(
ctx
);
}
else
{
BatchCompute
(
ctx
);
}
}
#undef COMPUTE_CtHt_PEEPHOLE
#undef COMPUTE_CtHt
#undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht
#undef GET_Ct
#undef GEMM_WH_ADDON
#undef INIT_BASE_INPUT_DATAS
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fused_embedding_fc_lstm
,
ops
::
FusedEmbeddingFCLSTMOp
,
ops
::
FusedEmbeddingFCLSTMOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OP_CPU_KERNEL
(
fused_embedding_fc_lstm
,
ops
::
FusedEmbeddingFCLSTMKernel
<
float
>
,
ops
::
FusedEmbeddingFCLSTMKernel
<
double
>
);
paddle/fluid/operators/
concurrency/channel_util
.h
→
paddle/fluid/operators/
fused_embedding_fc_lstm_op
.h
浏览文件 @
f63ab561
...
...
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -13,26 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
concurrency
{
void
ChannelSend
(
framework
::
ChannelHolder
*
ch
,
framework
::
Variable
*
var
);
bool
ChannelReceive
(
framework
::
ChannelHolder
*
ch
,
framework
::
Variable
*
var
);
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
class
FusedEmbeddingFCLSTMOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
void
ChannelAddToSendQ
(
framework
::
ChannelHolder
*
ch
,
const
void
*
referrer
,
framework
::
Variable
*
var
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
framework
::
ChannelAction
)
>
cb
);
void
ChannelAddToReceiveQ
(
framework
::
ChannelHolder
*
ch
,
const
void
*
referrer
,
framework
::
Variable
*
var
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
cond
,
std
::
function
<
bool
(
framework
::
ChannelAction
)
>
cb
);
class
FusedEmbeddingFCLSTMOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
;
};
}
// namespace concurrency
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/select_op.cc
已删除
100644 → 0
浏览文件 @
8f5d918a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include <boost/tokenizer.hpp>
namespace
paddle
{
namespace
operators
{
static
constexpr
char
kX
[]
=
"X"
;
static
constexpr
char
kCaseToExecute
[]
=
"case_to_execute"
;
static
constexpr
char
kOutputs
[]
=
"Out"
;
static
constexpr
char
kCases
[]
=
"cases"
;
static
constexpr
char
kCasesBlock
[]
=
"sub_block"
;
class
SelectOp
:
public
framework
::
OperatorBase
{
public:
SelectOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
enum
class
SelectOpCaseType
{
DEFAULT
=
0
,
SEND
=
1
,
RECEIVE
=
2
,
};
struct
SelectOpCase
{
int
caseIndex
;
SelectOpCaseType
caseType
;
std
::
string
channelName
;
std
::
string
varName
;
SelectOpCase
()
{}
SelectOpCase
(
int
caseIndex
,
SelectOpCaseType
caseType
,
std
::
string
channelName
,
std
::
string
varName
)
:
caseIndex
(
caseIndex
),
caseType
(
caseType
),
channelName
(
channelName
),
varName
(
varName
)
{}
};
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
std
::
vector
<
std
::
string
>
casesConfigs
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kCases
);
framework
::
BlockDesc
*
casesBlock
=
Attr
<
framework
::
BlockDesc
*>
(
kCasesBlock
);
framework
::
Scope
&
casesBlockScope
=
scope
.
NewScope
();
std
::
string
caseToExecuteVarName
=
Input
(
kCaseToExecute
);
framework
::
Variable
*
caseToExecuteVar
=
casesBlockScope
.
FindVar
(
caseToExecuteVarName
);
// Construct cases from "conditional_block_op"(s) in the casesBlock
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
cases
=
ParseAndShuffleCases
(
&
casesConfigs
);
// Get all unique channels involved in select
std
::
set
<
framework
::
ChannelHolder
*>
channelsSet
;
for
(
auto
c
:
cases
)
{
if
(
!
c
->
channelName
.
empty
())
{
auto
channelVar
=
scope
.
FindVar
(
c
->
channelName
);
framework
::
ChannelHolder
*
ch
=
channelVar
->
GetMutable
<
framework
::
ChannelHolder
>
();
if
(
channelsSet
.
find
(
ch
)
==
channelsSet
.
end
())
{
channelsSet
.
insert
(
ch
);
}
}
}
// Order all channels by their pointer address
std
::
vector
<
framework
::
ChannelHolder
*>
channels
(
channelsSet
.
begin
(),
channelsSet
.
end
());
std
::
sort
(
channels
.
begin
(),
channels
.
end
());
// Poll all cases
int32_t
caseToExecute
=
pollCases
(
&
scope
,
&
cases
,
channels
);
// At this point, the case to execute has already been determined,
// so we can proceed with executing the cases block
framework
::
LoDTensor
*
caseToExecuteTensor
=
caseToExecuteVar
->
GetMutable
<
framework
::
LoDTensor
>
();
caseToExecuteTensor
->
data
<
int32_t
>
()[
0
]
=
caseToExecute
;
// Execute the cases block, only one case will be executed since we set the
// case_to_execute value to the index of the case we want to execute
framework
::
Executor
executor
(
dev_place
);
framework
::
ProgramDesc
*
program
=
casesBlock
->
Program
();
executor
.
Run
(
*
program
,
&
casesBlockScope
,
casesBlock
->
ID
(),
false
/*create_local_scope*/
);
}
/**
* Goes through all operators in the casesConfigs and processes
* "conditional_block" operators. These operators are mapped to our
* SelectOpCase objects. We randomize the case orders, and set the
* default case (if any exists) as the last case)
* @param casesBlock
* @return
*/
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
ParseAndShuffleCases
(
std
::
vector
<
std
::
string
>
*
casesConfigs
)
const
{
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
cases
;
std
::
shared_ptr
<
SelectOpCase
>
defaultCase
;
if
(
casesConfigs
!=
nullptr
)
{
boost
::
char_delimiters_separator
<
char
>
sep
(
false
,
","
,
""
);
for
(
std
::
vector
<
std
::
string
>::
iterator
itr
=
casesConfigs
->
begin
();
itr
<
casesConfigs
->
end
();
++
itr
)
{
std
::
string
caseConfig
=
*
itr
;
boost
::
tokenizer
<>
tokens
(
caseConfig
,
sep
);
boost
::
tokenizer
<>::
iterator
tok_iter
=
tokens
.
begin
();
PADDLE_ENFORCE
(
tok_iter
!=
tokens
.
end
(),
"Cannot get case index"
);
std
::
string
caseIndexString
=
*
tok_iter
;
int
caseIndex
=
std
::
stoi
(
caseIndexString
);
++
tok_iter
;
PADDLE_ENFORCE
(
tok_iter
!=
tokens
.
end
(),
"Cannot get case type"
);
std
::
string
caseTypeString
=
*
tok_iter
;
SelectOpCaseType
caseType
=
(
SelectOpCaseType
)
std
::
stoi
(
caseTypeString
);
std
::
string
caseChannel
;
std
::
string
caseChannelVar
;
++
tok_iter
;
if
(
caseType
!=
SelectOpCaseType
::
DEFAULT
)
{
PADDLE_ENFORCE
(
tok_iter
!=
tokens
.
end
(),
"Cannot get case channel"
);
caseChannel
=
*
tok_iter
;
++
tok_iter
;
PADDLE_ENFORCE
(
tok_iter
!=
tokens
.
end
(),
"Cannot get case channel variable"
);
caseChannelVar
=
*
tok_iter
;
}
auto
c
=
std
::
make_shared
<
SelectOpCase
>
(
caseIndex
,
caseType
,
caseChannel
,
caseChannelVar
);
if
(
caseType
==
SelectOpCaseType
::
DEFAULT
)
{
PADDLE_ENFORCE
(
defaultCase
==
nullptr
,
"Select can only contain one default case."
);
defaultCase
=
c
;
}
else
{
cases
.
push_back
(
c
);
}
}
}
// Randomly sort cases, with default case being last
std
::
random_shuffle
(
cases
.
begin
(),
cases
.
end
());
if
(
defaultCase
!=
nullptr
)
{
cases
.
push_back
(
defaultCase
);
}
return
cases
;
}
/**
* This method will recursively poll the cases and determines if any case
* condition is true.
* If none of the cases conditions are true (and there is no default case),
* then block
* the thread. The thread may be woken up by a channel operation, at which
* point we
* execute the case.
* @param scope
* @param cases
* @param channels
* @return
*/
int32_t
pollCases
(
const
framework
::
Scope
*
scope
,
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
*
cases
,
std
::
vector
<
framework
::
ChannelHolder
*>
channels
)
const
{
// Lock all involved channels
lockChannels
(
channels
);
std
::
atomic
<
int
>
caseToExecute
(
-
1
);
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>::
iterator
it
=
cases
->
begin
();
while
(
it
!=
cases
->
end
())
{
std
::
shared_ptr
<
SelectOpCase
>
c
=
*
it
;
auto
chVar
=
scope
->
FindVar
(
c
->
channelName
);
framework
::
ChannelHolder
*
ch
=
chVar
->
GetMutable
<
framework
::
ChannelHolder
>
();
switch
(
c
->
caseType
)
{
case
SelectOpCaseType
::
SEND
:
PADDLE_ENFORCE
(
!
ch
->
IsClosed
(),
"Cannot send to a closed channel"
);
if
(
ch
->
CanSend
())
{
// We can send to channel directly, send the data to channel
// and execute case
auto
chVar
=
scope
->
FindVar
(
c
->
varName
);
concurrency
::
ChannelSend
(
ch
,
chVar
);
caseToExecute
=
c
->
caseIndex
;
}
break
;
case
SelectOpCaseType
::
RECEIVE
:
if
(
ch
->
CanReceive
())
{
// We can receive from channel directly, send the data to channel
// and execute case
auto
chVar
=
scope
->
FindVar
(
c
->
varName
);
concurrency
::
ChannelReceive
(
ch
,
chVar
);
caseToExecute
=
c
->
caseIndex
;
}
break
;
case
SelectOpCaseType
::
DEFAULT
:
caseToExecute
=
c
->
caseIndex
;
break
;
}
if
(
caseToExecute
!=
-
1
)
{
// We found a case to execute, stop looking at other case statements
break
;
}
++
it
;
}
if
(
caseToExecute
==
-
1
)
{
// None of the cases are eligible to execute, enqueue current thread
// into all the sending/receiving queue of each involved channel
std
::
atomic
<
bool
>
completed
(
false
);
std
::
recursive_mutex
mutex
;
std
::
unique_lock
<
std
::
recursive_mutex
>
lock
{
mutex
};
// std::condition_variable_any selectCond;
auto
selectCond
=
std
::
make_shared
<
std
::
condition_variable_any
>
();
std
::
recursive_mutex
callbackMutex
;
pushThreadOnChannelQueues
(
scope
,
cases
,
selectCond
,
&
caseToExecute
,
&
completed
,
&
callbackMutex
);
// TODO(thuan): Atomically unlock all channels and sleep current thread
unlockChannels
(
channels
);
selectCond
->
wait
(
lock
,
[
&
completed
]()
{
return
completed
.
load
();
});
// Select has been woken up by case operation
lockChannels
(
channels
);
removeThreadOnChannelQueues
(
scope
,
cases
);
if
(
caseToExecute
==
-
1
)
{
// Recursively poll cases, since we were woken up by a channel close
// TODO(thuan): Need to test if this is a valid case
unlockChannels
(
channels
);
return
pollCases
(
scope
,
cases
,
channels
);
}
}
// At this point, caseToExecute != -1, and we can proceed with executing
// the case block
unlockChannels
(
channels
);
return
caseToExecute
;
}
void
lockChannels
(
std
::
vector
<
framework
::
ChannelHolder
*>
chs
)
const
{
std
::
vector
<
framework
::
ChannelHolder
*>::
iterator
it
=
chs
.
begin
();
while
(
it
!=
chs
.
end
())
{
framework
::
ChannelHolder
*
ch
=
*
it
;
ch
->
Lock
();
++
it
;
}
}
void
unlockChannels
(
std
::
vector
<
framework
::
ChannelHolder
*>
chs
)
const
{
std
::
vector
<
framework
::
ChannelHolder
*>::
reverse_iterator
it
=
chs
.
rbegin
();
while
(
it
!=
chs
.
rend
())
{
framework
::
ChannelHolder
*
ch
=
*
it
;
ch
->
Unlock
();
++
it
;
}
}
void
pushThreadOnChannelQueues
(
const
framework
::
Scope
*
scope
,
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
*
cases
,
std
::
shared_ptr
<
std
::
condition_variable_any
>
rCond
,
std
::
atomic
<
int
>
*
caseToExecute
,
std
::
atomic
<
bool
>
*
completed
,
std
::
recursive_mutex
*
callbackMutex
)
const
{
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>::
iterator
it
=
cases
->
begin
();
while
(
it
!=
cases
->
end
())
{
std
::
shared_ptr
<
SelectOpCase
>
c
=
*
it
;
auto
chVar
=
scope
->
FindVar
(
c
->
channelName
);
framework
::
ChannelHolder
*
ch
=
chVar
->
GetMutable
<
framework
::
ChannelHolder
>
();
std
::
function
<
bool
(
framework
::
ChannelAction
channelAction
)
>
cb
=
[
&
caseToExecute
,
&
completed
,
&
callbackMutex
,
c
](
framework
::
ChannelAction
channelAction
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
{
*
callbackMutex
};
bool
canProcess
=
false
;
if
(
!
(
*
completed
))
{
// If the channel wasn't closed, we set the caseToExecute index
// as this current case
if
(
channelAction
!=
framework
::
ChannelAction
::
CLOSE
)
{
*
caseToExecute
=
c
->
caseIndex
;
}
// This will allow our conditional variable to break out of wait
*
completed
=
true
;
canProcess
=
true
;
}
return
canProcess
;
};
switch
(
c
->
caseType
)
{
case
SelectOpCaseType
::
SEND
:
{
auto
chOutputVar
=
scope
->
FindVar
(
c
->
varName
);
concurrency
::
ChannelAddToSendQ
(
ch
,
this
,
chOutputVar
,
rCond
,
cb
);
break
;
}
case
SelectOpCaseType
::
RECEIVE
:
{
auto
chOutputVar
=
scope
->
FindVar
(
c
->
varName
);
concurrency
::
ChannelAddToReceiveQ
(
ch
,
this
,
chOutputVar
,
rCond
,
cb
);
break
;
}
default:
break
;
}
++
it
;
}
}
void
removeThreadOnChannelQueues
(
const
framework
::
Scope
*
scope
,
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>
*
cases
)
const
{
std
::
vector
<
std
::
shared_ptr
<
SelectOpCase
>>::
iterator
it
=
cases
->
begin
();
while
(
it
!=
cases
->
end
())
{
std
::
shared_ptr
<
SelectOpCase
>
c
=
*
it
;
auto
chVar
=
scope
->
FindVar
(
c
->
channelName
);
framework
::
ChannelHolder
*
ch
=
chVar
->
GetMutable
<
framework
::
ChannelHolder
>
();
switch
(
c
->
caseType
)
{
case
SelectOpCaseType
::
SEND
:
{
ch
->
RemoveFromSendQ
(
this
);
break
;
}
case
SelectOpCaseType
::
RECEIVE
:
{
ch
->
RemoveFromReceiveQ
(
this
);
break
;
}
default:
break
;
}
++
it
;
}
}
};
class
SelectOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
kX
,
"A set of variables, which are required by operators inside the "
"cases of Select Op"
)
.
AsDuplicable
();
AddInput
(
kCaseToExecute
,
"(Int) The variable the sets the index of the case to execute, "
"after evaluating the channels being sent to and received from"
)
.
AsDuplicable
();
AddOutput
(
kOutputs
,
"A set of variables, which will be assigned with values "
"generated by the operators inside the cases of Select Op."
)
.
AsDuplicable
();
AddAttr
<
std
::
vector
<
std
::
string
>>
(
kCases
,
"(String vector) Serialized list of"
"all cases in the select op. Each"
"case is serialized as: "
"'<index>,<type>,<channel>,<value>'"
"where type is 0 for default, 1 for"
"send, and 2 for receive"
"No channel and values are needed for"
"default cases."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kCasesBlock
,
"The cases block inside select_op"
);
AddComment
(
R"DOC(
)DOC"
);
}
};
// TODO(thuan): Implement Gradient Operator for SELECT_OP
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
select
,
paddle
::
operators
::
SelectOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
SelectOpMaker
);
paddle/fluid/pybind/protobuf.cc
浏览文件 @
f63ab561
...
...
@@ -214,7 +214,6 @@ void BindVarDsec(pybind11::module *m) {
.
def
(
"set_shapes"
,
&
pd
::
VarDesc
::
SetShapes
)
.
def
(
"set_dtype"
,
&
pd
::
VarDesc
::
SetDataType
)
.
def
(
"set_dtypes"
,
&
pd
::
VarDesc
::
SetDataTypes
)
.
def
(
"set_capacity"
,
&
pd
::
VarDesc
::
SetCapacity
)
.
def
(
"shape"
,
&
pd
::
VarDesc
::
GetShape
,
pybind11
::
return_value_policy
::
reference
)
.
def
(
"shapes"
,
&
pd
::
VarDesc
::
GetShapes
,
...
...
@@ -251,7 +250,6 @@ void BindVarDsec(pybind11::module *m) {
.
value
(
"STEP_SCOPES"
,
pd
::
proto
::
VarType
::
STEP_SCOPES
)
.
value
(
"LOD_RANK_TABLE"
,
pd
::
proto
::
VarType
::
LOD_RANK_TABLE
)
.
value
(
"LOD_TENSOR_ARRAY"
,
pd
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
.
value
(
"CHANNEL"
,
pd
::
proto
::
VarType
::
CHANNEL
)
.
value
(
"PLACE_LIST"
,
pd
::
proto
::
VarType
::
PLACE_LIST
)
.
value
(
"READER"
,
pd
::
proto
::
VarType
::
READER
)
.
value
(
"RAW"
,
pd
::
proto
::
VarType
::
RAW
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
f63ab561
...
...
@@ -21,7 +21,6 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
...
...
paddle/legacy/trainer/tests/CMakeLists.txt
浏览文件 @
f63ab561
...
...
@@ -16,7 +16,11 @@ endfunction()
trainer_test
(
test_Compare
)
trainer_test
(
test_PyDataProviderWrapper
)
trainer_test
(
test_recurrent_machine_generation
)
trainer_test
(
test_Trainer
)
if
(
NOT APPLE
)
trainer_test
(
test_Trainer
)
else
()
message
(
WARNING
"These tests has been disabled in OSX for random fail:
\n
test_Trainer"
)
endif
()
############### test_TrainerOnePass ##########################
if
(
WITH_PYTHON
)
...
...
python/paddle/fluid/concurrency.py
已删除
100644 → 0
浏览文件 @
8f5d918a
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
.layers.control_flow
import
BlockGuard
,
equal
from
.framework
import
Operator
from
.layer_helper
import
LayerHelper
,
unique_name
from
.layers
import
fill_constant
from
.
import
core
__all__
=
[
'make_channel'
,
'channel_send'
,
'channel_recv'
,
'channel_close'
,
'Select'
]
class
Go
(
BlockGuard
):
def
__init__
(
self
,
name
=
None
):
self
.
helper
=
LayerHelper
(
"go"
,
name
=
name
)
super
(
Go
,
self
).
__init__
(
self
.
helper
.
main_program
)
def
__enter__
(
self
):
super
(
Go
,
self
).
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
if
exc_type
is
not
None
:
return
False
self
.
_construct_go_op
()
return
super
(
Go
,
self
).
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
def
_construct_go_op
(
self
):
main_program
=
self
.
helper
.
main_program
go_block
=
main_program
.
current_block
()
parent_block
=
main_program
.
block
(
main_program
.
current_block
()
.
parent_idx
)
inner_outputs
=
set
()
x_name_list
=
set
()
for
op
in
go_block
.
ops
:
# Iterate over all operators, get all the inputs
# and add as input to the Go operator.
for
iname
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
iname
):
if
in_var_name
not
in
inner_outputs
:
x_name_list
.
add
(
in_var_name
)
for
oname
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
oname
):
inner_outputs
.
add
(
out_var_name
)
# Iterate over all operators , get all the outputs
# add to the output list of Go operator only if
# they exist in the parent block.
out_vars
=
[]
for
inner_out_name
in
inner_outputs
:
if
inner_out_name
in
parent_block
.
vars
:
out_vars
.
append
(
parent_block
.
var
(
inner_out_name
))
parent_block
.
append_op
(
type
=
'go'
,
inputs
=
{
'X'
:
[
parent_block
.
_var_recursive
(
x_name
)
for
x_name
in
x_name_list
]
},
outputs
=
{},
attrs
=
{
'sub_block'
:
go_block
})
class
SelectCase
(
object
):
DEFAULT
=
0
SEND
=
1
RECEIVE
=
2
def
__init__
(
self
,
select
,
case_idx
,
case_to_execute
,
channel_action_fn
=
None
,
channel
=
None
,
value
=
None
,
is_copy
=
False
):
self
.
select
=
select
self
.
helper
=
LayerHelper
(
'conditional_block'
)
self
.
main_program
=
self
.
helper
.
main_program
self
.
is_scalar_condition
=
True
self
.
case_to_execute
=
case_to_execute
self
.
idx
=
case_idx
# Since we aren't going to use the `channel_send` or `channel_recv`
# functions directly, we just need to capture the name.
self
.
action
=
(
self
.
SEND
if
channel_action_fn
.
__name__
==
(
'channel_send'
)
else
self
.
RECEIVE
)
if
channel_action_fn
else
self
.
DEFAULT
X
=
value
if
self
.
action
==
self
.
SEND
and
is_copy
:
# We create of copy of the data we want to send
copied_X
=
self
.
select
.
parent_block
.
create_var
(
name
=
unique_name
.
generate
(
value
.
name
+
'_copy'
),
type
=
value
.
type
,
dtype
=
value
.
dtype
,
shape
=
value
.
shape
,
lod_level
=
value
.
lod_level
,
capacity
=
value
.
capacity
if
hasattr
(
value
,
'capacity'
)
else
None
,
)
self
.
select
.
parent_block
.
append_op
(
type
=
"assign"
,
inputs
=
{
"X"
:
value
},
outputs
=
{
"Out"
:
copied_X
})
X
=
copied_X
self
.
value
=
X
self
.
channel
=
channel
def
__enter__
(
self
):
self
.
block
=
self
.
main_program
.
_create_block
()
def
construct_op
(
self
):
main_program
=
self
.
helper
.
main_program
cases_block
=
main_program
.
current_block
()
inner_outputs
=
set
()
input_set
=
set
()
params
=
set
()
for
op
in
self
.
block
.
ops
:
# Iterate over all operators, get all the inputs
# and add as input to the SelectCase operator.
for
iname
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
iname
):
if
in_var_name
not
in
inner_outputs
:
input_set
.
add
(
in_var_name
)
for
oname
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
oname
):
inner_outputs
.
add
(
out_var_name
)
param_list
=
[
cases_block
.
var
(
each_name
)
for
each_name
in
params
if
each_name
not
in
input_set
]
# Iterate over all operators, get all the outputs
# add to the output list of SelectCase operator only if
# they exist in the parent block.
out_vars
=
[]
for
inner_out_name
in
inner_outputs
:
if
inner_out_name
in
cases_block
.
vars
:
out_vars
.
append
(
cases_block
.
var
(
inner_out_name
))
# First, create an op that will determine whether or not this is the
# conditional variable to execute.
should_execute_block
=
equal
(
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
self
.
idx
),
self
.
case_to_execute
)
step_scope
=
cases_block
.
create_var
(
type
=
core
.
VarDesc
.
VarType
.
STEP_SCOPES
)
cases_block
.
append_op
(
type
=
'conditional_block'
,
inputs
=
{
'X'
:
[
should_execute_block
],
'Params'
:
param_list
},
outputs
=
{
'Out'
:
out_vars
,
'Scope'
:
[
step_scope
]},
attrs
=
{
'sub_block'
:
self
.
block
,
'is_scalar_condition'
:
self
.
is_scalar_condition
})
return
'%s,%s,%s,%s'
%
(
self
.
idx
,
self
.
action
,
self
.
channel
.
name
if
self
.
channel
else
''
,
self
.
value
.
name
if
self
.
value
else
''
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
main_program
.
_rollback
()
if
exc_type
is
not
None
:
return
False
# re-raise exception
return
True
class
Select
(
BlockGuard
):
def
__init__
(
self
,
name
=
None
):
self
.
helper
=
LayerHelper
(
'select'
,
name
=
name
)
self
.
parent_block
=
self
.
helper
.
main_program
.
current_block
()
self
.
cases
=
[]
super
(
Select
,
self
).
__init__
(
self
.
helper
.
main_program
)
self
.
case_to_execute
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=-
1
)
def
__enter__
(
self
):
super
(
Select
,
self
).
__enter__
()
return
self
def
case
(
self
,
channel_action_fn
,
channel
,
value
,
is_copy
=
False
):
"""Create a new block for this condition.
"""
select_case
=
SelectCase
(
self
,
len
(
self
.
cases
),
self
.
case_to_execute
,
channel_action_fn
,
channel
,
value
,
is_copy
)
self
.
cases
.
append
(
select_case
)
return
select_case
def
default
(
self
):
"""Create a default case block for this condition.
"""
default_case
=
SelectCase
(
self
,
len
(
self
.
cases
),
self
.
case_to_execute
)
self
.
cases
.
append
(
default_case
)
return
default_case
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
if
exc_type
is
not
None
:
return
False
# Create a select op and another block to wrap its
# case blocks.
select_block
=
self
.
helper
.
main_program
.
current_block
()
parent_block
=
self
.
helper
.
main_program
.
block
(
select_block
.
parent_idx
)
# Construct each case op, inside the newly created select block.
serialized_cases
=
[]
for
case
in
self
.
cases
:
serialized_cases
.
append
(
case
.
construct_op
())
intermediate
=
set
()
params
=
set
()
for
case_block
in
select_block
.
ops
:
if
case_block
.
attrs
and
'sub_block'
in
case_block
.
attrs
:
for
each_op
in
case_block
.
attrs
[
'sub_block'
].
ops
:
assert
isinstance
(
each_op
,
Operator
)
for
iname
in
each_op
.
input_names
:
for
in_var_name
in
each_op
.
input
(
iname
):
if
in_var_name
not
in
intermediate
:
params
.
add
(
in_var_name
)
for
oname
in
each_op
.
output_names
:
for
out_var_name
in
each_op
.
output
(
oname
):
intermediate
.
add
(
out_var_name
)
out_list
=
[
parent_block
.
var
(
var_name
)
for
var_name
in
parent_block
.
vars
if
var_name
in
intermediate
]
X
=
[
select_block
.
_var_recursive
(
x_name
)
for
x_name
in
params
]
# Needs to be used by `equal` inside the cases block.
X
.
append
(
self
.
case_to_execute
)
# Construct the select op.
parent_block
.
append_op
(
type
=
'select'
,
inputs
=
{
'X'
:
X
,
'case_to_execute'
:
self
.
case_to_execute
},
attrs
=
{
'sub_block'
:
select_block
,
'cases'
:
serialized_cases
},
outputs
=
{
'Out'
:
out_list
})
return
super
(
Select
,
self
).
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
def
make_channel
(
dtype
,
capacity
=
0
):
"""
Helps implementation of a concurrent program by creating a "channel" of
a defined data type. Channels allow for the passing of data in
concurrent scenarios - such as when using threads to divide computation.
Channels can be used to "send" and "receive" such data concurrently.
There are two kinds of channels: unbuffered and buffered. Unbuffered
channels have no capacity - and thus, block on send and only unblock only
once what they have sent has been received.
On the other hand, buffered channels are initialized with a capacity -
and do not block on sends.
Use this method in combination with `channel_send`, `channel_recv`,
`channel_close`, and `Go` to design a concurrent Paddle program.
Args:
dtype (ParamAttr|string): Data type of the data sent in the channel.
This data type should be the string name of a numpy data type.
capacity (ParamAttr|int): Size of the channel. Defaults to 0 for
to create an unbuffered channel.
Returns:
Variable: The channel variable that can be used to send an receive data
of the defined dtype.
Examples:
.. code-block:: python
ch = fluid.make_channel(dtype='int32', capacity=10)
...
# Code to execute in a Go block, which receives the channel data.
fluid.channel_send(ch, 100)
fluid.channel_close(ch)
"""
helper
=
LayerHelper
(
'channel_create'
,
**
locals
())
main_program
=
helper
.
main_program
make_channel_block
=
main_program
.
current_block
()
# Make a channel variable (using the channel data type) and make sure it
# persists into the global scope.
channel
=
helper
.
create_variable
(
name
=
unique_name
.
generate
(
'channel'
),
type
=
core
.
VarDesc
.
VarType
.
CHANNEL
,
persistable
=
True
)
create_channel_op
=
make_channel_block
.
append_op
(
type
=
"channel_create"
,
outputs
=
{
"Out"
:
channel
},
attrs
=
{
"data_type"
:
dtype
,
"capacity"
:
capacity
})
return
channel
def
channel_send
(
channel
,
value
,
is_copy
=
False
):
"""
Sends a value through a channel variable. Used by an unbuffered or buffered
channel to pass data from within or to a concurrent Go block, where
`channel_recv` to used to get the passed value.
Args:
channel (Variable|Channel): Channel variable created using
`make_channel`.
value (Variable): Value to send to channel
is_copy (bool): Copy data while channel send. If False, then data
is moved. The input cannot be used after move. (default False)
Returns:
Variable: The boolean status on whether or not the channel
successfully sent the passed value.
Examples:
.. code-block:: python
ch = fluid.make_channel(dtype='int32', capacity=10)
...
# Code to execute in a Go block, which receives the channel data.
fluid.channel_send(ch, 100)
"""
helper
=
LayerHelper
(
'channel_send'
,
**
locals
())
main_program
=
helper
.
main_program
channel_send_block
=
main_program
.
current_block
()
X
=
value
if
is_copy
:
copied_X
=
helper
.
create_variable
(
name
=
unique_name
.
generate
(
value
.
name
+
'_copy'
),
type
=
value
.
type
,
dtype
=
value
.
dtype
,
shape
=
value
.
shape
,
lod_level
=
value
.
lod_level
,
capacity
=
value
.
capacity
if
hasattr
(
value
,
'capacity'
)
else
None
)
assign_op
=
channel_send_block
.
append_op
(
type
=
"assign"
,
inputs
=
{
"X"
:
value
},
outputs
=
{
"Out"
:
copied_X
})
X
=
copied_X
channel_send_block
.
append_op
(
type
=
"channel_send"
,
inputs
=
{
"Channel"
:
channel
,
"X"
:
X
,
})
def
channel_recv
(
channel
,
return_value
):
"""
Receives a value through a channel variable. Used by an unbuffered or
buffered channel within a concurrent Go block to get data from originally
sent using `channel_send`, or from outside such a block where
`channel_send` is used to send the value.
Args:
channel (Variable|Channel): Channel variable created using
`make_channel`.
return_value (Variable): Variable to set as a result of running channel_recv_op
Returns:
Variable: The received value from the channel.
Variable: The boolean status on whether or not the channel
successfully received the passed value.
Examples:
.. code-block:: python
ch = fluid.make_channel(dtype='int32', capacity=10)
with fluid.Go():
returned_value, return_status = fluid.channel_recv(ch, 'int32')
# Code to send data through the channel.
"""
helper
=
LayerHelper
(
'channel_recv'
,
**
locals
())
main_program
=
helper
.
main_program
channel_recv_block
=
main_program
.
current_block
()
status
=
helper
.
create_variable
(
name
=
unique_name
.
generate
(
'status'
),
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
dtype
=
core
.
VarDesc
.
VarType
.
BOOL
)
channel_recv_op
=
channel_recv_block
.
append_op
(
type
=
"channel_recv"
,
inputs
=
{
"Channel"
:
channel
},
outputs
=
{
"Out"
:
return_value
,
"Status"
:
status
})
return
return_value
,
status
def
channel_close
(
channel
):
"""
Closes a channel created using `make_channel`.
Args:
channel (Variable|Channel): Channel variable created using
`make_channel`.
Examples:
.. code-block:: python
ch = fluid.make_channel(dtype='int32', capacity=10)
...
# Code to receive and send data through a channel
...
fluid.channel_close(ch)
"""
helper
=
LayerHelper
(
'channel_close'
,
**
locals
())
main_program
=
helper
.
main_program
channel_close_block
=
main_program
.
current_block
()
channel_close_op
=
channel_close_block
.
append_op
(
type
=
"channel_close"
,
inputs
=
{
"Channel"
:
channel
})
python/paddle/fluid/contrib/tests/test_quantize_transpiler.py
浏览文件 @
f63ab561
...
...
@@ -244,6 +244,7 @@ class TestQuantizeTranspiler(unittest.TestCase):
test_loss2
,
=
exe
.
run
(
program
=
test_program
,
feed
=
feeder
.
feed
(
test_data
),
fetch_list
=
[
loss
])
self
.
assertAlmostEqual
(
test_loss1
,
test_loss2
,
delta
=
5e-3
)
w_freeze
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
'conv2d_1.w_0'
)
.
get_tensor
())
# fail: -432.0 != -433.0, this is due to the calculation precision
...
...
python/paddle/fluid/framework.py
浏览文件 @
f63ab561
...
...
@@ -541,8 +541,7 @@ class Operator(object):
'feed'
,
'fetch'
,
'save'
,
'load'
,
'recurrent'
,
'go'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'listen_and_serv'
,
'parallel_do'
,
'save_combine'
,
'load_combine'
,
'ncclInit'
,
'channel_create'
,
'channel_close'
,
'channel_send'
,
'channel_recv'
,
'select'
,
'checkpoint_notify'
,
'gen_nccl_id'
'ncclInit'
,
'select'
,
'checkpoint_notify'
,
'gen_nccl_id'
}
def
__init__
(
self
,
...
...
python/paddle/fluid/tests/book/high-level-api/recognize_digits/CMakeLists.txt
浏览文件 @
f63ab561
...
...
@@ -2,6 +2,16 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
# default test
foreach
(
src
${
TEST_OPS
}
)
if
(
NOT APPLE
)
foreach
(
src
${
TEST_OPS
}
)
py_test
(
${
src
}
SRCS
${
src
}
.py
)
endforeach
()
endforeach
()
else
()
foreach
(
src
${
TEST_OPS
}
)
if
(
${
src
}
STREQUAL
"test_recognize_digits_conv"
)
message
(
WARNING
"These tests has been disabled in OSX for random fail:
\n
"
${
src
}
)
else
()
py_test
(
${
src
}
SRCS
${
src
}
.py
)
endif
()
endforeach
()
endif
()
python/paddle/fluid/tests/no_test_concurrency.py
已删除
100644 → 0
浏览文件 @
8f5d918a
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid
import
framework
,
unique_name
,
layer_helper
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.layers
import
fill_constant
,
assign
,
While
,
elementwise_add
,
Print
class
TestRoutineOp
(
unittest
.
TestCase
):
def
test_simple_routine
(
self
):
ch
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
# Create LOD_TENSOR<INT64> and put it into the scope. This placeholder
# variable will be filled in and returned by fluid.channel_recv
result
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT64
)
with
fluid
.
Go
():
input_value
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
FP64
,
value
=
1234
)
fluid
.
channel_send
(
ch
,
input_value
)
result
,
status
=
fluid
.
channel_recv
(
ch
,
result
)
fluid
.
channel_close
(
ch
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
outs
=
exe
.
run
(
fetch_list
=
[
result
])
self
.
assertEqual
(
outs
[
0
],
1234
)
def
test_daisy_chain
(
self
):
'''
Mimics classic Daisy-chain test: https://talks.golang.org/2012/concurrency.slide#39
'''
n
=
100
leftmost
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
left
=
leftmost
# TODO(thuan): Use fluid.While() after scope capture is implemented.
# https://github.com/PaddlePaddle/Paddle/issues/8502
for
i
in
range
(
n
):
right
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
with
fluid
.
Go
():
one_tensor
=
self
.
_create_one_dim_tensor
(
1
)
result
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT64
)
result
,
status
=
fluid
.
channel_recv
(
right
,
result
)
one_added
=
fluid
.
layers
.
elementwise_add
(
x
=
one_tensor
,
y
=
result
)
fluid
.
channel_send
(
left
,
one_added
)
left
=
right
# Trigger the channel propagation by sending a "1" to rightmost channel
with
fluid
.
Go
():
one_tensor
=
self
.
_create_one_dim_tensor
(
1
)
fluid
.
channel_send
(
right
,
one_tensor
)
leftmost_result
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT64
)
leftmost_result
,
status
=
fluid
.
channel_recv
(
leftmost
,
leftmost_result
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
leftmost_data
=
exe
.
run
(
fetch_list
=
[
leftmost_result
])
# The leftmost_data should be equal to the number of channels + 1
self
.
assertEqual
(
leftmost_data
[
0
][
0
],
n
+
1
)
def
_create_one_dim_tensor
(
self
,
value
):
one_dim_tensor
=
fill_constant
(
shape
=
[
1
],
dtype
=
'int'
,
value
=
value
)
one_dim_tensor
.
stop_gradient
=
True
return
one_dim_tensor
def
_create_tensor
(
self
,
name
,
type
,
dtype
):
return
framework
.
default_main_program
().
current_block
().
create_var
(
name
=
unique_name
.
generate
(
name
),
type
=
type
,
dtype
=
dtype
)
def
_create_persistable_tensor
(
self
,
name
,
type
,
dtype
):
return
framework
.
default_main_program
().
current_block
().
create_var
(
name
=
unique_name
.
generate
(
name
),
type
=
type
,
dtype
=
dtype
,
persistable
=
True
)
def
test_select
(
self
):
with
framework
.
program_guard
(
framework
.
Program
()):
ch1
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
capacity
=
1
)
result1
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
FP64
)
input_value
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
FP64
,
value
=
10
)
with
fluid
.
Select
()
as
select
:
with
select
.
case
(
fluid
.
channel_send
,
ch1
,
input_value
):
# Execute something.
pass
with
select
.
default
():
pass
# This should not block because we are using a buffered channel.
result1
,
status
=
fluid
.
channel_recv
(
ch1
,
result1
)
fluid
.
channel_close
(
ch1
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
result
=
exe
.
run
(
fetch_list
=
[
result1
])
self
.
assertEqual
(
result
[
0
][
0
],
10
)
def
test_fibonacci
(
self
):
"""
Mimics Fibonacci Go example: https://tour.golang.org/concurrency/5
"""
with
framework
.
program_guard
(
framework
.
Program
()):
quit_ch_input_var
=
self
.
_create_persistable_tensor
(
'quit_ch_input'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT32
)
quit_ch_input
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
,
out
=
quit_ch_input_var
)
result
=
self
.
_create_persistable_tensor
(
'result'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
INT32
)
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
,
out
=
result
)
x
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
)
y
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
1
)
while_cond
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
BOOL
,
value
=
True
)
while_false
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
BOOL
,
value
=
False
)
x_tmp
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
)
def
fibonacci
(
channel
,
quit_channel
):
while_op
=
While
(
cond
=
while_cond
)
with
while_op
.
block
():
result2
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
value
=
0
)
with
fluid
.
Select
()
as
select
:
with
select
.
case
(
fluid
.
channel_send
,
channel
,
x
,
is_copy
=
True
):
assign
(
input
=
x
,
output
=
x_tmp
)
assign
(
input
=
y
,
output
=
x
)
assign
(
elementwise_add
(
x
=
x_tmp
,
y
=
y
),
output
=
y
)
with
select
.
case
(
fluid
.
channel_recv
,
quit_channel
,
result2
):
# Quit
helper
=
layer_helper
.
LayerHelper
(
'assign'
)
helper
.
append_op
(
type
=
'assign'
,
inputs
=
{
'X'
:
[
while_false
]},
outputs
=
{
'Out'
:
[
while_cond
]})
ch1
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
quit_ch
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
with
fluid
.
Go
():
for
i
in
range
(
10
):
fluid
.
channel_recv
(
ch1
,
result
)
Print
(
result
)
fluid
.
channel_send
(
quit_ch
,
quit_ch_input
)
fibonacci
(
ch1
,
quit_ch
)
fluid
.
channel_close
(
ch1
)
fluid
.
channel_close
(
quit_ch
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
exe_result
=
exe
.
run
(
fetch_list
=
[
result
])
self
.
assertEqual
(
exe_result
[
0
][
0
],
34
)
def
test_ping_pong
(
self
):
"""
Mimics Ping Pong example: https://gobyexample.com/channel-directions
"""
with
framework
.
program_guard
(
framework
.
Program
()):
result
=
self
.
_create_tensor
(
'return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
FP64
)
ping_result
=
self
.
_create_tensor
(
'ping_return_value'
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
FP64
)
def
ping
(
ch
,
message
):
fluid
.
channel_send
(
ch
,
message
,
is_copy
=
True
)
def
pong
(
ch1
,
ch2
):
fluid
.
channel_recv
(
ch1
,
ping_result
)
fluid
.
channel_send
(
ch2
,
ping_result
,
is_copy
=
True
)
pings
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
capacity
=
1
)
pongs
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
capacity
=
1
)
msg
=
fill_constant
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
FP64
,
value
=
9
)
ping
(
pings
,
msg
)
pong
(
pings
,
pongs
)
fluid
.
channel_recv
(
pongs
,
result
)
fluid
.
channel_close
(
pings
)
fluid
.
channel_close
(
pongs
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
exe_result
=
exe
.
run
(
fetch_list
=
[
result
])
self
.
assertEqual
(
exe_result
[
0
][
0
],
9
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/notest_concurrency.py
已删除
100644 → 0
浏览文件 @
8f5d918a
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.executor
import
Executor
class
TestRoutineOp
(
unittest
.
TestCase
):
def
test_simple_routine
(
self
):
ch
=
fluid
.
make_channel
(
dtype
=
core
.
VarDesc
.
VarType
.
BOOL
,
name
=
"CreateChannel"
)
with
fluid
.
Go
():
fluid
.
channel_send
(
ch
,
True
)
result
=
fluid
.
channel_recv
(
ch
)
fluid
.
channel_close
(
ch
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
outs
=
exe
.
run
(
fetch_list
=
[
result
])
self
.
assertEqual
(
outs
[
0
],
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录