Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d1fbf50b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d1fbf50b
编写于
10月 19, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add unit testing for forwad implementation.
上级
2a8dbd13
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
233 addition
and
186 deletion
+233
-186
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+1
-1
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+4
-4
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+13
-6
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+4
-2
paddle/operators/math/detail/hl_avx_functions.cc
paddle/operators/math/detail/hl_avx_functions.cc
+3
-1
paddle/operators/math/detail/hl_cpu_functions.cc
paddle/operators/math/detail/hl_cpu_functions.cc
+89
-0
paddle/operators/math/detail/hl_functions.h
paddle/operators/math/detail/hl_functions.h
+17
-72
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+22
-28
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+2
-0
paddle/operators/math/lstm_compute.cu
paddle/operators/math/lstm_compute.cu
+8
-34
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+1
-1
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+13
-10
python/paddle/v2/framework/tests/test_lstm_op.py
python/paddle/v2/framework/tests/test_lstm_op.py
+56
-27
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
d1fbf50b
...
...
@@ -127,7 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library
(
sum_op DEPS net_op
)
op_library
(
pool_op DEPS pooling
)
op_library
(
pool_with_index_op DEPS pooling
)
op_library
(
lstm_op DEPS sequence2batch
)
op_library
(
lstm_op DEPS sequence2batch
lstm_compute math_function
)
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
...
...
paddle/operators/lstm_op.cc
浏览文件 @
d1fbf50b
...
...
@@ -44,7 +44,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same."
);
}
int
frame_size
=
x_dims
[
1
];
int
frame_size
=
x_dims
[
1
]
/
4
;
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"The rank of Input(Weight) should be 2."
);
...
...
@@ -71,9 +71,9 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if diable peepholes connection"
,
frame_size
);
}
ctx
->
SetOutputDim
(
"Hidden"
,
x_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
x_dims
);
ctx
->
SetOutputDim
(
"Batch"
,
x_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
{
x_dims
[
0
],
frame_size
}
);
ctx
->
SetOutputDim
(
"Cell"
,
{
x_dims
[
0
],
frame_size
}
);
ctx
->
SetOutputDim
(
"Batch
Gate
"
,
x_dims
);
ctx
->
ShareLoD
(
"Input"
,
"Hidden"
);
ctx
->
ShareLoD
(
"Input"
,
"Cell"
);
}
...
...
paddle/operators/lstm_op.h
浏览文件 @
d1fbf50b
...
...
@@ -52,9 +52,14 @@ class LSTMKernel : public framework::OpKernel<T> {
to_batch
(
ctx
.
device_context
(),
*
input
,
*
batch_gate
,
is_reverse
);
auto
in_dims
=
input
->
dims
();
int
frame_size
=
in_dims
[
1
];
int
frame_size
=
in_dims
[
1
]
/
4
;
framework
::
DDim
dims
({
in_dims
[
0
],
frame_size
});
if
(
bias
)
{
// framework::Tensor cpu_t;
// cpu_t.mutable_data<T>(in_dims, platform::CPUPlace());
// cpu_t.CopyFrom<T>(*batch_gate, platform::CPUPlace(),
// ctx.device_context());
Eigen
::
array
<
int
,
2
>
extents
({{
1
,
4
*
frame_size
}});
Eigen
::
array
<
int
,
2
>
offsets
({{
0
,
0
}});
auto
b
=
EigenMatrix
<
T
>::
From
(
*
bias
);
...
...
@@ -76,15 +81,14 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value
.
prevStateValue
=
nullptr
;
framework
::
LoDTensor
batch_out
;
batch_out
.
mutable_data
<
T
>
(
in_
dims
,
ctx
.
GetPlace
());
batch_out
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
framework
::
LoDTensor
batch_cell
;
batch_cell
.
mutable_data
<
T
>
(
in_
dims
,
ctx
.
GetPlace
());
batch_cell
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
framework
::
LoDTensor
batch_cell_pre_act
;
batch_cell_pre_act
.
mutable_data
<
T
>
(
in_
dims
,
ctx
.
GetPlace
());
batch_cell_pre_act
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
auto
batch_lod
=
batch_gate
->
lod
()[
0
];
int
num_batch
=
batch_lod
.
size
()
-
1
;
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gateActivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cellActivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidateActivation"
);
...
...
@@ -125,9 +129,12 @@ class LSTMKernel : public framework::OpKernel<T> {
// restore the output hidden in LoDTensor from the batch hidden
to_seq
(
ctx
.
device_context
(),
batch_out
,
*
hidden_out
);
batch_
out
.
set_lod
(
batch_gate
->
lod
());
batch_
cell
.
set_lod
(
batch_gate
->
lod
());
// restore the output cell state in LoDTensor from the batch cell
to_seq
(
ctx
.
device_context
(),
batch_cell
,
*
cell_out
);
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
batch_gate
);
t
.
device
(
ctx
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
}
};
...
...
paddle/operators/math/CMakeLists.txt
浏览文件 @
d1fbf50b
add_subdirectory
(
detail
)
if
(
WITH_GPU
)
nv_library
(
math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator
)
nv_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
@@ -6,7 +8,7 @@ if(WITH_GPU)
nv_library
(
pooling SRCS pooling.cc pooling.cu DEPS device_context
)
nv_library
(
vol2col SRCS vol2col.cc vol2col.cu DEPS device_context
)
nv_library
(
sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context
)
nv_library
(
lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context
)
nv_library
(
lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context
activation_functions
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator
)
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
@@ -14,7 +16,7 @@ else()
cc_library
(
cross_entropy SRCS cross_entropy.cc DEPS operator
)
cc_library
(
pooling SRCS pooling.cc DEPS device_context
)
cc_library
(
sequence2batch SRCS sequence2batch.cc DEPS device_context
)
cc_library
(
lstm_compute SRCS lstm_compute.cc DEPS device_context
)
cc_library
(
lstm_compute SRCS lstm_compute.cc DEPS device_context
activation_functions
)
endif
()
cc_test
(
im2col_test SRCS im2col_test.cc DEPS math_function tensor
)
...
...
paddle/operators/math/detail/hl_avx_functions.cc
浏览文件 @
d1fbf50b
...
...
@@ -14,10 +14,12 @@ limitations under the License. */
#include <immintrin.h>
#include "hl_functions.h"
// TODO(qingqing) refine this dependence
#include "paddle/cuda/src/avx_mathfun.h"
namespace
hppl
{
extern
__m256
exp
(
__m256
a
);
__m256
exp
(
__m256
a
)
{
return
exp256_ps
(
a
);
}
__m256
relu
(
const
__m256
a
)
{
__m256
tmp
=
_mm256_set1_ps
(
0.0
f
);
...
...
paddle/operators/math/detail/hl_cpu_functions.cc
0 → 100644
浏览文件 @
d1fbf50b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "hl_functions.h"
namespace
hppl
{
namespace
typef
{
float
relu
(
const
float
a
)
{
return
a
>
static_cast
<
float
>
(
0.0
)
?
a
:
static_cast
<
float
>
(
0.0
);
}
float
sigmoid
(
const
float
a
)
{
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
float
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
float
>
(
1.0
)
/
(
static_cast
<
float
>
(
1.0
)
+
exp
(
-
tmp
));
}
float
tanh
(
const
float
a
)
{
float
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
float
linear
(
const
float
a
)
{
return
a
;
}
float
relu
(
const
float
a
,
const
float
b
)
{
return
a
*
(
b
>
0.0
?
1.0
:
0.0
);
}
float
sigmoid
(
const
float
a
,
const
float
b
)
{
return
a
*
b
*
(
static_cast
<
float
>
(
1
)
-
b
);
}
float
tanh
(
const
float
a
,
const
float
b
)
{
return
a
*
(
static_cast
<
float
>
(
1
)
-
b
*
b
);
}
float
linear
(
const
float
a
,
const
float
b
)
{
return
a
;
}
}
// namespace typef
namespace
typed
{
double
relu
(
const
double
a
)
{
return
a
>
static_cast
<
double
>
(
0.0
)
?
a
:
static_cast
<
double
>
(
0.0
);
}
double
sigmoid
(
const
double
a
)
{
const
double
min
=
SIGMOID_THRESHOLD_MIN
;
const
double
max
=
SIGMOID_THRESHOLD_MAX
;
double
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
double
>
(
1.0
)
/
(
static_cast
<
double
>
(
1.0
)
+
exp
(
-
tmp
));
}
double
tanh
(
const
double
a
)
{
double
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
double
linear
(
const
double
a
)
{
return
a
;
}
double
relu
(
const
double
a
,
const
double
b
)
{
return
a
*
(
b
>
0.0
?
1.0
:
0.0
);
}
double
sigmoid
(
const
double
a
,
const
double
b
)
{
return
a
*
b
*
(
static_cast
<
double
>
(
1
)
-
b
);
}
double
tanh
(
const
double
a
,
const
double
b
)
{
return
a
*
(
static_cast
<
double
>
(
1
)
-
b
*
b
);
}
double
linear
(
const
double
a
,
const
double
b
)
{
return
a
;
}
}
// namespace typed
}
// namespace hppl
paddle/operators/math/detail/hl_functions.h
浏览文件 @
d1fbf50b
...
...
@@ -34,83 +34,28 @@ limitations under the License. */
#ifndef __NVCC__
namespace
hppl
{
namespace
typef
{
/*
* forward activation
*/
float
relu
(
const
float
a
)
{
return
a
>
static_cast
<
float
>
(
0.0
)
?
a
:
static_cast
<
float
>
(
0.0
);
}
float
sigmoid
(
const
float
a
)
{
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
float
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
float
>
(
1.0
)
/
(
static_cast
<
float
>
(
1.0
)
+
exp
(
-
tmp
));
}
float
tanh
(
const
float
a
)
{
float
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
float
linear
(
const
float
a
)
{
return
a
;
}
/*
* backward activation
*/
float
relu
(
const
float
a
,
const
float
b
)
{
return
a
*
(
b
>
0.0
?
1.0
:
0.0
);
}
float
relu
(
const
float
a
);
float
sigmoid
(
const
float
a
);
float
tanh
(
const
float
a
);
float
linear
(
const
float
a
);
float
sigmoid
(
const
float
a
,
const
float
b
)
{
return
a
*
b
*
(
static_cast
<
float
>
(
1
)
-
b
);
}
float
relu
(
const
float
a
,
const
float
b
);
float
sigmoid
(
const
float
a
,
const
float
b
);
float
tanh
(
const
float
a
,
const
float
b
);
float
linear
(
const
float
a
,
const
float
b
);
float
tanh
(
const
float
a
,
const
float
b
)
{
return
a
*
(
static_cast
<
float
>
(
1
)
-
b
*
b
);
}
float
linear
(
const
float
a
,
const
float
b
)
{
return
a
;
}
}
// namespace typef
namespace
typed
{
/*
* forward activation
*/
double
relu
(
const
double
a
)
{
return
a
>
static_cast
<
double
>
(
0.0
)
?
a
:
static_cast
<
double
>
(
0.0
);
}
double
sigmoid
(
const
double
a
)
{
const
double
min
=
SIGMOID_THRESHOLD_MIN
;
const
double
max
=
SIGMOID_THRESHOLD_MAX
;
double
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
double
>
(
1.0
)
/
(
static_cast
<
double
>
(
1.0
)
+
exp
(
-
tmp
));
}
double
tanh
(
const
double
a
)
{
double
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
double
linear
(
const
double
a
)
{
return
a
;
}
/*
* backward activation
*/
double
relu
(
const
double
a
,
const
double
b
)
{
return
a
*
(
b
>
0.0
?
1.0
:
0.0
);
}
double
sigmoid
(
const
double
a
,
const
double
b
)
{
return
a
*
b
*
(
static_cast
<
double
>
(
1
)
-
b
);
}
double
tanh
(
const
double
a
,
const
double
b
)
{
return
a
*
(
static_cast
<
double
>
(
1
)
-
b
*
b
);
}
double
linear
(
const
double
a
,
const
double
b
)
{
return
a
;
}
double
relu
(
const
double
a
);
double
sigmoid
(
const
double
a
);
double
tanh
(
const
double
a
);
double
linear
(
const
double
a
);
double
relu
(
const
double
a
,
const
double
b
);
double
sigmoid
(
const
double
a
,
const
double
b
);
double
tanh
(
const
double
a
,
const
double
b
);
double
linear
(
const
double
a
,
const
double
b
);
}
// namespace typed
}
// namespace hppl
...
...
paddle/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
d1fbf50b
...
...
@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace
paddle
{
namespace
operators
{
namespace
math
{
...
...
@@ -29,11 +31,10 @@ namespace detail {
* grid(frameBlocks, batchBlocks)
*/
template
<
class
T
,
class
Op
,
bool
isBatch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
int
batchSize
,
typename
hppl
::
ForwardActType
<
T
>::
type
active_node
,
typename
hppl
::
ForwardActType
<
T
>::
type
active_gate
,
typename
hppl
::
ForwardActType
<
T
>::
type
active_state
)
{
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
...
...
@@ -69,8 +70,10 @@ __global__ void KeLstmForward(
rPrevState
=
value
.
prevStateValue
[
frameIdx
];
}
hppl
::
gpu
::
ForwardAct
<
T
>
act
;
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
act
(
active_node
),
act
(
active_gate
),
act
(
active_state
));
value
.
gateValue
[
frameIdx
]
=
rValueIn
;
value
.
gateValue
[
frameIdx
+
frameSize
]
=
rValueIg
;
...
...
@@ -87,11 +90,11 @@ __global__ void KeLstmForward(
* grid(frameBlocks, batchBlocks)
*/
template
<
class
T
,
class
Op
,
bool
isBatch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
typename
hppl
::
BackwardActType
<
T
>::
type
active_node
,
typename
hppl
::
BackwardActType
<
T
>::
type
active_gate
,
typename
hppl
::
BackwardActType
<
T
>::
type
active_state
)
{
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
...
...
@@ -142,10 +145,11 @@ __global__ void KeLstmBackward(
rPrevState
=
value
.
prevStateValue
[
frameIdx
];
}
hppl
::
gpu
::
BackwardAct
<
T
>
act
;
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
act
ive_node
,
active_gate
,
active_state
);
act
(
active_node
),
act
(
active_gate
),
act
(
active_state
)
);
grad
.
gateGrad
[
frameIdx
]
=
rGradIn
;
grad
.
gateGrad
[
frameIdx
+
frameSize
]
=
rGradIg
;
...
...
@@ -196,22 +200,16 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
grid
=
dim3
((
frameSize
+
32
-
1
)
/
32
,
(
batchSize
+
32
-
1
)
/
32
);
}
using
type
=
typename
hppl
::
ForwardActType
<
T
>::
type
;
hppl
::
gpu
::
ForwardAct
<
T
>
act
;
type
act_node
=
act
(
active_node
);
type
act_gate
=
act
(
active_gate
);
type
act_state
=
act
(
active_state
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
if
(
batchSize
==
1
)
{
KeLstmForward
<
T
,
Op
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frameSize
,
batchSize
,
act
_node
,
act_gate
,
act_st
ate
);
op
,
value
,
frameSize
,
batchSize
,
act
ive_node
,
active_gate
,
active_g
ate
);
}
else
{
KeLstmForward
<
T
,
Op
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frameSize
,
batchSize
,
act
_node
,
act_gate
,
act_st
ate
);
op
,
value
,
frameSize
,
batchSize
,
act
ive_node
,
active_gate
,
active_g
ate
);
}
}
...
...
@@ -235,22 +233,18 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
grid
=
dim3
((
frameSize
+
32
-
1
)
/
32
,
(
batchSize
+
32
-
1
)
/
32
);
}
using
type
=
typename
hppl
::
BackwardActType
<
T
>::
type
;
hppl
::
gpu
::
BackwardAct
<
T
>
act
;
type
act_node
=
act
(
active_node
);
type
act_gate
=
act
(
active_gate
);
type
act_state
=
act
(
active_state
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
if
(
batchSize
==
1
)
{
KeLstmBackward
<
T
,
Op
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frameSize
,
batchSize
,
act_node
,
act_gate
,
act_state
);
op
,
value
,
grad
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
KeLstmBackward
<
T
,
Op
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frameSize
,
batchSize
,
act_node
,
act_gate
,
act_state
);
op
,
value
,
grad
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
}
...
...
paddle/operators/math/lstm_compute.cc
浏览文件 @
d1fbf50b
...
...
@@ -72,6 +72,8 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
};
template
class
LstmUnitFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
LstmUnitFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
LstmUnitGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
LstmUnitGradFunctor
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
...
...
paddle/operators/math/lstm_compute.cu
浏览文件 @
d1fbf50b
...
...
@@ -26,18 +26,9 @@ struct LstmUnitFunctor<platform::GPUPlace, T> {
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
gpu_lstm_forward
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
detail
::
gpu_lstm_forward
<
T
>
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gateValue
+=
frame_size
*
4
;
value
.
stateValue
+=
frame_size
;
value
.
stateActiveValue
+=
frame_size
;
value
.
outputValue
+=
frame_size
;
if
(
value
.
prevStateValue
)
{
value
.
prevStateValue
+=
frame_size
;
}
}
}
};
...
...
@@ -47,32 +38,15 @@ struct LstmUnitGradFunctor<platform::GPUPlace, T> {
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gateValue
+=
frame_size
*
4
;
value
.
stateValue
+=
frame_size
;
value
.
stateActiveValue
+=
frame_size
;
value
.
outputValue
+=
frame_size
;
if
(
value
.
prevStateValue
)
{
value
.
prevStateValue
+=
frame_size
;
}
grad
.
gateGrad
+=
frame_size
*
4
;
grad
.
stateGrad
+=
frame_size
;
grad
.
stateActiveGrad
+=
frame_size
;
grad
.
outputGrad
+=
frame_size
;
if
(
grad
.
prevStateGrad
)
{
grad
.
prevStateGrad
+=
frame_size
;
}
}
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
}
};
template
class
LstmUnitFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
LstmUnitFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
LstmUnitGradFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
LstmUnitGradFunctor
<
platform
::
GPUPlace
,
double
>;
}
// namespace math
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
d1fbf50b
...
...
@@ -53,7 +53,7 @@ struct LstmMetaGrad {
T
*
checkOgGrad
;
};
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
inline
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
return
HL_ACTIVATION_SIGMOID
;
}
else
if
(
type
==
"relu"
)
{
...
...
paddle/operators/math/sequence2batch.h
浏览文件 @
d1fbf50b
...
...
@@ -59,7 +59,7 @@ class LoDTensor2BatchFunctor {
};
std
::
vector
<
SeqInfo
>
seq_info
;
for
(
size_t
seq_id
=
0
;
seq_id
<
lod
.
size
();
++
seq_id
)
{
for
(
size_t
seq_id
=
0
;
seq_id
<
lod
.
size
()
-
1
;
++
seq_id
)
{
int
length
=
lod
[
seq_id
+
1
]
-
lod
[
seq_id
];
seq_info
.
emplace_back
(
lod
[
seq_id
],
length
,
seq_id
);
}
...
...
@@ -83,10 +83,11 @@ class LoDTensor2BatchFunctor {
// The batch number represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence.
auto
batch_lods
=
batch
.
lod
();
if
(
batch_lods
.
size
()
==
0
)
{
batch_lods
.
resize
(
2
);
}
paddle
::
framework
::
LoD
batch_lods
;
batch_lods
.
push_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
push_back
(
std
::
vector
<
size_t
>
{
0
});
// batch_lods[0] is the start positions for batch LoDTensor
int
num_batch
=
(
size_t
)
seq_info
[
0
].
length
;
batch_lods
[
0
].
resize
(
num_batch
+
1
);
...
...
@@ -115,6 +116,7 @@ class LoDTensor2BatchFunctor {
}
batch_starts
[
n
+
1
]
=
batch_id
;
}
batch
.
set_lod
(
batch_lods
);
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
to_batch
(
context
,
lod_tensor
,
seq2batch_idx
,
batch
,
true
);
...
...
@@ -130,12 +132,13 @@ class Batch2LoDTensorFunctor {
auto
in_lod
=
batch
.
lod
();
PADDLE_ENFORCE_EQ
(
in_lod
.
size
(),
2UL
,
"The LoD size of input `batch` should be 2."
);
auto
out_lod
=
lod_tensor
.
lod
();
PADDLE_ENFORCE_EQ
(
out_lod
[
0
][
0
],
out_lod
[
1
].
size
());
PADDLE_ENFORCE_EQ
(
out_lod
[
0
][
0
],
lod_tensor
.
dims
()[
0
]);
PADDLE_ENFORCE_EQ
(
out_lod
[
0
][
0
],
batch
.
dims
()[
0
]);
auto
out_lod
=
lod_tensor
.
lod
()[
0
];
auto
num
=
out_lod
[
out_lod
.
size
()
-
1
];
PADDLE_ENFORCE_EQ
(
num
,
lod_tensor
.
dims
()[
0
]);
PADDLE_ENFORCE_EQ
(
num
,
in_lod
[
1
].
size
());
PADDLE_ENFORCE_EQ
(
num
,
batch
.
dims
()[
0
]);
CopyMatrixRowsFunctor
<
Place
,
T
>
to_seq
;
size_t
*
index
=
out
_lod
[
1
].
data
();
size_t
*
index
=
in
_lod
[
1
].
data
();
to_seq
(
context
,
batch
,
index
,
lod_tensor
,
false
);
}
};
...
...
python/paddle/v2/framework/tests/test_lstm_op.py
浏览文件 @
d1fbf50b
...
...
@@ -2,17 +2,26 @@ import unittest
import
numpy
as
np
from
op_test
import
OpTest
SIGMOID_THRESHOLD_MIN
=
-
40.0
SIGMOID_THRESHOLD_MAX
=
13.0
EXP_MAX_INPUT
=
40.0
def
identity
(
x
):
return
x
def
sigmoid
(
x
):
return
1.
/
(
1.
+
np
.
exp
(
-
x
))
y
=
np
.
copy
(
x
)
y
[
x
<
SIGMOID_THRESHOLD_MIN
]
=
SIGMOID_THRESHOLD_MIN
y
[
x
>
SIGMOID_THRESHOLD_MAX
]
=
SIGMOID_THRESHOLD_MAX
return
1.
/
(
1.
+
np
.
exp
(
-
y
))
def
tanh
(
x
):
return
2.
*
sigmoid
(
2.
*
x
)
-
1.
y
=
-
2.
*
x
y
[
y
>
EXP_MAX_INPUT
]
=
EXP_MAX_INPUT
return
(
2.
/
(
1.
+
np
.
exp
(
y
)))
-
1.
def
relu
(
x
):
...
...
@@ -35,7 +44,7 @@ def lstm(
g
=
np
.
dot
(
h_pre
,
w_h
)
# 1 x 4D
g
=
g
+
x
g
=
np
.
reshape
(
g
,
(
1
,
g
.
size
))
c
,
g_i
,
g_f
,
g_o
=
np
.
split
(
g
,
4
,
axis
=
1
)
c
_tmp
,
g_i
,
g_f
,
g_o
=
np
.
split
(
g
,
4
,
axis
=
1
)
if
w_c
is
None
:
g_i
=
gate_act
(
g_i
)
# 1 x D
g_f
=
gate_act
(
g_f
)
# 1 x D
...
...
@@ -43,7 +52,7 @@ def lstm(
w_ic
,
w_fc
,
w_oc
=
np
.
split
(
w_c
,
3
,
axis
=
1
)
g_i
=
gate_act
(
g_i
+
w_ic
*
c_pre
)
# 1 x D
g_f
=
gate_act
(
g_f
+
w_fc
*
c_pre
)
# 1 x D
c
=
g_f
*
c_pre
+
g_i
*
cand_act
(
c
)
# 1 x D
c
=
g_f
*
c_pre
+
g_i
*
cand_act
(
c
_tmp
)
# 1 x D
if
w_c
is
None
:
g_o
=
gate_act
(
g_o
)
# 1 x D
...
...
@@ -51,12 +60,14 @@ def lstm(
_
,
_
,
w_oc
=
np
.
split
(
w_c
,
3
,
axis
=
1
)
g_o
=
gate_act
(
g_o
+
w_oc
*
c
)
# 1 x D
h
=
g_o
*
cell_act
(
c
)
return
h
,
c
bg
=
np
.
concatenate
((
cand_act
(
c_tmp
),
g_i
,
g_f
,
g_o
),
axis
=
1
)
return
h
,
c
,
bg
offset
=
lod
[
0
]
batch_size
=
len
(
offset
)
-
1
hidden
=
[]
cell
=
[]
gate
=
[]
if
w_b
is
not
None
:
input
=
input
+
np
.
tile
(
w_b
,
(
offset
[
-
1
],
1
))
for
i
in
range
(
batch_size
):
...
...
@@ -64,44 +75,62 @@ def lstm(
seq_len
=
offset
[
i
+
1
]
-
offset
[
i
]
x
=
input
[
offset
[
i
]:
offset
[
i
+
1
],
:]
h_pre
=
h0
[
i
]
# 1 x D
c_pre
=
h
0
[
i
]
# 1 x D
c_pre
=
c
0
[
i
]
# 1 x D
for
j
in
range
(
seq_len
):
# compute one step
h_pre
,
c_pre
=
_step
(
x
[
j
],
w_h
,
w_c
,
h_pre
,
c_pre
,
gate_act
,
h_pre
,
c_pre
,
g_pre
=
_step
(
x
[
j
],
w_h
,
w_c
,
h_pre
,
c_pre
,
gate_act
,
cell_act
,
cand_act
)
hidden
.
append
(
h_pre
.
flatten
())
cell
.
append
(
c_pre
.
flatten
())
gate
.
append
(
g_pre
.
flatten
())
hidden
=
np
.
array
(
hidden
).
astype
(
"float64"
)
cell
=
np
.
array
(
cell
).
astype
(
"float64"
)
gate
=
np
.
array
(
gate
).
astype
(
"float64"
)
assert
gate
.
shape
==
input
.
shape
assert
hidden
.
shape
==
(
input
.
shape
[
0
],
input
.
shape
[
1
]
/
4
)
assert
cell
.
shape
==
(
input
.
shape
[
0
],
input
.
shape
[
1
]
/
4
)
return
hidden
,
cell
return
hidden
,
cell
,
gate
class
LstmUnitTest
(
OpTest
):
def
set_data
(
self
):
lod
=
[[
0
,
2
,
6
,
9
]]
shape
=
(
9
,
64
)
x
=
np
.
random
.
normal
(
size
=
(
9
,
4
*
64
)).
astype
(
"float64"
)
h0
=
np
.
random
.
normal
(
size
=
(
4
,
64
)).
astype
(
"float64"
)
c0
=
np
.
random
.
normal
(
size
=
(
4
,
64
)).
astype
(
"float64"
)
w
=
np
.
random
.
normal
(
size
=
(
64
,
4
*
64
)).
astype
(
"float64"
)
b
=
np
.
random
.
normal
(
size
=
(
1
,
7
*
64
)).
astype
(
"float64"
)
w_b
=
b
[:,
4
*
64
]
w_c
=
b
[:,
4
*
64
:]
h
,
c
=
lstm
(
x
,
lod
,
h0
,
c0
,
w
,
w_b
,
w_c
,
False
,
sigmoid
,
tanh
,
tanh
)
self
.
inputs
=
{
'Input'
:
x
,
'H0'
:
h0
,
'C0'
:
c0
,
'Weight'
:
w
,
'Bias'
:
b
}
self
.
inputs
=
{
'Hidden'
:
h
,
'Cell'
:
c
}
D
=
4
#lod = [[0, 2, 6, 9]]
lod
=
[[
0
,
1
]]
shape
=
(
1
,
D
)
x
=
np
.
random
.
normal
(
size
=
(
1
,
4
*
D
)).
astype
(
"float64"
)
h0
=
np
.
zeros
((
4
,
D
)).
astype
(
"float64"
)
c0
=
np
.
zeros
((
4
,
D
)).
astype
(
"float64"
)
w
=
np
.
random
.
normal
(
size
=
(
D
,
4
*
D
)).
astype
(
"float64"
)
b
=
np
.
random
.
normal
(
size
=
(
1
,
7
*
D
)).
astype
(
"float64"
)
w_b
=
b
[:,
0
:
4
*
D
]
w_c
=
b
[:,
4
*
D
:]
#h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh)
h
,
c
,
g
=
lstm
(
x
,
lod
,
h0
,
c0
,
w
,
w_b
,
w_c
,
False
,
identity
,
identity
,
identity
)
g_sort
=
np
.
zeros_like
(
x
)
#idx = [2,6,0,3,7,1,4,8,5]
#for i, j in enumerate(idx):
# g_sort[i, :] = g[j, :]
self
.
inputs
=
{
'Input'
:
(
x
,
lod
),
'H0'
:
h0
,
'C0'
:
c0
,
'Weight'
:
w
,
'Bias'
:
b
}
self
.
outputs
=
{
'Hidden'
:
h
,
'Cell'
:
c
,
'BatchGate'
:
g_sort
}
self
.
attrs
=
{
'usePeepholes'
:
True
,
'isReverse'
:
False
,
'gateActivation'
:
'
sigmoid
'
,
'cellActivation'
:
'
tanh
'
,
'candidateActivation'
:
'
tanh
'
'gateActivation'
:
'
linear
'
,
'cellActivation'
:
'
linear
'
,
'candidateActivation'
:
'
linear
'
}
def
setUp
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录