Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3e552cdc
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e552cdc
编写于
11月 29, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix gru_op related code style
上级
dcf3ffd9
变更
7
展开全部
显示空白变更内容
内联
并排
Showing
7 changed file
with
617 addition
and
599 deletion
+617
-599
paddle/operators/gru_op.h
paddle/operators/gru_op.h
+23
-23
paddle/operators/math/detail/gru_cpu_kernel.h
paddle/operators/math/detail/gru_cpu_kernel.h
+272
-268
paddle/operators/math/detail/gru_gpu_kernel.h
paddle/operators/math/detail/gru_gpu_kernel.h
+126
-126
paddle/operators/math/detail/gru_kernel.h
paddle/operators/math/detail/gru_kernel.h
+72
-63
paddle/operators/math/gru_compute.cc
paddle/operators/math/gru_compute.cc
+33
-31
paddle/operators/math/gru_compute.cu
paddle/operators/math/gru_compute.cu
+75
-73
paddle/operators/math/gru_compute.h
paddle/operators/math/gru_compute.h
+16
-15
未找到文件。
paddle/operators/gru_op.h
浏览文件 @
3e552cdc
...
@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
int
frame_size
=
hidden_dims
[
1
];
int
frame_size
=
hidden_dims
[
1
];
math
::
hl_gru_value
<
T
>
gru_value
;
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gate
W
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
gate
_w
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
W
eight
=
gru_value
.
state
_w
eight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
Tensor
ordered_h0
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
...
@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
// to reorder.
// to reorder.
ReorderInitState
<
Place
,
T
>
(
context
.
device_context
(),
*
h0
,
order
,
ReorderInitState
<
Place
,
T
>
(
context
.
device_context
(),
*
h0
,
order
,
&
ordered_h0
,
true
);
&
ordered_h0
,
true
);
gru_value
.
prev
OutV
alue
=
ordered_h0
.
data
<
T
>
();
gru_value
.
prev
_out_v
alue
=
ordered_h0
.
data
<
T
>
();
}
else
{
}
else
{
gru_value
.
prev
OutV
alue
=
nullptr
;
gru_value
.
prev
_out_v
alue
=
nullptr
;
}
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
...
@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output
V
alue
=
hidden_t
.
data
<
T
>
();
gru_value
.
output
_v
alue
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
reset
OutputV
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset
_output_v
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
Place
,
T
>::
compute
(
math
::
GRUUnitFunctor
<
Place
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
gru_value
.
prev
OutValue
=
gru_value
.
outputV
alue
;
gru_value
.
prev
_out_value
=
gru_value
.
output_v
alue
;
}
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
...
@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
math
::
hl_gru_value
<
T
>
gru_value
;
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gate
W
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
gate
_w
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
W
eight
=
gru_value
.
state
_w
eight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
math
::
hl_gru_grad
<
T
>
gru_grad
;
math
::
hl_gru_grad
<
T
>
gru_grad
;
if
(
weight_grad
)
{
if
(
weight_grad
)
{
gru_grad
.
gate
WeightG
rad
=
gru_grad
.
gate
_weight_g
rad
=
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
zero
(
dev_ctx
,
weight_grad
,
static_cast
<
T
>
(
0.0
));
zero
(
dev_ctx
,
weight_grad
,
static_cast
<
T
>
(
0.0
));
gru_grad
.
state
WeightG
rad
=
gru_grad
.
state
_weight_g
rad
=
weight_grad
->
data
<
T
>
()
+
2
*
frame_size
*
frame_size
;
weight_grad
->
data
<
T
>
()
+
2
*
frame_size
*
frame_size
;
}
else
{
}
else
{
gru_grad
.
gate
WeightG
rad
=
nullptr
;
gru_grad
.
gate
_weight_g
rad
=
nullptr
;
gru_grad
.
state
WeightG
rad
=
nullptr
;
gru_grad
.
state
_weight_g
rad
=
nullptr
;
}
}
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
...
@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
gru_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
gru_value
.
reset
OutputV
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset
_output_v
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_grad_t
=
batch_hidden_grad
.
Slice
(
bstart
,
bend
);
Tensor
hidden_grad_t
=
batch_hidden_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
output
G
rad
=
hidden_grad_t
.
data
<
T
>
();
gru_grad
.
output
_g
rad
=
hidden_grad_t
.
data
<
T
>
();
Tensor
gate_grad_t
=
batch_gate_grad
.
Slice
(
bstart
,
bend
);
Tensor
gate_grad_t
=
batch_gate_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
gate
G
rad
=
gate_grad_t
.
data
<
T
>
();
gru_grad
.
gate
_g
rad
=
gate_grad_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_grad_t
=
Tensor
reset_hidden_prev_grad_t
=
batch_reset_hidden_prev_grad
.
Slice
(
bstart
,
bend
);
batch_reset_hidden_prev_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
reset
OutputG
rad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
gru_grad
.
reset
_output_g
rad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
if
(
n
==
0
)
{
if
(
n
==
0
)
{
gru_value
.
prev
OutV
alue
=
h0
?
ordered_h0
.
data
<
T
>
()
:
nullptr
;
gru_value
.
prev
_out_v
alue
=
h0
?
ordered_h0
.
data
<
T
>
()
:
nullptr
;
gru_grad
.
prev
OutG
rad
=
gru_grad
.
prev
_out_g
rad
=
h0
&&
h0_grad
?
ordered_h0_grad
.
data
<
T
>
()
:
nullptr
;
h0
&&
h0_grad
?
ordered_h0_grad
.
data
<
T
>
()
:
nullptr
;
}
else
{
}
else
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
hidden_prev_t
=
batch_hidden
->
Slice
(
bstart_pre
,
bstart
);
Tensor
hidden_prev_t
=
batch_hidden
->
Slice
(
bstart_pre
,
bstart
);
gru_value
.
prev
OutV
alue
=
hidden_prev_t
.
data
<
T
>
();
gru_value
.
prev
_out_v
alue
=
hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_prev_grad_t
=
batch_hidden_grad
.
Slice
(
bstart_pre
,
bstart
);
Tensor
hidden_prev_grad_t
=
batch_hidden_grad
.
Slice
(
bstart_pre
,
bstart
);
gru_grad
.
prev
OutG
rad
=
hidden_prev_grad_t
.
data
<
T
>
();
gru_grad
.
prev
_out_g
rad
=
hidden_prev_grad_t
.
data
<
T
>
();
}
}
math
::
GRUUnitGradFunctor
<
Place
,
T
>::
compute
(
math
::
GRUUnitGradFunctor
<
Place
,
T
>::
compute
(
...
...
paddle/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
3e552cdc
此差异已折叠。
点击以展开。
paddle/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
3e552cdc
...
@@ -27,174 +27,174 @@ namespace math {
...
@@ -27,174 +27,174 @@ namespace math {
namespace
detail
{
namespace
detail
{
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpResetOutput
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpResetOutput
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
op
ResetO
utput
,
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
op
_reset_o
utput
,
T
*
gate
Value
,
T
*
resetOutputV
alue
,
T
*
gate
_value
,
T
*
reset_output_v
alue
,
T
*
prev
OutputValue
,
int
frameS
ize
,
T
*
prev
_output_value
,
int
frame_s
ize
,
int
batch
S
ize
,
int
batch
_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
reset
OutputValue
+=
batchIdx
*
frameS
ize
;
reset
_output_value
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
PrevO
ut
=
0
;
T
r
_prev_o
ut
=
0
;
T
r
ValueResetO
utput
;
T
r
_value_reset_o
utput
;
T
r
ValueUpdateGate
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_value_update_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ValueResetGate
=
gateValue
[
frameIdx
+
frameS
ize
*
1
];
T
r
_value_reset_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
1
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
if
(
is
Batch
)
prevOutputValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_output_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOut
=
prevOutputValue
[
frameI
dx
];
r
_prev_out
=
prev_output_value
[
frame_i
dx
];
}
}
op
ResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutp
ut
,
op
_reset_output
(
r_value_update_gate
,
r_value_reset_gate
,
r_prev_o
ut
,
active_gate
);
r_value_reset_output
,
active_gate
);
gate
Value
[
frameIdx
+
frameSize
*
0
]
=
rValueUpdateG
ate
;
gate
_value
[
frame_idx
+
frame_size
*
0
]
=
r_value_update_g
ate
;
gate
Value
[
frameIdx
+
frameSize
*
1
]
=
rValueResetG
ate
;
gate
_value
[
frame_idx
+
frame_size
*
1
]
=
r_value_reset_g
ate
;
reset
OutputValue
[
frameIdx
]
=
rValueResetO
utput
;
reset
_output_value
[
frame_idx
]
=
r_value_reset_o
utput
;
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpFinalOutput
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpFinalOutput
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
op
FinalO
utput
,
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
op
_final_o
utput
,
T
*
gate
Value
,
T
*
prevOutputV
alue
,
T
*
gate
_value
,
T
*
prev_output_v
alue
,
T
*
output
Value
,
int
frameS
ize
,
T
*
output
_value
,
int
frame_s
ize
,
int
batch
S
ize
,
int
batch
_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
output
Value
+=
batchIdx
*
frameS
ize
;
output
_value
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
O
utput
;
T
r
_o
utput
;
T
r
PrevO
ut
=
0
;
T
r
_prev_o
ut
=
0
;
T
r
ValueUpdateGate
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_value_update_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ValueFrameState
=
gateValue
[
frameIdx
+
frameS
ize
*
2
];
T
r
_value_frame_state
=
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
if
(
is
Batch
)
prevOutputValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_output_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOut
=
prevOutputValue
[
frameI
dx
];
r
_prev_out
=
prev_output_value
[
frame_i
dx
];
}
}
op
FinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutp
ut
,
op
_final_output
(
r_value_update_gate
,
r_value_frame_state
,
r_prev_o
ut
,
active_node
);
r_output
,
active_node
);
gate
Value
[
frameIdx
+
frameSize
*
2
]
=
rValueFrameS
tate
;
gate
_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_frame_s
tate
;
output
Value
[
frameIdx
]
=
rO
utput
;
output
_value
[
frame_idx
]
=
r_o
utput
;
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpStateGrad
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpStateGrad
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
op
StateGrad
,
T
*
gateV
alue
,
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
op
_state_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
outputG
rad
,
T
*
prev
_out_grad
,
T
*
output_g
rad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
gate
Grad
+=
batchIdx
*
3
*
frameS
ize
;
gate
_grad
+=
batch_idx
*
3
*
frame_s
ize
;
output
Grad
+=
batchIdx
*
frameS
ize
;
output
_grad
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
UpdateGateG
rad
;
T
r
_update_gate_g
rad
;
T
r
FrameStateG
rad
;
T
r
_frame_state_g
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
r
UpdateGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
FrameStateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
2
];
T
r
_frame_state_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
T
r
OutGrad
=
outputGrad
[
frameI
dx
];
T
r
_out_grad
=
output_grad
[
frame_i
dx
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
if
(
is
Batch
)
prevOutValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOutValue
=
prevOutValue
[
frameI
dx
];
r
_prev_out_value
=
prev_out_value
[
frame_i
dx
];
if
(
is
Batch
)
prevOutGrad
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_grad
+=
batch_idx
*
frame_s
ize
;
r
PrevOutGrad
=
prevOutGrad
[
frameI
dx
];
r
_prev_out_grad
=
prev_out_grad
[
frame_i
dx
];
}
}
op
StateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateV
alue
,
op
_state_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_frame_state_v
alue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutG
rad
,
r_frame_state_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_node
);
r_out_grad
,
active_node
);
gate
Grad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
0
]
=
r_update_gate_g
rad
;
gate
Grad
[
frameIdx
+
frameSize
*
2
]
=
rFrameStateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
2
]
=
r_frame_state_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
prev
OutGrad
[
frameIdx
]
=
rPrevOutG
rad
;
prev
_out_grad
[
frame_idx
]
=
r_prev_out_g
rad
;
}
}
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpResetGrad
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpResetGrad
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
op
ResetGrad
,
T
*
gateV
alue
,
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
op
_reset_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
resetOutputG
rad
,
T
*
prev
_out_grad
,
T
*
reset_output_g
rad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
gate
Grad
+=
batchIdx
*
3
*
frameS
ize
;
gate
_grad
+=
batch_idx
*
3
*
frame_s
ize
;
reset
OutputGrad
+=
batchIdx
*
frameS
ize
;
reset
_output_grad
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
ResetGateG
rad
;
T
r
_reset_gate_g
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
r
ResetOutputG
rad
=
0
;
T
r
_reset_output_g
rad
=
0
;
T
r
UpdateGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
UpdateGateGrad
=
gateGrad
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_grad
=
gate_grad
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ResetGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
1
];
T
r
_reset_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
1
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
if
(
is
Batch
)
prevOutValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_value
+=
batch_idx
*
frame_s
ize
;
if
(
is
Batch
)
prevOutGrad
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_grad
+=
batch_idx
*
frame_s
ize
;
r
PrevOutValue
=
prevOutValue
[
frameI
dx
];
r
_prev_out_value
=
prev_out_value
[
frame_i
dx
];
r
PrevOutGrad
=
prevOutGrad
[
frameI
dx
];
r
_prev_out_grad
=
prev_out_grad
[
frame_i
dx
];
r
ResetOutputGrad
=
resetOutputGrad
[
frameI
dx
];
r
_reset_output_grad
=
reset_output_grad
[
frame_i
dx
];
}
}
op
ResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateV
alue
,
op
_reset_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_reset_gate_v
alue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputG
rad
,
r_reset_gate_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_gate
);
r_reset_output_grad
,
active_gate
);
gate
Grad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
0
]
=
r_update_gate_g
rad
;
gate
Grad
[
frameIdx
+
frameSize
*
1
]
=
rResetGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
1
]
=
r_reset_gate_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
prev
OutGrad
[
frameIdx
]
=
rPrevOutG
rad
;
prev
_out_grad
[
frame_idx
]
=
r_prev_out_g
rad
;
}
}
}
}
}
// namespace detail
}
// namespace detail
...
...
paddle/operators/math/detail/gru_kernel.h
浏览文件 @
3e552cdc
...
@@ -28,23 +28,25 @@ namespace forward {
...
@@ -28,23 +28,25 @@ namespace forward {
template
<
typename
T
>
template
<
typename
T
>
class
gru_resetOutput
{
class
gru_resetOutput
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueResetGate
,
T
&
prevOut
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_reset_gate
,
T
&
valueResetOutput
,
activation_mode_t
actGate
)
{
T
&
prev_out
,
T
&
value_reset_output
,
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
activation_mode_t
act_gate
)
{
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
valueResetOutput
=
prevOut
*
valueResetGate
;
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
prev_out
*
value_reset_gate
;
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueResetGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
prevOut
,
__m256
&
valueResetOutput
,
__m256
&
value_reset_gate
,
__m256
&
prev_out
,
activation_mode_t
actGate
)
{
__m256
&
value_reset_output
,
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
activation_mode_t
act_gate
)
{
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
valueResetOutput
=
_mm256_mul_ps
(
prevOut
,
valueResetGate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
_mm256_mul_ps
(
prev_out
,
value_reset_gate
);
}
}
#endif
#endif
#endif
#endif
...
@@ -53,24 +55,26 @@ class gru_resetOutput {
...
@@ -53,24 +55,26 @@ class gru_resetOutput {
template
<
typename
T
>
template
<
typename
T
>
class
gru_finalOutput
{
class
gru_finalOutput
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueFrameState
,
T
&
prevOut
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_frame_state
,
T
&
valueOutput
,
activation_mode_t
actInput
)
{
T
&
prev_out
,
T
&
value_output
,
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
activation_mode_t
act_input
)
{
valueOutput
=
prevOut
-
(
valueUpdateGate
*
prevOut
)
+
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
(
valueUpdateGate
*
valueFrameState
);
value_output
=
prev_out
-
(
value_update_gate
*
prev_out
)
+
(
value_update_gate
*
value_frame_state
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueFrameState
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
prevOut
,
__m256
&
valueOutput
,
__m256
&
value_frame_state
,
__m256
&
prev_out
,
activation_mode_t
actInput
)
{
__m256
&
value_output
,
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
activation_mode_t
act_input
)
{
valueOutput
=
_mm256_add_ps
(
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
_mm256_sub_ps
(
prevOut
,
_mm256_mul_ps
(
valueUpdateGate
,
prevOut
)),
value_output
=
_mm256_add_ps
(
_mm256_mul_ps
(
valueUpdateGate
,
valueFrameState
));
_mm256_sub_ps
(
prev_out
,
_mm256_mul_ps
(
value_update_gate
,
prev_out
)),
_mm256_mul_ps
(
value_update_gate
,
value_frame_state
));
}
}
#endif
#endif
#endif
#endif
...
@@ -82,34 +86,37 @@ namespace backward {
...
@@ -82,34 +86,37 @@ namespace backward {
template
<
typename
T
>
template
<
typename
T
>
class
gru_stateGrad
{
class
gru_stateGrad
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value
UpdateGate
,
T
&
gradUpdateG
ate
,
HOSTDEVICE
void
operator
()(
T
&
value
_update_gate
,
T
&
grad_update_g
ate
,
T
&
value
FrameState
,
T
&
gradFrameS
tate
,
T
&
value
_frame_state
,
T
&
grad_frame_s
tate
,
T
&
value
PrevOut
,
T
&
gradPrevOut
,
T
&
gradOutp
ut
,
T
&
value
_prev_out
,
T
&
grad_prev_o
ut
,
activation_mode_t
actI
nput
)
{
T
&
grad_output
,
activation_mode_t
act_i
nput
)
{
grad
UpdateGate
=
(
gradOutput
*
valueFrameS
tate
);
grad
_update_gate
=
(
grad_output
*
value_frame_s
tate
);
grad
UpdateGate
-=
(
gradOutput
*
valuePrevO
ut
);
grad
_update_gate
-=
(
grad_output
*
value_prev_o
ut
);
grad
PrevOut
-=
(
gradOutput
*
valueUpdateG
ate
);
grad
_prev_out
-=
(
grad_output
*
value_update_g
ate
);
grad
PrevOut
+=
gradO
utput
;
grad
_prev_out
+=
grad_o
utput
;
grad
FrameState
=
grad
_frame_state
=
activation
(
grad_output
*
value_update_gate
,
activation
(
gradOutput
*
valueUpdateGate
,
valueFrameState
,
actI
nput
);
value_frame_state
,
act_i
nput
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
valueFrameState
,
__m256
&
gradFrameState
,
__m256
&
grad_update_gate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
value_frame_state
,
__m256
&
gradOutput
,
activation_mode_t
actInput
)
{
__m256
&
grad_frame_state
,
__m256
&
value_prev_out
,
gradUpdateGate
=
_mm256_mul_ps
(
gradOutput
,
valueFrameState
);
__m256
&
grad_prev_out
,
__m256
&
grad_output
,
gradUpdateGate
=
activation_mode_t
act_input
)
{
_mm256_sub_ps
(
gradUpdateGate
,
_mm256_mul_ps
(
gradOutput
,
valuePrevOut
));
grad_update_gate
=
_mm256_mul_ps
(
grad_output
,
value_frame_state
);
gradPrevOut
=
_mm256_add_ps
(
grad_update_gate
=
_mm256_sub_ps
(
_mm256_sub_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
)),
grad_update_gate
,
_mm256_mul_ps
(
grad_output
,
value_prev_out
));
gradOutput
);
grad_prev_out
=
_mm256_add_ps
(
gradFrameState
=
activation
(
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
),
_mm256_sub_ps
(
grad_prev_out
,
valueFrameState
,
actInput
);
_mm256_mul_ps
(
grad_output
,
value_update_gate
)),
grad_output
);
grad_frame_state
=
activation
(
_mm256_mul_ps
(
grad_output
,
value_update_gate
),
value_frame_state
,
act_input
);
}
}
#endif
#endif
#endif
#endif
...
@@ -118,30 +125,32 @@ class gru_stateGrad {
...
@@ -118,30 +125,32 @@ class gru_stateGrad {
template
<
typename
T
>
template
<
typename
T
>
class
gru_resetGrad
{
class
gru_resetGrad
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
grad_update_gate
,
T
&
valueResetGate
,
T
&
gradResetGate
,
T
&
value_reset_gate
,
T
&
grad_reset_gate
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
value_prev_out
,
T
&
grad_prev_out
,
T
&
gradResetOutput
,
activation_mode_t
actGate
)
{
T
&
grad_reset_output
,
activation_mode_t
act_gate
)
{
gradResetGate
=
(
gradResetOutput
*
valuePrevOut
);
grad_reset_gate
=
(
grad_reset_output
*
value_prev_out
);
gradPrevOut
+=
(
gradResetOutput
*
valueResetGate
);
grad_prev_out
+=
(
grad_reset_output
*
value_reset_gate
);
gradUpdateGate
=
activation
(
gradUpdateGate
,
valueUpdateGate
,
actGate
);
grad_update_gate
=
gradResetGate
=
activation
(
gradResetGate
,
valueResetGate
,
actGate
);
activation
(
grad_update_gate
,
value_update_gate
,
act_gate
);
grad_reset_gate
=
activation
(
grad_reset_gate
,
value_reset_gate
,
act_gate
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
valueResetGate
,
__m256
&
gradResetGate
,
__m256
&
grad_update_gate
,
__m256
&
value_reset_gate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
grad_reset_gate
,
__m256
&
value_prev_out
,
__m256
&
gradResetOutput
,
__m256
&
grad_prev_out
,
__m256
&
grad_reset_output
,
activation_mode_t
actGate
)
{
activation_mode_t
act_gate
)
{
gradResetGate
=
_mm256_mul_ps
(
gradResetOutput
,
valuePrevOut
);
grad_reset_gate
=
_mm256_mul_ps
(
grad_reset_output
,
value_prev_out
);
gradPrevOut
=
_mm256_add_ps
(
gradPrevOut
,
grad_prev_out
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradResetOutput
,
valueResetGate
));
grad_prev_out
,
_mm256_mul_ps
(
grad_reset_output
,
value_reset_gate
));
gradUpdateGate
=
activation
(
gradUpdateGate
,
valueUpdateGate
,
actGate
);
grad_update_gate
=
gradResetGate
=
activation
(
gradResetGate
,
valueResetGate
,
actGate
);
activation
(
grad_update_gate
,
value_update_gate
,
act_gate
);
grad_reset_gate
=
activation
(
grad_reset_gate
,
value_reset_gate
,
act_gate
);
}
}
#endif
#endif
#endif
#endif
...
...
paddle/operators/math/gru_compute.cc
浏览文件 @
3e552cdc
...
@@ -21,29 +21,29 @@ namespace math {
...
@@ -21,29 +21,29 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
#ifndef __NVCC__
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
*
2
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
*
2
,
frame_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
prev
_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
value
.
gateValue
,
frameS
ize
*
3
);
1
,
value
.
gate_value
,
frame_s
ize
*
3
);
}
}
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
frame
Size
,
batchS
ize
,
active_gate
);
frame
_size
,
batch_s
ize
,
active_gate
);
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
value
.
reset
OutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
reset
_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
value
.
gateValue
+
frameSize
*
2
,
frameS
ize
*
3
);
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_s
ize
*
3
);
}
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
frame
Size
,
batchS
ize
,
active_node
);
frame
_size
,
batch_s
ize
,
active_node
);
#endif
#endif
}
}
};
};
...
@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
...
@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
);
grad
,
frame
_size
,
batch_s
ize
,
active_node
);
if
(
value
.
prev
OutValue
&&
grad
.
prevOutG
rad
)
{
if
(
value
.
prev
_out_value
&&
grad
.
prev_out_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
grad
.
gate
Grad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateW
eight
,
grad
.
gate
_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_w
eight
,
frame
Size
,
0
,
grad
.
resetOutputGrad
,
frameS
ize
);
frame
_size
,
0
,
grad
.
reset_output_grad
,
frame_s
ize
);
if
(
grad
.
state
WeightG
rad
)
{
if
(
grad
.
state
_weight_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
value
.
reset_output_value
,
frame_size
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
}
}
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
grad
,
frame
Size
,
batchS
ize
,
active_gate
);
grad
,
frame
_size
,
batch_s
ize
,
active_gate
);
if
(
grad
.
prev
OutGrad
&&
value
.
prevOutV
alue
)
{
if
(
grad
.
prev
_out_grad
&&
value
.
prev_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
*
2
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
*
2
,
1
,
grad
.
gate
Grad
,
frameSize
*
3
,
value
.
gateWeight
,
frameS
ize
*
2
,
1
,
grad
.
gate
_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_s
ize
*
2
,
1
,
grad
.
prev
OutGrad
,
frameS
ize
);
grad
.
prev
_out_grad
,
frame_s
ize
);
if
(
grad
.
gate
WeightG
rad
)
{
if
(
grad
.
gate
_weight_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frame
Size
,
frameSize
*
2
,
batchS
ize
,
1
,
context
,
true
,
false
,
frame
_size
,
frame_size
*
2
,
batch_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
grad
.
gateGrad
,
frameS
ize
*
3
,
1
,
value
.
prev
_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_s
ize
*
3
,
1
,
grad
.
gate
WeightGrad
,
frameS
ize
*
2
);
grad
.
gate
_weight_grad
,
frame_s
ize
*
2
);
}
}
}
}
#endif
#endif
...
...
paddle/operators/math/gru_compute.cu
浏览文件 @
3e552cdc
...
@@ -21,66 +21,66 @@ namespace math {
...
@@ -21,66 +21,66 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
}
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
*
2
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
*
2
,
frame_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
prev
_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
value
.
gateValue
,
frameS
ize
*
3
);
1
,
value
.
gate_value
,
frame_s
ize
*
3
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* is
B
atch= */
false
,
/* is
_b
atch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
reset
OutputValue
,
value
.
prevOutValue
,
frameSize
,
batchS
ize
,
value
.
reset
_output_value
,
value
.
prev_out_value
,
frame_s
ize
,
active_gate
);
batch_size
,
active_gate
);
}
else
{
}
else
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* is
B
atch= */
true
,
/* is
_b
atch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
reset
OutputValue
,
value
.
prevOutValue
,
frameSize
,
batchS
ize
,
value
.
reset
_output_value
,
value
.
prev_out_value
,
frame_s
ize
,
active_gate
);
batch_size
,
active_gate
);
}
}
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
value
.
reset
OutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
reset
_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
value
.
gateValue
+
frameSize
*
2
,
frameS
ize
*
3
);
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_s
ize
*
3
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is
B
atch= */
false
,
/* is
_b
atch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
prev
OutValue
,
value
.
outputValue
,
frameSize
,
batchS
ize
,
value
.
prev
_out_value
,
value
.
output_value
,
frame_size
,
batch_s
ize
,
active_node
);
active_node
);
}
else
{
}
else
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is
B
atch= */
true
,
/* is
_b
atch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
prev
OutValue
,
value
.
outputValue
,
frameSize
,
batchS
ize
,
value
.
prev
_out_value
,
value
.
output_value
,
frame_size
,
batch_s
ize
,
active_node
);
active_node
);
}
}
}
}
...
@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
...
@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruBackwardStateGrad
<
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_node
);
grad
.
output_grad
,
frame_size
,
batch_s
ize
,
active_node
);
}
else
{
}
else
{
detail
::
KeGruBackwardStateGrad
<
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_node
);
grad
.
output_grad
,
frame_size
,
batch_s
ize
,
active_node
);
}
}
if
(
value
.
prev
OutValue
&&
grad
.
prevOutG
rad
)
{
if
(
value
.
prev
_out_value
&&
grad
.
prev_out_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
grad
.
gate
Grad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateW
eight
,
grad
.
gate
_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_w
eight
,
frame
Size
,
0
,
grad
.
resetOutputGrad
,
frameS
ize
);
frame
_size
,
0
,
grad
.
reset_output_grad
,
frame_s
ize
);
if
(
grad
.
state
WeightG
rad
)
{
if
(
grad
.
state
_weight_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
value
.
reset_output_value
,
frame_size
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruBackwardResetGrad
<
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_gate
);
grad
.
reset_output_grad
,
frame_size
,
batch_s
ize
,
active_gate
);
}
else
{
}
else
{
detail
::
KeGruBackwardResetGrad
<
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_gate
);
grad
.
reset_output_grad
,
frame_size
,
batch_s
ize
,
active_gate
);
}
}
if
(
grad
.
prev
OutGrad
&&
value
.
prevOutV
alue
)
{
if
(
grad
.
prev
_out_grad
&&
value
.
prev_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
*
2
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
*
2
,
1
,
grad
.
gate
Grad
,
frameSize
*
3
,
value
.
gateWeight
,
frameS
ize
*
2
,
1
,
grad
.
gate
_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_s
ize
*
2
,
1
,
grad
.
prev
OutGrad
,
frameS
ize
);
grad
.
prev
_out_grad
,
frame_s
ize
);
if
(
grad
.
gate
WeightG
rad
)
{
if
(
grad
.
gate
_weight_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frame
Size
,
frameSize
*
2
,
batchS
ize
,
1
,
context
,
true
,
false
,
frame
_size
,
frame_size
*
2
,
batch_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
grad
.
gateGrad
,
frameS
ize
*
3
,
1
,
value
.
prev
_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_s
ize
*
3
,
1
,
grad
.
gate
WeightGrad
,
frameS
ize
*
2
);
grad
.
gate
_weight_grad
,
frame_s
ize
*
2
);
}
}
}
}
}
}
...
...
paddle/operators/math/gru_compute.h
浏览文件 @
3e552cdc
...
@@ -22,28 +22,28 @@ namespace math {
...
@@ -22,28 +22,28 @@ namespace math {
// TODO(guosheng): refine code style in gru_compute
// TODO(guosheng): refine code style in gru_compute
template
<
typename
T
>
template
<
typename
T
>
struct
hl_gru_value
{
struct
hl_gru_value
{
T
*
gate
W
eight
;
T
*
gate
_w
eight
;
T
*
state
W
eight
;
T
*
state
_w
eight
;
T
*
gate
V
alue
;
T
*
gate
_v
alue
;
T
*
reset
OutputV
alue
;
T
*
reset
_output_v
alue
;
T
*
output
V
alue
;
T
*
output
_v
alue
;
T
*
prev
OutV
alue
;
T
*
prev
_out_v
alue
;
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
hl_gru_grad
{
struct
hl_gru_grad
{
T
*
gate
WeightG
rad
;
T
*
gate
_weight_g
rad
;
T
*
state
WeightG
rad
;
T
*
state
_weight_g
rad
;
T
*
gate
G
rad
;
T
*
gate
_g
rad
;
T
*
reset
OutputG
rad
;
T
*
reset
_output_g
rad
;
T
*
output
G
rad
;
T
*
output
_g
rad
;
T
*
prev
OutG
rad
;
T
*
prev
_out_g
rad
;
};
};
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
struct
GRUUnitFunctor
{
struct
GRUUnitFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
activation_mode_t
active_gate
);
};
};
...
@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
...
@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
struct
GRUUnitGradFunctor
{
struct
GRUUnitGradFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
activation_mode_t
active_gate
);
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录