Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d760b6a5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d760b6a5
编写于
12月 25, 2017
作者:
Q
qingqing01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine the activation type getting in the LSTM operator to speed.
上级
d3c42f7d
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
102 addition
and
67 deletion
+102
-67
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+13
-6
paddle/operators/math/detail/activation_functions.h
paddle/operators/math/detail/activation_functions.h
+22
-0
paddle/operators/math/detail/lstm_cpu_kernel.h
paddle/operators/math/detail/lstm_cpu_kernel.h
+18
-18
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+12
-12
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+11
-11
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+8
-8
paddle/operators/math/lstm_compute.cu
paddle/operators/math/lstm_compute.cu
+10
-8
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+8
-4
未找到文件。
paddle/operators/lstm_op.h
浏览文件 @
d760b6a5
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/operators/math/detail/activation_functions.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel<T> {
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
);
auto
gate_act
=
math
::
detail
::
GetActivationType
(
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
);
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
);
auto
cell_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
...
@@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g
.
mutable_data
<
T
>
(
batch_gate
->
dims
(),
ctx
.
GetPlace
());
batch_gate_g
.
mutable_data
<
T
>
(
batch_gate
->
dims
(),
ctx
.
GetPlace
());
batch_gate_g
.
set_lod
(
batch_gate
->
lod
());
batch_gate_g
.
set_lod
(
batch_gate
->
lod
());
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
);
auto
gate_act
=
math
::
detail
::
GetActivationType
(
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
);
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
);
auto
cell_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
...
...
paddle/operators/math/detail/activation_functions.h
浏览文件 @
d760b6a5
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <math.h>
#include <math.h>
#include "paddle/platform/hostdevice.h"
#include "paddle/platform/hostdevice.h"
#include "paddle/platform/enforce.h"
#ifdef __AVX__
#ifdef __AVX__
#include <immintrin.h>
#include <immintrin.h>
...
@@ -29,6 +30,27 @@ namespace detail {
...
@@ -29,6 +30,27 @@ namespace detail {
#define SIGMOID_THRESHOLD_MAX 13.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define EXP_MAX_INPUT 40.0
enum
ActivationType
{
kSigmoid
,
kReLU
,
kTanh
,
kIdentity
,
};
inline
ActivationType
GetActivationType
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
return
ActivationType
::
kSigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
ActivationType
::
kReLU
;
}
else
if
(
type
==
"tanh"
)
{
return
ActivationType
::
kTanh
;
}
else
if
(
type
==
"identity"
)
{
return
ActivationType
::
kIdentity
;
}
PADDLE_THROW
(
"Not support type %s."
,
type
);
}
namespace
forward
{
namespace
forward
{
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/operators/math/detail/lstm_cpu_kernel.h
浏览文件 @
d760b6a5
...
@@ -27,9 +27,9 @@ namespace detail {
...
@@ -27,9 +27,9 @@ namespace detail {
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
frame_size
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
T
r_value_in
;
T
r_value_in
;
T
r_value_ig
;
T
r_value_ig
;
T
r_value_fg
;
T
r_value_fg
;
...
@@ -77,9 +77,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -77,9 +77,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
T
r_value_in
;
T
r_value_in
;
T
r_value_ig
;
T
r_value_ig
;
T
r_value_fg
;
T
r_value_fg
;
...
@@ -150,9 +150,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -150,9 +150,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
frame_size
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r_value_in
;
__m256
r_value_in
;
__m256
r_value_ig
;
__m256
r_value_ig
;
...
@@ -204,9 +204,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -204,9 +204,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r_value_in
;
__m256
r_value_in
;
__m256
r_value_ig
;
__m256
r_value_ig
;
...
@@ -281,9 +281,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -281,9 +281,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame_size
,
active_node
,
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame_size
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
...
@@ -295,9 +295,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
...
@@ -295,9 +295,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
activation_mode_t
active_node
,
int
frame_size
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame_size
,
active_node
,
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame_size
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
...
...
paddle/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
d760b6a5
...
@@ -31,9 +31,9 @@ namespace detail {
...
@@ -31,9 +31,9 @@ namespace detail {
*/
*/
template
<
class
T
,
class
Op
,
bool
is_batch
>
template
<
class
T
,
class
Op
,
bool
is_batch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
int
batch_size
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
if
(
frame_idx
>=
frame_size
)
return
;
...
@@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
...
@@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
template
<
class
T
,
class
Op
,
bool
is_batch
>
template
<
class
T
,
class
Op
,
bool
is_batch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
int
batch_size
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
if
(
frame_idx
>=
frame_size
)
return
;
...
@@ -185,9 +185,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
...
@@ -185,9 +185,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
...
@@ -220,9 +220,9 @@ template <class T, class Op>
...
@@ -220,9 +220,9 @@ template <class T, class Op>
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
d760b6a5
...
@@ -30,9 +30,9 @@ class lstm {
...
@@ -30,9 +30,9 @@ class lstm {
HOSTDEVICE
void
operator
()(
T
&
value_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_og
,
HOSTDEVICE
void
operator
()(
T
&
value_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_og
,
T
&
prev_state
,
T
&
state
,
T
&
state_atv
,
T
&
output
,
T
&
prev_state
,
T
&
state
,
T
&
state_atv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
value_in
=
activation
(
value_in
,
active_node
);
value_in
=
activation
(
value_in
,
active_node
);
value_ig
=
activation
(
value_ig
+
prev_state
*
checkI
,
active_gate
);
value_ig
=
activation
(
value_ig
+
prev_state
*
checkI
,
active_gate
);
value_fg
=
activation
(
value_fg
+
prev_state
*
checkF
,
active_gate
);
value_fg
=
activation
(
value_fg
+
prev_state
*
checkF
,
active_gate
);
...
@@ -53,9 +53,9 @@ class lstm {
...
@@ -53,9 +53,9 @@ class lstm {
__m256
&
prev_state
,
__m256
&
state
,
__m256
&
prev_state
,
__m256
&
state
,
__m256
&
state_atv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
state_atv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkF
,
__m256
&
checkO
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
value_in
=
activation
(
value_in
,
active_node
);
value_in
=
activation
(
value_in
,
active_node
);
value_ig
=
value_ig
=
activation
(
_mm256_add_ps
(
value_ig
,
_mm256_mul_ps
(
prev_state
,
checkI
)),
activation
(
_mm256_add_ps
(
value_ig
,
_mm256_mul_ps
(
prev_state
,
checkI
)),
...
@@ -87,9 +87,9 @@ class lstm {
...
@@ -87,9 +87,9 @@ class lstm {
T
&
state_grad
,
T
&
state_atv
,
T
&
output_grad
,
T
&
state_grad
,
T
&
state_atv
,
T
&
output_grad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
activation_mode_t
active_node
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
ActivationType
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_state
)
{
grad_og
=
activation
(
output_grad
*
state_atv
,
value_og
,
active_gate
);
grad_og
=
activation
(
output_grad
*
state_atv
,
value_og
,
active_gate
);
state_grad
+=
activation
(
output_grad
*
value_og
,
state_atv
,
active_state
)
+
state_grad
+=
activation
(
output_grad
*
value_og
,
state_atv
,
active_state
)
+
grad_og
*
checkO
;
grad_og
*
checkO
;
...
@@ -114,8 +114,8 @@ class lstm {
...
@@ -114,8 +114,8 @@ class lstm {
__m256
&
prev_state
,
__m256
&
prev_state_grad
,
__m256
&
state
,
__m256
&
prev_state
,
__m256
&
prev_state_grad
,
__m256
&
state
,
__m256
&
state_grad
,
__m256
&
state_atv
,
__m256
&
output_grad
,
__m256
&
state_grad
,
__m256
&
state_atv
,
__m256
&
output_grad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
activation_mode_t
active_node
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
ActivationType
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_gate
,
ActivationType
active_state
)
{
grad_og
=
activation
(
_mm256_mul_ps
(
output_grad
,
state_atv
),
value_og
,
grad_og
=
activation
(
_mm256_mul_ps
(
output_grad
,
state_atv
),
value_og
,
active_gate
);
active_gate
);
state_grad
=
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
output_grad
,
value_og
),
state_grad
=
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
output_grad
,
value_og
),
...
...
paddle/operators/math/lstm_compute.cc
浏览文件 @
d760b6a5
...
@@ -24,12 +24,12 @@ template <class T>
...
@@ -24,12 +24,12 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
LstmUnitFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
detail
::
ActivationType
&
gate_act
,
const
std
::
string
&
cand_act
)
{
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
cand_act
,
gate_act
,
cell_act
);
ActiveType
(
cell_act
));
value
.
gate_value
+=
frame_size
*
4
;
value
.
gate_value
+=
frame_size
*
4
;
value
.
state_value
+=
frame_size
;
value
.
state_value
+=
frame_size
;
value
.
state_active_value
+=
frame_size
;
value
.
state_active_value
+=
frame_size
;
...
@@ -46,12 +46,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
...
@@ -46,12 +46,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
detail
::
ActivationType
&
gate_act
,
const
std
::
string
&
cand_act
)
{
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
detail
::
cpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
ActiveType
(
cand_act
),
frame_size
,
cand_act
,
gate_act
,
cell_act
);
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gate_value
+=
frame_size
*
4
;
value
.
gate_value
+=
frame_size
*
4
;
value
.
state_value
+=
frame_size
;
value
.
state_value
+=
frame_size
;
...
...
paddle/operators/math/lstm_compute.cu
浏览文件 @
d760b6a5
...
@@ -24,11 +24,12 @@ template <class T>
...
@@ -24,11 +24,12 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
LstmUnitFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
detail
::
ActivationType
&
gate_act
,
const
std
::
string
&
cand_act
)
{
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
detail
::
gpu_lstm_forward
<
T
>
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
detail
::
gpu_lstm_forward
<
T
>
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
)
,
frame_size
,
batch_size
,
cand_act
,
ActiveType
(
gate_act
),
ActiveType
(
cell_act
)
);
gate_act
,
cell_act
);
}
}
};
};
...
@@ -37,11 +38,12 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
...
@@ -37,11 +38,12 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
detail
::
ActivationType
&
gate_act
,
const
std
::
string
&
cand_act
)
{
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
)
,
frame_size
,
batch_size
,
cand_act
,
ActiveType
(
gate_act
),
ActiveType
(
cell_act
)
);
gate_act
,
cell_act
);
}
}
};
};
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
d760b6a5
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/platform/device_context.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/enforce.h"
#include "paddle/operators/math/detail/activation_functions.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -29,6 +30,7 @@ typedef enum {
...
@@ -29,6 +30,7 @@ typedef enum {
HL_ACTIVATION_END
HL_ACTIVATION_END
}
activation_mode_t
;
}
activation_mode_t
;
template
<
class
T
>
template
<
class
T
>
struct
LstmMetaValue
{
struct
LstmMetaValue
{
T
*
gate_value
;
T
*
gate_value
;
...
@@ -72,8 +74,9 @@ class LstmUnitFunctor {
...
@@ -72,8 +74,9 @@ class LstmUnitFunctor {
public:
public:
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
detail
::
ActivationType
&
gate_act
,
const
std
::
string
&
cand_act
);
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
...
@@ -81,8 +84,9 @@ class LstmUnitGradFunctor {
...
@@ -81,8 +84,9 @@ class LstmUnitGradFunctor {
public:
public:
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
detail
::
ActivationType
&
gate_act
,
const
std
::
string
&
cand_act
);
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
);
};
};
}
// namespace math
}
// namespace math
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录