Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3e1b914f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e1b914f
编写于
1月 07, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update gru op forward kernel
上级
7a81ab86
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
61 addition
and
29 deletion
+61
-29
paddle/fluid/operators/gru_op.cc
paddle/fluid/operators/gru_op.cc
+7
-2
paddle/fluid/operators/gru_op.cu.cc
paddle/fluid/operators/gru_op.cu.cc
+2
-1
paddle/fluid/operators/gru_op.h
paddle/fluid/operators/gru_op.h
+2
-1
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
+11
-8
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
+3
-2
paddle/fluid/operators/math/detail/gru_kernel.h
paddle/fluid/operators/math/detail/gru_kernel.h
+22
-7
paddle/fluid/operators/math/gru_compute.cc
paddle/fluid/operators/math/gru_compute.cc
+6
-3
paddle/fluid/operators/math/gru_compute.cu
paddle/fluid/operators/math/gru_compute.cu
+4
-3
paddle/fluid/operators/math/gru_compute.h
paddle/fluid/operators/math/gru_compute.h
+4
-2
未找到文件。
paddle/fluid/operators/gru_op.cc
浏览文件 @
3e1b914f
...
...
@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) "
"whether to compute reversed GRU."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"origin_mode"
,
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
GRU Operator implements part calculations of the complete GRU as following:
...
...
@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
bool
origin_mode
=
context
.
Attr
<
bool
>
(
"origin_mode"
);
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
...
...
@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math
::
detail
::
forward_final_output
(
math
::
detail
::
forward
::
gru_finalOutput
<
T
>
(),
gru_value
,
frame_size
,
cur_batch_size
,
active_node
);
cur_batch_size
,
active_node
,
origin_mode
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
...
...
@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
active_gate
,
origin_mode
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
...
...
paddle/fluid/operators/gru_op.cu.cc
浏览文件 @
3e1b914f
...
...
@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T>
class
GRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
bool
origin_mode
=
context
.
Attr
<
bool
>
(
"origin_mode"
);
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
...
...
@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
active_gate
,
origin_mode
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
...
...
paddle/fluid/operators/gru_op.h
浏览文件 @
3e1b914f
...
...
@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T>
class
GRUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
bool
origin_mode
=
context
.
Attr
<
bool
>
(
"origin_mode"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
...
...
@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
math
::
GRUUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
gru_grad
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
active_gate
,
origin_mode
);
}
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
3e1b914f
...
...
@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
ActivationType
active_node
)
{
ActivationType
active_node
,
bool
origin_mode
)
{
T
r_value_update_gate
;
T
r_value_frame_state
;
T
r_prev_out
=
0
;
...
...
@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
}
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
&
r_output
,
active_node
);
&
r_output
,
active_node
,
origin_mode
);
frame_state
[
i
]
=
r_value_frame_state
;
output_value
[
i
]
=
r_output
;
...
...
@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
ActivationType
active_node
)
{
ActivationType
active_node
,
bool
origin_mode
)
{
#ifdef __AVX__
__m256
r_value_update_gate
,
r_value_update_gate_last
=
_mm256_set1_ps
(
0.0
f
);
__m256
r_value_frame_state
,
r_value_frame_state_last
=
_mm256_set1_ps
(
0.0
f
);
...
...
@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
}
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
&
r_output
,
active_node
);
&
r_output
,
active_node
,
origin_mode
);
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
frame_state
+
i
),
r_value_frame_state
);
...
...
@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
if
(
rest
>
0
)
{
i
=
n
-
block
;
op_final_output
(
&
r_value_update_gate_last
,
&
r_value_frame_state_last
,
&
r_prev_out_last
,
&
r_output
,
active_node
);
&
r_prev_out_last
,
&
r_output
,
active_node
,
origin_mode
);
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
frame_state
+
i
),
r_value_frame_state_last
);
...
...
@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template
<
class
OpFinalOutput
,
typename
T
>
inline
void
forward_final_output
(
OpFinalOutput
op_final_output
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
ActivationType
active_node
)
{
int
batch_size
,
ActivationType
active_node
,
bool
origin_mode
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpFinalOutput
::
avx
&&
(
frame_size
>
static_cast
<
int
>
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
active_node
);
frame_size
,
active_node
,
origin_mode
);
}
else
{
hl_naive_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
active_node
);
value
.
output_value
,
frame_size
,
active_node
,
origin_mode
);
}
value
.
gate_value
+=
frame_size
*
3
;
...
...
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
3e1b914f
...
...
@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
int
batch_size
,
ActivationType
active_node
)
{
ActivationType
active_node
,
bool
origin_mode
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
...
...
@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
}
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
&
r_output
,
active_node
);
&
r_output
,
active_node
,
origin_mode
);
gate_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_frame_state
;
output_value
[
frame_idx
]
=
r_output
;
...
...
paddle/fluid/operators/math/detail/gru_kernel.h
浏览文件 @
3e1b914f
...
...
@@ -57,10 +57,16 @@ class gru_finalOutput {
public:
HOSTDEVICE
void
operator
()(
T
*
value_update_gate
,
T
*
value_frame_state
,
T
*
prev_out
,
T
*
value_output
,
ActivationType
act_input
)
{
ActivationType
act_input
,
bool
origin_mode
)
{
*
value_frame_state
=
activation
(
*
value_frame_state
,
act_input
);
*
value_output
=
*
prev_out
-
((
*
value_update_gate
)
*
(
*
prev_out
))
+
((
*
value_update_gate
)
*
(
*
value_frame_state
));
if
(
origin_mode
)
{
*
value_output
=
((
*
value_update_gate
)
*
(
*
prev_out
))
+
*
value_frame_state
-
((
*
value_update_gate
)
*
(
*
value_frame_state
));
}
else
{
*
value_output
=
*
prev_out
-
((
*
value_update_gate
)
*
(
*
prev_out
))
+
((
*
value_update_gate
)
*
(
*
value_frame_state
));
}
}
#ifndef __NVCC__
#ifndef __AVX__
...
...
@@ -69,11 +75,20 @@ class gru_finalOutput {
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
*
value_update_gate
,
__m256
*
value_frame_state
,
__m256
*
prev_out
,
__m256
*
value_output
,
ActivationType
act_input
)
{
__m256
*
value_output
,
ActivationType
act_input
,
bool
origin_mode
)
{
*
value_frame_state
=
activation
(
*
value_frame_state
,
act_input
);
*
value_output
=
_mm256_add_ps
(
_mm256_sub_ps
(
*
prev_out
,
_mm256_mul_ps
(
*
value_update_gate
,
*
prev_out
)),
_mm256_mul_ps
(
*
value_update_gate
,
*
value_frame_state
));
if
(
origin_mode
)
{
*
value_output
=
_mm256_sub_ps
(
_mm256_add_ps
(
_mm256_mul_ps
(
*
value_update_gate
,
*
prev_out
),
*
value_frame_state
),
_mm256_mul_ps
(
*
value_update_gate
,
*
value_frame_state
));
}
else
{
*
value_output
=
_mm256_add_ps
(
_mm256_sub_ps
(
*
prev_out
,
_mm256_mul_ps
(
*
value_update_gate
,
*
prev_out
)),
_mm256_mul_ps
(
*
value_update_gate
,
*
value_frame_state
));
}
}
#endif
#endif
...
...
paddle/fluid/operators/math/gru_compute.cc
浏览文件 @
3e1b914f
...
...
@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
)
{
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
)
{
#ifndef __NVCC__
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
)
{
...
...
@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
frame_size
,
batch_size
,
active_node
);
frame_size
,
batch_size
,
active_node
,
origin_mode
);
#endif
}
};
...
...
@@ -54,7 +56,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
GRUMetaValue
<
T
>
value
,
GRUMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
)
{
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
)
{
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_node
);
...
...
paddle/fluid/operators/math/gru_compute.cu
浏览文件 @
3e1b914f
...
...
@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
)
{
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
)
{
auto
stream
=
context
.
stream
();
dim3
threads
;
dim3
grid
;
...
...
@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
batch_size
,
active_node
);
active_node
,
origin_mode
);
}
else
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is_batch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
batch_size
,
active_node
);
active_node
,
origin_mode
);
}
}
};
...
...
paddle/fluid/operators/math/gru_compute.h
浏览文件 @
3e1b914f
...
...
@@ -44,7 +44,8 @@ struct GRUUnitFunctor {
static
void
compute
(
const
DeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
);
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
);
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor {
static
void
compute
(
const
DeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
GRUMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
);
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
);
};
}
// namespace math
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录