Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5f586e22
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5f586e22
编写于
9月 05, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'ups/develop' into refine/op/fusion_lstm
上级
78d9ad57
04272c0d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
283 addition
and
104 deletion
+283
-104
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+0
-1
paddle/fluid/inference/api/demo_ci/run.sh
paddle/fluid/inference/api/demo_ci/run.sh
+1
-1
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+25
-24
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+192
-73
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+0
-5
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+65
-0
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
5f586e22
...
...
@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
...
...
paddle/fluid/inference/api/demo_ci/run.sh
浏览文件 @
5f586e22
...
...
@@ -14,7 +14,7 @@ else
fi
PREFIX
=
inference-vis-demos%2F
URL_ROOT
=
http://paddlemodels.
bj
.bcebos.com/
${
PREFIX
}
URL_ROOT
=
http://paddlemodels.
cdn
.bcebos.com/
${
PREFIX
}
# download vis_demo data
function
download
()
{
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
5f586e22
...
...
@@ -39,19 +39,6 @@ bool RequestSendHandler::Handle(const std::string& varname,
const
std
::
string
&
out_var_name
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
// Async
if
(
!
sync_mode_
)
{
rpc_server_
->
Profiler
().
OneStep
();
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"async: run sub program error "
<<
e
.
what
();
return
false
;
}
return
true
;
}
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv BATCH_BARRIER_MESSAGE"
;
...
...
@@ -60,17 +47,31 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG
(
3
)
<<
"sync: recv complete message"
;
rpc_server_
->
Complete
();
}
else
{
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
rpc_server_
->
WaitCond
(
kRequestSend
);
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
if
(
invar
==
nullptr
)
{
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
return
false
;
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_sparse_vars_
);
sparse_vars_
.
push_back
(
invar
);
// Async
if
(
!
sync_mode_
)
{
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
rpc_server_
->
Profiler
().
OneStep
();
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"async: run sub program error "
<<
e
.
what
();
return
false
;
}
return
true
;
}
else
{
// sync
rpc_server_
->
WaitCond
(
kRequestSend
);
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
if
(
invar
==
nullptr
)
{
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
return
false
;
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_sparse_vars_
);
sparse_vars_
.
push_back
(
invar
);
}
}
}
return
true
;
...
...
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
5f586e22
...
...
@@ -79,12 +79,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE
(
!
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
),
"Do not support peephole yet."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
auto
use_peepholes
=
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
(
use_peepholes
?
7
:
4
)
*
frame_size
,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection"
,
frame_size
);
"7 * %d if enable peepholes connection or"
"4 * %d if disable peepholes"
,
frame_size
,
frame_size
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
...
...
@@ -231,16 +231,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
act_cand = act_functor(act_cand_str); \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
#define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \
bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
...
...
@@ -265,12 +266,21 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
bias_data
=
bias
->
data
<
T
>
();
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
cell_out_data
=
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D4
,
M
,
x_data
,
wx_data
,
xx_data
,
bias
->
data
<
T
>
());
...
...
@@ -296,46 +306,86 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
const
T
*
prev_c_data
=
nullptr
;
const
T
*
prev_h_data
=
nullptr
;
int
tstart
=
0
;
if
(
h0_data
)
{
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
// If step == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros. Then W_h * H_t-1 can be skipped
// ~C_t
act_cand
(
D
,
xx_data
,
xx_data
);
// cell out= input*tilde
if
(
use_peepholes
)
{
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
cell_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
// O_t
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
tstart
=
1
;
tstart
=
1
;
move_step
();
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
// + W_h * H_t-1
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
// ~C_t
act_cand
(
D
,
xx_data
,
xx_data
);
// a = forget * prev_cell
if
(
use_peepholes
)
{
// + W_ic|W_fc * C_t-1 for peephole connection
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VADD
(
D2
,
xx_data
+
D
,
checked_cell_data
,
xx_data
+
D
);
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
}
// F_t * C_t-1
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// b = input * tilde
// I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
// cell out= a+b
// C_t = F_t * C_t-1 + I_t * ~C_t
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D2
,
cell_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
// O_t
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
...
...
@@ -343,14 +393,14 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_c_data
=
cell_out_data
;
move_step
();
}
}
}
// for each step in batch
}
// for each batch
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
// batch size == 1
SeqCompute
(
ctx
);
return
;
}
...
...
@@ -366,6 +416,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
bias_data
=
bias
->
data
<
T
>
();
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
auto
place
=
ctx
.
GetPlace
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
...
...
@@ -374,6 +426,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
...
...
@@ -395,17 +453,27 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
T
*
prev_batch_h_data
=
nullptr
;
T
*
prev_batch_c_data
=
nullptr
;
T
*
cur_batch_in_data
=
batched_input_data
;
T
*
cur_batch_h_out_data
=
batched_h_out_data
;
T
*
cur_batch_c_out_data
=
batched_c_out_data
;
auto
move_step
=
[
&
](
int
bs
)
{
cur_batch_in_data
+=
bs
*
D4
;
cur_batch_c_out_data
+=
bs
*
D
;
cur_batch_h_out_data
+=
bs
*
D
;
};
int
tstart
=
0
;
T
*
prev_h_data
=
nullptr
;
T
*
prev_c_data
=
nullptr
;
if
(
h0
)
{
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
prev_h_data
=
reordered_h0_data
;
prev_c_data
=
reordered_c0_data
;
prev_
batch_
h_data
=
reordered_h0_data
;
prev_
batch_
c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
...
...
@@ -414,71 +482,122 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_c0_data
+=
D
;
}
}
else
{
// compute without h0, c0
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
// W_ch, W_ih, W_fh, W_oh
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
// Compute with no H0/C0
T
*
cur_in_data
=
cur_batch_in_data
;
T
*
cur_c_out_data
=
cur_batch_c_out_data
;
T
*
cur_h_out_data
=
cur_batch_h_out_data
;
// If step == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros. Then W_h * H_t-1 can be skiped
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
// iterate each data in 1st batch
// ~C_t
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// cell out= input*tilde
if
(
use_peepholes
)
{
// I_t, F_t
act_gate
(
D2
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_c_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
cur_in_data
+
D3
,
checked_cell_data
+
D2
,
cur_in_data
+
D3
);
// O_t
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
//
add offset
//
move to next data in the same batch
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
// move to data for next timestep
prev_batch_h_data
=
cur_batch_h_out_data
;
prev_batch_c_data
=
cur_batch_c_out_data
;
move_step
(
max_bs
);
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
// Then start from next
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
// + W_h * H_t-1
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
cur_bs
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
batched_input_data
,
D4
);
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_prev_c_data
=
prev_c_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
prev_batch_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
cur_batch_in_data
,
D4
);
T
*
cur_in_data
=
cur_batch_in_data
;
T
*
cur_c_out_data
=
cur_batch_c_out_data
;
T
*
cur_h_out_data
=
cur_batch_h_out_data
;
T
*
prev_c_data
=
prev_batch_c_data
;
// NULL if no C0 in step0
T
*
prev_h_data
=
prev_batch_h_data
;
// NULL if no H0 in step0
auto
next_data_in_batch
=
[
&
]()
{
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
prev_c_data
=
prev_c_data
?
prev_c_data
+
D
:
nullptr
;
prev_h_data
=
prev_h_data
?
prev_h_data
+
D
:
nullptr
;
};
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// iterate each data in same batch
// ~C_t
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// a = forget * prev_cell
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_prev_c_data
,
cur_in_data
+
D2
);
// b = input * tilde
if
(
use_peepholes
)
{
// + W_ic|W_fc * C_t-1 for peephole connection
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VADD
(
D2
,
cur_in_data
+
D
,
checked_cell_data
,
cur_in_data
+
D
);
// I_t, F_t
act_gate
(
D2
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
// F_t * C_t-1
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
prev_c_data
,
cur_in_data
+
D2
);
// I_t * ~C_t
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_in_data
+
D
);
//
cell out= a+b
//
C_t = F_t * C_t-1 + I_t * ~C_t
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D2
,
cur_c_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
cur_in_data
+
D3
,
checked_cell_data
+
D2
,
cur_in_data
+
D3
);
// O_t
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
cur_in_data
+=
D4
;
cur_prev_c_data
+=
D
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
// move to next data in same batch
next_data_in_batch
();
}
prev_c_data
=
batched_c_out_data
;
prev_h_data
=
batched_h_out_data
;
batched_c_out_data
=
cur_c_out_data
;
batched_h_out_data
=
cur_h_out_data
;
batched_input_data
=
cur_in_data
;
// move to data for next timestep
prev_batch_h_data
=
cur_batch_h_out_data
;
prev_batch_c_data
=
cur_batch_c_out_data
;
move_step
(
cur_bs
);
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
5f586e22
...
...
@@ -3546,11 +3546,6 @@ def topk(input, k, name=None):
top5_values, top5_indices = layers.topk(input, k=5)
"""
shape
=
input
.
shape
if
k
<
1
or
k
>=
shape
[
-
1
]:
raise
ValueError
(
"k must be greater than 0 and less than %d."
%
(
shape
[
-
1
]))
helper
=
LayerHelper
(
"top_k"
,
**
locals
())
values
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
indices
=
helper
.
create_tmp_variable
(
dtype
=
"int64"
)
...
...
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
浏览文件 @
5f586e22
...
...
@@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest):
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
use_peepholes
=
False
self
.
use_seq
=
False
self
.
set_conf
()
T
=
sum
(
self
.
lod
[
0
])
...
...
@@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest):
}
self
.
attrs
=
{
'use_peepholes'
:
self
.
use_peepholes
,
'use_seq'
:
self
.
use_seq
,
'is_reverse'
:
self
.
is_reverse
,
'gate_activation'
:
self
.
act_gate
,
'cell_activation'
:
self
.
act_cell
,
...
...
@@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp):
self
.
D
=
16
class
TestFusionLSTMOpPeepholes
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
class
TestFusionLSTMOpPeepholesInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpPeepholesReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpPoopholesBS1
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
lod
=
[[
3
]]
self
.
D
=
16
class
TestFusionLSTMOpSeqInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpSeqInitReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpSeqPeepholes
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
class
TestFusionLSTMOpSeqPeepholesInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqPeepholesReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
is_reverse
=
True
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录