Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3e1b914f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e1b914f
编写于
1月 07, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update gru op forward kernel
上级
7a81ab86
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
61 addition
and
29 deletion
+61
-29
paddle/fluid/operators/gru_op.cc
paddle/fluid/operators/gru_op.cc
+7
-2
paddle/fluid/operators/gru_op.cu.cc
paddle/fluid/operators/gru_op.cu.cc
+2
-1
paddle/fluid/operators/gru_op.h
paddle/fluid/operators/gru_op.h
+2
-1
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
+11
-8
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
+3
-2
paddle/fluid/operators/math/detail/gru_kernel.h
paddle/fluid/operators/math/detail/gru_kernel.h
+22
-7
paddle/fluid/operators/math/gru_compute.cc
paddle/fluid/operators/math/gru_compute.cc
+6
-3
paddle/fluid/operators/math/gru_compute.cu
paddle/fluid/operators/math/gru_compute.cu
+4
-3
paddle/fluid/operators/math/gru_compute.h
paddle/fluid/operators/math/gru_compute.h
+4
-2
未找到文件。
paddle/fluid/operators/gru_op.cc
浏览文件 @
3e1b914f
...
@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) "
"(bool, defalut: False) "
"whether to compute reversed GRU."
)
"whether to compute reversed GRU."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"origin_mode"
,
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
GRU Operator implements part calculations of the complete GRU as following:
GRU Operator implements part calculations of the complete GRU as following:
...
@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
...
@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public:
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
bool
origin_mode
=
context
.
Attr
<
bool
>
(
"origin_mode"
);
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
...
@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
...
@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math
::
detail
::
forward_final_output
(
math
::
detail
::
forward_final_output
(
math
::
detail
::
forward
::
gru_finalOutput
<
T
>
(),
gru_value
,
frame_size
,
math
::
detail
::
forward
::
gru_finalOutput
<
T
>
(),
gru_value
,
frame_size
,
cur_batch_size
,
active_node
);
cur_batch_size
,
active_node
,
origin_mode
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
}
...
@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
...
@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
active_gate
,
origin_mode
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
}
...
...
paddle/fluid/operators/gru_op.cu.cc
浏览文件 @
3e1b914f
...
@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T>
...
@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T>
class
GRUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
bool
origin_mode
=
context
.
Attr
<
bool
>
(
"origin_mode"
);
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
...
@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
active_gate
,
origin_mode
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
}
...
...
paddle/fluid/operators/gru_op.h
浏览文件 @
3e1b914f
...
@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T>
...
@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T>
class
GRUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GRUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
bool
origin_mode
=
context
.
Attr
<
bool
>
(
"origin_mode"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
const
T
*
weight_data
=
weight
->
data
<
T
>
();
...
@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
math
::
GRUUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
GRUUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
gru_grad
,
frame_size
,
cur_batch_size
,
active_node
,
dev_ctx
,
gru_value
,
gru_grad
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
active_gate
,
origin_mode
);
}
}
if
(
input_grad
)
{
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
3e1b914f
...
@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
...
@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
T
*
output_value
,
int
frame_size
,
ActivationType
active_node
)
{
ActivationType
active_node
,
bool
origin_mode
)
{
T
r_value_update_gate
;
T
r_value_update_gate
;
T
r_value_frame_state
;
T
r_value_frame_state
;
T
r_prev_out
=
0
;
T
r_prev_out
=
0
;
...
@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
...
@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
}
}
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
&
r_output
,
active_node
);
&
r_output
,
active_node
,
origin_mode
);
frame_state
[
i
]
=
r_value_frame_state
;
frame_state
[
i
]
=
r_value_frame_state
;
output_value
[
i
]
=
r_output
;
output_value
[
i
]
=
r_output
;
...
@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
...
@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
T
*
output_value
,
int
frame_size
,
ActivationType
active_node
)
{
ActivationType
active_node
,
bool
origin_mode
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r_value_update_gate
,
r_value_update_gate_last
=
_mm256_set1_ps
(
0.0
f
);
__m256
r_value_update_gate
,
r_value_update_gate_last
=
_mm256_set1_ps
(
0.0
f
);
__m256
r_value_frame_state
,
r_value_frame_state_last
=
_mm256_set1_ps
(
0.0
f
);
__m256
r_value_frame_state
,
r_value_frame_state_last
=
_mm256_set1_ps
(
0.0
f
);
...
@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
...
@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
}
}
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
&
r_output
,
active_node
);
&
r_output
,
active_node
,
origin_mode
);
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
frame_state
+
i
),
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
frame_state
+
i
),
r_value_frame_state
);
r_value_frame_state
);
...
@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
...
@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
if
(
rest
>
0
)
{
if
(
rest
>
0
)
{
i
=
n
-
block
;
i
=
n
-
block
;
op_final_output
(
&
r_value_update_gate_last
,
&
r_value_frame_state_last
,
op_final_output
(
&
r_value_update_gate_last
,
&
r_value_frame_state_last
,
&
r_prev_out_last
,
&
r_output
,
active_node
);
&
r_prev_out_last
,
&
r_output
,
active_node
,
origin_mode
);
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
frame_state
+
i
),
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
frame_state
+
i
),
r_value_frame_state_last
);
r_value_frame_state_last
);
...
@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
...
@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template
<
class
OpFinalOutput
,
typename
T
>
template
<
class
OpFinalOutput
,
typename
T
>
inline
void
forward_final_output
(
OpFinalOutput
op_final_output
,
inline
void
forward_final_output
(
OpFinalOutput
op_final_output
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
ActivationType
active_node
)
{
int
batch_size
,
ActivationType
active_node
,
bool
origin_mode
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpFinalOutput
::
avx
&&
(
frame_size
>
static_cast
<
int
>
(
8
-
1
))
&&
if
(
OpFinalOutput
::
avx
&&
(
frame_size
>
static_cast
<
int
>
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
hl_avx_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
active_node
);
frame_size
,
active_node
,
origin_mode
);
}
else
{
}
else
{
hl_naive_gru_forward_final_output
(
hl_naive_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
value
.
prev_out_value
,
op_final_output
,
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
active_node
);
value
.
output_value
,
frame_size
,
active_node
,
origin_mode
);
}
}
value
.
gate_value
+=
frame_size
*
3
;
value
.
gate_value
+=
frame_size
*
3
;
...
...
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
3e1b914f
...
@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
...
@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
T
*
output_value
,
int
frame_size
,
int
batch_size
,
int
batch_size
,
ActivationType
active_node
)
{
ActivationType
active_node
,
bool
origin_mode
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
int
batch_idx
=
0
;
...
@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
...
@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
}
}
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
op_final_output
(
&
r_value_update_gate
,
&
r_value_frame_state
,
&
r_prev_out
,
&
r_output
,
active_node
);
&
r_output
,
active_node
,
origin_mode
);
gate_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_frame_state
;
gate_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_frame_state
;
output_value
[
frame_idx
]
=
r_output
;
output_value
[
frame_idx
]
=
r_output
;
...
...
paddle/fluid/operators/math/detail/gru_kernel.h
浏览文件 @
3e1b914f
...
@@ -57,11 +57,17 @@ class gru_finalOutput {
...
@@ -57,11 +57,17 @@ class gru_finalOutput {
public:
public:
HOSTDEVICE
void
operator
()(
T
*
value_update_gate
,
T
*
value_frame_state
,
HOSTDEVICE
void
operator
()(
T
*
value_update_gate
,
T
*
value_frame_state
,
T
*
prev_out
,
T
*
value_output
,
T
*
prev_out
,
T
*
value_output
,
ActivationType
act_input
)
{
ActivationType
act_input
,
bool
origin_mode
)
{
*
value_frame_state
=
activation
(
*
value_frame_state
,
act_input
);
*
value_frame_state
=
activation
(
*
value_frame_state
,
act_input
);
if
(
origin_mode
)
{
*
value_output
=
((
*
value_update_gate
)
*
(
*
prev_out
))
+
*
value_frame_state
-
((
*
value_update_gate
)
*
(
*
value_frame_state
));
}
else
{
*
value_output
=
*
prev_out
-
((
*
value_update_gate
)
*
(
*
prev_out
))
+
*
value_output
=
*
prev_out
-
((
*
value_update_gate
)
*
(
*
prev_out
))
+
((
*
value_update_gate
)
*
(
*
value_frame_state
));
((
*
value_update_gate
)
*
(
*
value_frame_state
));
}
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
...
@@ -69,12 +75,21 @@ class gru_finalOutput {
...
@@ -69,12 +75,21 @@ class gru_finalOutput {
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
*
value_update_gate
,
HOSTDEVICE
void
operator
()(
__m256
*
value_update_gate
,
__m256
*
value_frame_state
,
__m256
*
prev_out
,
__m256
*
value_frame_state
,
__m256
*
prev_out
,
__m256
*
value_output
,
ActivationType
act_input
)
{
__m256
*
value_output
,
ActivationType
act_input
,
bool
origin_mode
)
{
*
value_frame_state
=
activation
(
*
value_frame_state
,
act_input
);
*
value_frame_state
=
activation
(
*
value_frame_state
,
act_input
);
if
(
origin_mode
)
{
*
value_output
=
_mm256_sub_ps
(
_mm256_add_ps
(
_mm256_mul_ps
(
*
value_update_gate
,
*
prev_out
),
*
value_frame_state
),
_mm256_mul_ps
(
*
value_update_gate
,
*
value_frame_state
));
}
else
{
*
value_output
=
_mm256_add_ps
(
*
value_output
=
_mm256_add_ps
(
_mm256_sub_ps
(
*
prev_out
,
_mm256_mul_ps
(
*
value_update_gate
,
*
prev_out
)),
_mm256_sub_ps
(
*
prev_out
,
_mm256_mul_ps
(
*
value_update_gate
,
*
prev_out
)),
_mm256_mul_ps
(
*
value_update_gate
,
*
value_frame_state
));
_mm256_mul_ps
(
*
value_update_gate
,
*
value_frame_state
));
}
}
}
#endif
#endif
#endif
#endif
};
};
...
...
paddle/fluid/operators/math/gru_compute.cc
浏览文件 @
3e1b914f
...
@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
...
@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
)
{
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
)
{
#ifndef __NVCC__
#ifndef __NVCC__
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
)
{
if
(
value
.
prev_out_value
)
{
...
@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
...
@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
}
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
frame_size
,
batch_size
,
active_node
);
frame_size
,
batch_size
,
active_node
,
origin_mode
);
#endif
#endif
}
}
};
};
...
@@ -54,7 +56,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
...
@@ -54,7 +56,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
GRUMetaValue
<
T
>
value
,
GRUMetaGrad
<
T
>
grad
,
GRUMetaValue
<
T
>
value
,
GRUMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
)
{
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
)
{
#ifndef __NVCC__
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_node
);
grad
,
frame_size
,
batch_size
,
active_node
);
...
...
paddle/fluid/operators/math/gru_compute.cu
浏览文件 @
3e1b914f
...
@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
...
@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
)
{
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
)
{
auto
stream
=
context
.
stream
();
auto
stream
=
context
.
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
...
@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
...
@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate_value
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
batch_size
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
batch_size
,
active_node
);
active_node
,
origin_mode
);
}
else
{
}
else
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is_batch= */
true
,
/* is_batch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate_value
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
batch_size
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
batch_size
,
active_node
);
active_node
,
origin_mode
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/math/gru_compute.h
浏览文件 @
3e1b914f
...
@@ -44,7 +44,8 @@ struct GRUUnitFunctor {
...
@@ -44,7 +44,8 @@ struct GRUUnitFunctor {
static
void
compute
(
const
DeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
static
void
compute
(
const
DeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
);
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
...
@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor {
...
@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor {
static
void
compute
(
const
DeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
static
void
compute
(
const
DeviceContext
&
context
,
GRUMetaValue
<
T
>
value
,
GRUMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
GRUMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
);
const
detail
::
ActivationType
active_gate
,
bool
origin_mode
);
};
};
}
// namespace math
}
// namespace math
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录