Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
f8391545
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
f8391545
编写于
12月 26, 2017
作者:
Q
qingqing01
提交者:
GitHub
12月 26, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6996 from qingqing01/lstm_active_type
Refine the activation type getting in the LSTM operator to speed.
上级
1398854f
a8e18549
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
97 addition
and
69 deletion
+97
-69
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+13
-6
paddle/operators/math/detail/activation_functions.h
paddle/operators/math/detail/activation_functions.h
+21
-0
paddle/operators/math/detail/lstm_cpu_kernel.h
paddle/operators/math/detail/lstm_cpu_kernel.h
+17
-20
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+10
-12
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+11
-11
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+8
-8
paddle/operators/math/lstm_compute.cu
paddle/operators/math/lstm_compute.cu
+10
-8
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+7
-4
未找到文件。
paddle/operators/lstm_op.h
浏览文件 @
f8391545
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
...
...
@@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel<T> {
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
);
auto
gate_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
auto
cell_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
...
...
@@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g
.
mutable_data
<
T
>
(
batch_gate
->
dims
(),
ctx
.
GetPlace
());
batch_gate_g
.
set_lod
(
batch_gate
->
lod
());
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
);
auto
gate_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
auto
cell_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
...
...
paddle/operators/math/detail/activation_functions.h
浏览文件 @
f8391545
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/platform/enforce.h"
#include "paddle/platform/hostdevice.h"
#ifdef __AVX__
...
...
@@ -29,6 +30,26 @@ namespace detail {
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
enum
ActivationType
{
kSigmoid
,
kReLU
,
kTanh
,
kIdentity
,
};
inline
ActivationType
GetActivationType
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
return
ActivationType
::
kSigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
ActivationType
::
kReLU
;
}
else
if
(
type
==
"tanh"
)
{
return
ActivationType
::
kTanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
ActivationType
::
kIdentity
;
}
PADDLE_THROW
(
"Not support type %s."
,
type
);
}
namespace
forward
{
template
<
typename
T
>
...
...
paddle/operators/math/detail/lstm_cpu_kernel.h
浏览文件 @
f8391545
...
...
@@ -26,10 +26,9 @@ namespace detail {
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
int
frame_size
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
T
r_value_in
;
T
r_value_ig
;
T
r_value_fg
;
...
...
@@ -77,9 +76,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
T
r_value_in
;
T
r_value_ig
;
T
r_value_fg
;
...
...
@@ -149,10 +148,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
int
frame_size
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
#ifdef __AVX__
__m256
r_value_in
;
__m256
r_value_ig
;
...
...
@@ -204,9 +202,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
#ifdef __AVX__
__m256
r_value_in
;
__m256
r_value_ig
;
...
...
@@ -281,9 +279,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame_size
,
active_node
,
active_gate
,
active_state
);
...
...
@@ -295,9 +292,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
template
<
class
T
,
class
Op
>
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
int
frame_size
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame_size
,
active_node
,
active_gate
,
active_state
);
...
...
paddle/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
f8391545
...
...
@@ -31,9 +31,9 @@ namespace detail {
*/
template
<
class
T
,
class
Op
,
bool
is_batch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
int
batch_size
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
...
...
@@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
template
<
class
T
,
class
Op
,
bool
is_batch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
int
batch_size
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
...
...
@@ -185,9 +185,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
dim3
threads
;
dim3
grid
;
if
(
batch_size
==
1
)
{
...
...
@@ -220,9 +219,8 @@ template <class T, class Op>
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
dim3
threads
;
dim3
grid
;
if
(
batch_size
==
1
)
{
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
f8391545
...
...
@@ -30,9 +30,9 @@ class lstm {
HOSTDEVICE
void
operator
()(
T
&
value_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_og
,
T
&
prev_state
,
T
&
state
,
T
&
state_atv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
value_in
=
activation
(
value_in
,
active_node
);
value_ig
=
activation
(
value_ig
+
prev_state
*
checkI
,
active_gate
);
value_fg
=
activation
(
value_fg
+
prev_state
*
checkF
,
active_gate
);
...
...
@@ -53,9 +53,9 @@ class lstm {
__m256
&
prev_state
,
__m256
&
state
,
__m256
&
state_atv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
value_in
=
activation
(
value_in
,
active_node
);
value_ig
=
activation
(
_mm256_add_ps
(
value_ig
,
_mm256_mul_ps
(
prev_state
,
checkI
)),
...
...
@@ -87,9 +87,9 @@ class lstm {
T
&
state_grad
,
T
&
state_atv
,
T
&
output_grad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
grad_og
=
activation
(
output_grad
*
state_atv
,
value_og
,
active_gate
);
state_grad
+=
activation
(
output_grad
*
value_og
,
state_atv
,
active_state
)
+
grad_og
*
checkO
;
...
...
@@ -114,8 +114,8 @@ class lstm {
__m256
&
prev_state
,
__m256
&
prev_state_grad
,
__m256
&
state
,
__m256
&
state_grad
,
__m256
&
state_atv
,
__m256
&
output_grad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
grad_og
=
activation
(
_mm256_mul_ps
(
output_grad
,
state_atv
),
value_og
,
active_gate
);
state_grad
=
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
output_grad
,
value_og
),
...
...
paddle/operators/math/lstm_compute.cc
浏览文件 @
f8391545
...
...
@@ -24,12 +24,12 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
cand_act
,
gate_act
,
cell_act
);
value
.
gate_value
+=
frame_size
*
4
;
value
.
state_value
+=
frame_size
;
value
.
state_active_value
+=
frame_size
;
...
...
@@ -46,12 +46,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
frame_size
,
cand_act
,
gate_act
,
cell_act
);
value
.
gate_value
+=
frame_size
*
4
;
value
.
state_value
+=
frame_size
;
...
...
paddle/operators/math/lstm_compute.cu
浏览文件 @
f8391545
...
...
@@ -24,11 +24,12 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
detail
::
gpu_lstm_forward
<
T
>
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
)
,
ActiveType
(
gate_act
),
ActiveType
(
cell_act
)
);
frame_size
,
batch_size
,
cand_act
,
gate_act
,
cell_act
);
}
};
...
...
@@ -37,11 +38,12 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
)
,
ActiveType
(
gate_act
),
ActiveType
(
cell_act
)
);
frame_size
,
batch_size
,
cand_act
,
gate_act
,
cell_act
);
}
};
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
f8391545
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
...
...
@@ -72,8 +73,9 @@ class LstmUnitFunctor {
public:
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
);
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
);
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -81,8 +83,9 @@ class LstmUnitGradFunctor {
public:
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
);
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
);
};
}
// namespace math
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录