Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3e552cdc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e552cdc
编写于
11月 29, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix gru_op related code style
上级
dcf3ffd9
变更
7
展开全部
显示空白变更内容
内联
并排
Showing
7 changed file
with
617 addition
and
599 deletion
+617
-599
paddle/operators/gru_op.h
paddle/operators/gru_op.h
+23
-23
paddle/operators/math/detail/gru_cpu_kernel.h
paddle/operators/math/detail/gru_cpu_kernel.h
+272
-268
paddle/operators/math/detail/gru_gpu_kernel.h
paddle/operators/math/detail/gru_gpu_kernel.h
+126
-126
paddle/operators/math/detail/gru_kernel.h
paddle/operators/math/detail/gru_kernel.h
+72
-63
paddle/operators/math/gru_compute.cc
paddle/operators/math/gru_compute.cc
+33
-31
paddle/operators/math/gru_compute.cu
paddle/operators/math/gru_compute.cu
+75
-73
paddle/operators/math/gru_compute.h
paddle/operators/math/gru_compute.h
+16
-15
未找到文件。
paddle/operators/gru_op.h
浏览文件 @
3e552cdc
...
@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
int
frame_size
=
hidden_dims
[
1
];
int
frame_size
=
hidden_dims
[
1
];
math
::
hl_gru_value
<
T
>
gru_value
;
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gate
W
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
gate
_w
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
W
eight
=
gru_value
.
state
_w
eight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
Tensor
ordered_h0
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
...
@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
// to reorder.
// to reorder.
ReorderInitState
<
Place
,
T
>
(
context
.
device_context
(),
*
h0
,
order
,
ReorderInitState
<
Place
,
T
>
(
context
.
device_context
(),
*
h0
,
order
,
&
ordered_h0
,
true
);
&
ordered_h0
,
true
);
gru_value
.
prev
OutV
alue
=
ordered_h0
.
data
<
T
>
();
gru_value
.
prev
_out_v
alue
=
ordered_h0
.
data
<
T
>
();
}
else
{
}
else
{
gru_value
.
prev
OutV
alue
=
nullptr
;
gru_value
.
prev
_out_v
alue
=
nullptr
;
}
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
...
@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output
V
alue
=
hidden_t
.
data
<
T
>
();
gru_value
.
output
_v
alue
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
reset
OutputV
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset
_output_v
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
Place
,
T
>::
compute
(
math
::
GRUUnitFunctor
<
Place
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
gru_value
.
prev
OutValue
=
gru_value
.
outputV
alue
;
gru_value
.
prev
_out_value
=
gru_value
.
output_v
alue
;
}
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
...
@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
math
::
hl_gru_value
<
T
>
gru_value
;
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gate
W
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
gate
_w
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
W
eight
=
gru_value
.
state
_w
eight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
math
::
hl_gru_grad
<
T
>
gru_grad
;
math
::
hl_gru_grad
<
T
>
gru_grad
;
if
(
weight_grad
)
{
if
(
weight_grad
)
{
gru_grad
.
gate
WeightG
rad
=
gru_grad
.
gate
_weight_g
rad
=
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
zero
(
dev_ctx
,
weight_grad
,
static_cast
<
T
>
(
0.0
));
zero
(
dev_ctx
,
weight_grad
,
static_cast
<
T
>
(
0.0
));
gru_grad
.
state
WeightG
rad
=
gru_grad
.
state
_weight_g
rad
=
weight_grad
->
data
<
T
>
()
+
2
*
frame_size
*
frame_size
;
weight_grad
->
data
<
T
>
()
+
2
*
frame_size
*
frame_size
;
}
else
{
}
else
{
gru_grad
.
gate
WeightG
rad
=
nullptr
;
gru_grad
.
gate
_weight_g
rad
=
nullptr
;
gru_grad
.
state
WeightG
rad
=
nullptr
;
gru_grad
.
state
_weight_g
rad
=
nullptr
;
}
}
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
...
@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
gru_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
gru_value
.
reset
OutputV
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset
_output_v
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_grad_t
=
batch_hidden_grad
.
Slice
(
bstart
,
bend
);
Tensor
hidden_grad_t
=
batch_hidden_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
output
G
rad
=
hidden_grad_t
.
data
<
T
>
();
gru_grad
.
output
_g
rad
=
hidden_grad_t
.
data
<
T
>
();
Tensor
gate_grad_t
=
batch_gate_grad
.
Slice
(
bstart
,
bend
);
Tensor
gate_grad_t
=
batch_gate_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
gate
G
rad
=
gate_grad_t
.
data
<
T
>
();
gru_grad
.
gate
_g
rad
=
gate_grad_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_grad_t
=
Tensor
reset_hidden_prev_grad_t
=
batch_reset_hidden_prev_grad
.
Slice
(
bstart
,
bend
);
batch_reset_hidden_prev_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
reset
OutputG
rad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
gru_grad
.
reset
_output_g
rad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
if
(
n
==
0
)
{
if
(
n
==
0
)
{
gru_value
.
prev
OutV
alue
=
h0
?
ordered_h0
.
data
<
T
>
()
:
nullptr
;
gru_value
.
prev
_out_v
alue
=
h0
?
ordered_h0
.
data
<
T
>
()
:
nullptr
;
gru_grad
.
prev
OutG
rad
=
gru_grad
.
prev
_out_g
rad
=
h0
&&
h0_grad
?
ordered_h0_grad
.
data
<
T
>
()
:
nullptr
;
h0
&&
h0_grad
?
ordered_h0_grad
.
data
<
T
>
()
:
nullptr
;
}
else
{
}
else
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
hidden_prev_t
=
batch_hidden
->
Slice
(
bstart_pre
,
bstart
);
Tensor
hidden_prev_t
=
batch_hidden
->
Slice
(
bstart_pre
,
bstart
);
gru_value
.
prev
OutV
alue
=
hidden_prev_t
.
data
<
T
>
();
gru_value
.
prev
_out_v
alue
=
hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_prev_grad_t
=
batch_hidden_grad
.
Slice
(
bstart_pre
,
bstart
);
Tensor
hidden_prev_grad_t
=
batch_hidden_grad
.
Slice
(
bstart_pre
,
bstart
);
gru_grad
.
prev
OutG
rad
=
hidden_prev_grad_t
.
data
<
T
>
();
gru_grad
.
prev
_out_g
rad
=
hidden_prev_grad_t
.
data
<
T
>
();
}
}
math
::
GRUUnitGradFunctor
<
Place
,
T
>::
compute
(
math
::
GRUUnitGradFunctor
<
Place
,
T
>::
compute
(
...
...
paddle/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
3e552cdc
此差异已折叠。
点击以展开。
paddle/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
3e552cdc
...
@@ -27,174 +27,174 @@ namespace math {
...
@@ -27,174 +27,174 @@ namespace math {
namespace
detail
{
namespace
detail
{
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpResetOutput
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpResetOutput
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
op
ResetO
utput
,
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
op
_reset_o
utput
,
T
*
gate
Value
,
T
*
resetOutputV
alue
,
T
*
gate
_value
,
T
*
reset_output_v
alue
,
T
*
prev
OutputValue
,
int
frameS
ize
,
T
*
prev
_output_value
,
int
frame_s
ize
,
int
batch
S
ize
,
int
batch
_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
reset
OutputValue
+=
batchIdx
*
frameS
ize
;
reset
_output_value
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
PrevO
ut
=
0
;
T
r
_prev_o
ut
=
0
;
T
r
ValueResetO
utput
;
T
r
_value_reset_o
utput
;
T
r
ValueUpdateGate
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_value_update_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ValueResetGate
=
gateValue
[
frameIdx
+
frameS
ize
*
1
];
T
r
_value_reset_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
1
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
if
(
is
Batch
)
prevOutputValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_output_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOut
=
prevOutputValue
[
frameI
dx
];
r
_prev_out
=
prev_output_value
[
frame_i
dx
];
}
}
op
ResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutp
ut
,
op
_reset_output
(
r_value_update_gate
,
r_value_reset_gate
,
r_prev_o
ut
,
active_gate
);
r_value_reset_output
,
active_gate
);
gate
Value
[
frameIdx
+
frameSize
*
0
]
=
rValueUpdateG
ate
;
gate
_value
[
frame_idx
+
frame_size
*
0
]
=
r_value_update_g
ate
;
gate
Value
[
frameIdx
+
frameSize
*
1
]
=
rValueResetG
ate
;
gate
_value
[
frame_idx
+
frame_size
*
1
]
=
r_value_reset_g
ate
;
reset
OutputValue
[
frameIdx
]
=
rValueResetO
utput
;
reset
_output_value
[
frame_idx
]
=
r_value_reset_o
utput
;
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpFinalOutput
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpFinalOutput
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
op
FinalO
utput
,
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
op
_final_o
utput
,
T
*
gate
Value
,
T
*
prevOutputV
alue
,
T
*
gate
_value
,
T
*
prev_output_v
alue
,
T
*
output
Value
,
int
frameS
ize
,
T
*
output
_value
,
int
frame_s
ize
,
int
batch
S
ize
,
int
batch
_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
output
Value
+=
batchIdx
*
frameS
ize
;
output
_value
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
O
utput
;
T
r
_o
utput
;
T
r
PrevO
ut
=
0
;
T
r
_prev_o
ut
=
0
;
T
r
ValueUpdateGate
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_value_update_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ValueFrameState
=
gateValue
[
frameIdx
+
frameS
ize
*
2
];
T
r
_value_frame_state
=
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
if
(
is
Batch
)
prevOutputValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_output_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOut
=
prevOutputValue
[
frameI
dx
];
r
_prev_out
=
prev_output_value
[
frame_i
dx
];
}
}
op
FinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutp
ut
,
op
_final_output
(
r_value_update_gate
,
r_value_frame_state
,
r_prev_o
ut
,
active_node
);
r_output
,
active_node
);
gate
Value
[
frameIdx
+
frameSize
*
2
]
=
rValueFrameS
tate
;
gate
_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_frame_s
tate
;
output
Value
[
frameIdx
]
=
rO
utput
;
output
_value
[
frame_idx
]
=
r_o
utput
;
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpStateGrad
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpStateGrad
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
op
StateGrad
,
T
*
gateV
alue
,
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
op
_state_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
outputG
rad
,
T
*
prev
_out_grad
,
T
*
output_g
rad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
gate
Grad
+=
batchIdx
*
3
*
frameS
ize
;
gate
_grad
+=
batch_idx
*
3
*
frame_s
ize
;
output
Grad
+=
batchIdx
*
frameS
ize
;
output
_grad
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
UpdateGateG
rad
;
T
r
_update_gate_g
rad
;
T
r
FrameStateG
rad
;
T
r
_frame_state_g
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
r
UpdateGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
FrameStateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
2
];
T
r
_frame_state_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
T
r
OutGrad
=
outputGrad
[
frameI
dx
];
T
r
_out_grad
=
output_grad
[
frame_i
dx
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
if
(
is
Batch
)
prevOutValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOutValue
=
prevOutValue
[
frameI
dx
];
r
_prev_out_value
=
prev_out_value
[
frame_i
dx
];
if
(
is
Batch
)
prevOutGrad
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_grad
+=
batch_idx
*
frame_s
ize
;
r
PrevOutGrad
=
prevOutGrad
[
frameI
dx
];
r
_prev_out_grad
=
prev_out_grad
[
frame_i
dx
];
}
}
op
StateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateV
alue
,
op
_state_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_frame_state_v
alue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutG
rad
,
r_frame_state_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_node
);
r_out_grad
,
active_node
);
gate
Grad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
0
]
=
r_update_gate_g
rad
;
gate
Grad
[
frameIdx
+
frameSize
*
2
]
=
rFrameStateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
2
]
=
r_frame_state_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
prev
OutGrad
[
frameIdx
]
=
rPrevOutG
rad
;
prev
_out_grad
[
frame_idx
]
=
r_prev_out_g
rad
;
}
}
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpResetGrad
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpResetGrad
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
op
ResetGrad
,
T
*
gateV
alue
,
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
op
_reset_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
resetOutputG
rad
,
T
*
prev
_out_grad
,
T
*
reset_output_g
rad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
gate
Grad
+=
batchIdx
*
3
*
frameS
ize
;
gate
_grad
+=
batch_idx
*
3
*
frame_s
ize
;
reset
OutputGrad
+=
batchIdx
*
frameS
ize
;
reset
_output_grad
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
ResetGateG
rad
;
T
r
_reset_gate_g
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
r
ResetOutputG
rad
=
0
;
T
r
_reset_output_g
rad
=
0
;
T
r
UpdateGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
UpdateGateGrad
=
gateGrad
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_grad
=
gate_grad
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ResetGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
1
];
T
r
_reset_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
1
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
if
(
is
Batch
)
prevOutValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_value
+=
batch_idx
*
frame_s
ize
;
if
(
is
Batch
)
prevOutGrad
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_grad
+=
batch_idx
*
frame_s
ize
;
r
PrevOutValue
=
prevOutValue
[
frameI
dx
];
r
_prev_out_value
=
prev_out_value
[
frame_i
dx
];
r
PrevOutGrad
=
prevOutGrad
[
frameI
dx
];
r
_prev_out_grad
=
prev_out_grad
[
frame_i
dx
];
r
ResetOutputGrad
=
resetOutputGrad
[
frameI
dx
];
r
_reset_output_grad
=
reset_output_grad
[
frame_i
dx
];
}
}
op
ResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateV
alue
,
op
_reset_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_reset_gate_v
alue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputG
rad
,
r_reset_gate_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_gate
);
r_reset_output_grad
,
active_gate
);
gate
Grad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
0
]
=
r_update_gate_g
rad
;
gate
Grad
[
frameIdx
+
frameSize
*
1
]
=
rResetGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
1
]
=
r_reset_gate_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
prev
OutGrad
[
frameIdx
]
=
rPrevOutG
rad
;
prev
_out_grad
[
frame_idx
]
=
r_prev_out_g
rad
;
}
}
}
}
}
// namespace detail
}
// namespace detail
...
...
paddle/operators/math/detail/gru_kernel.h
浏览文件 @
3e552cdc
...
@@ -28,23 +28,25 @@ namespace forward {
...
@@ -28,23 +28,25 @@ namespace forward {
template
<
typename
T
>
template
<
typename
T
>
class
gru_resetOutput
{
class
gru_resetOutput
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueResetGate
,
T
&
prevOut
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_reset_gate
,
T
&
valueResetOutput
,
activation_mode_t
actGate
)
{
T
&
prev_out
,
T
&
value_reset_output
,
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
activation_mode_t
act_gate
)
{
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
valueResetOutput
=
prevOut
*
valueResetGate
;
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
prev_out
*
value_reset_gate
;
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueResetGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
prevOut
,
__m256
&
valueResetOutput
,
__m256
&
value_reset_gate
,
__m256
&
prev_out
,
activation_mode_t
actGate
)
{
__m256
&
value_reset_output
,
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
activation_mode_t
act_gate
)
{
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
valueResetOutput
=
_mm256_mul_ps
(
prevOut
,
valueResetGate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
_mm256_mul_ps
(
prev_out
,
value_reset_gate
);
}
}
#endif
#endif
#endif
#endif
...
@@ -53,24 +55,26 @@ class gru_resetOutput {
...
@@ -53,24 +55,26 @@ class gru_resetOutput {
template
<
typename
T
>
template
<
typename
T
>
class
gru_finalOutput
{
class
gru_finalOutput
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueFrameState
,
T
&
prevOut
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_frame_state
,
T
&
valueOutput
,
activation_mode_t
actInput
)
{
T
&
prev_out
,
T
&
value_output
,
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
activation_mode_t
act_input
)
{
valueOutput
=
prevOut
-
(
valueUpdateGate
*
prevOut
)
+
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
(
valueUpdateGate
*
valueFrameState
);
value_output
=
prev_out
-
(
value_update_gate
*
prev_out
)
+
(
value_update_gate
*
value_frame_state
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueFrameState
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
prevOut
,
__m256
&
valueOutput
,
__m256
&
value_frame_state
,
__m256
&
prev_out
,
activation_mode_t
actInput
)
{
__m256
&
value_output
,
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
activation_mode_t
act_input
)
{
valueOutput
=
_mm256_add_ps
(
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
_mm256_sub_ps
(
prevOut
,
_mm256_mul_ps
(
valueUpdateGate
,
prevOut
)),
value_output
=
_mm256_add_ps
(
_mm256_mul_ps
(
valueUpdateGate
,
valueFrameState
));
_mm256_sub_ps
(
prev_out
,
_mm256_mul_ps
(
value_update_gate
,
prev_out
)),
_mm256_mul_ps
(
value_update_gate
,
value_frame_state
));
}
}
#endif
#endif
#endif
#endif
...
@@ -82,34 +86,37 @@ namespace backward {
...
@@ -82,34 +86,37 @@ namespace backward {
template
<
typename
T
>
template
<
typename
T
>
class
gru_stateGrad
{
class
gru_stateGrad
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value
UpdateGate
,
T
&
gradUpdateG
ate
,
HOSTDEVICE
void
operator
()(
T
&
value
_update_gate
,
T
&
grad_update_g
ate
,
T
&
value
FrameState
,
T
&
gradFrameS
tate
,
T
&
value
_frame_state
,
T
&
grad_frame_s
tate
,
T
&
value
PrevOut
,
T
&
gradPrevOut
,
T
&
gradOutp
ut
,
T
&
value
_prev_out
,
T
&
grad_prev_o
ut
,
activation_mode_t
actI
nput
)
{
T
&
grad_output
,
activation_mode_t
act_i
nput
)
{
grad
UpdateGate
=
(
gradOutput
*
valueFrameS
tate
);
grad
_update_gate
=
(
grad_output
*
value_frame_s
tate
);
grad
UpdateGate
-=
(
gradOutput
*
valuePrevO
ut
);
grad
_update_gate
-=
(
grad_output
*
value_prev_o
ut
);
grad
PrevOut
-=
(
gradOutput
*
valueUpdateG
ate
);
grad
_prev_out
-=
(
grad_output
*
value_update_g
ate
);
grad
PrevOut
+=
gradO
utput
;
grad
_prev_out
+=
grad_o
utput
;
grad
FrameState
=
grad
_frame_state
=
activation
(
grad_output
*
value_update_gate
,
activation
(
gradOutput
*
valueUpdateGate
,
valueFrameState
,
actI
nput
);
value_frame_state
,
act_i
nput
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
valueFrameState
,
__m256
&
gradFrameState
,
__m256
&
grad_update_gate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
value_frame_state
,
__m256
&
gradOutput
,
activation_mode_t
actInput
)
{
__m256
&
grad_frame_state
,
__m256
&
value_prev_out
,
gradUpdateGate
=
_mm256_mul_ps
(
gradOutput
,
valueFrameState
);
__m256
&
grad_prev_out
,
__m256
&
grad_output
,
gradUpdateGate
=
activation_mode_t
act_input
)
{
_mm256_sub_ps
(
gradUpdateGate
,
_mm256_mul_ps
(
gradOutput
,
valuePrevOut
));
grad_update_gate
=
_mm256_mul_ps
(
grad_output
,
value_frame_state
);
gradPrevOut
=
_mm256_add_ps
(
grad_update_gate
=
_mm256_sub_ps
(
_mm256_sub_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
)),
grad_update_gate
,
_mm256_mul_ps
(
grad_output
,
value_prev_out
));
gradOutput
);
grad_prev_out
=
_mm256_add_ps
(
gradFrameState
=
activation
(
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
),
_mm256_sub_ps
(
grad_prev_out
,
valueFrameState
,
actInput
);
_mm256_mul_ps
(
grad_output
,
value_update_gate
)),
grad_output
);
grad_frame_state
=
activation
(
_mm256_mul_ps
(
grad_output
,
value_update_gate
),
value_frame_state
,
act_input
);
}
}
#endif
#endif
#endif
#endif
...
@@ -118,30 +125,32 @@ class gru_stateGrad {
...
@@ -118,30 +125,32 @@ class gru_stateGrad {
template
<
typename
T
>
template
<
typename
T
>
class
gru_resetGrad
{
class
gru_resetGrad
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
grad_update_gate
,
T
&
valueResetGate
,
T
&
gradResetGate
,
T
&
value_reset_gate
,
T
&
grad_reset_gate
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
value_prev_out
,
T
&
grad_prev_out
,
T
&
gradResetOutput
,
activation_mode_t
actGate
)
{
T
&
grad_reset_output
,
activation_mode_t
act_gate
)
{
gradResetGate
=
(
gradResetOutput
*
valuePrevOut
);
grad_reset_gate
=
(
grad_reset_output
*
value_prev_out
);
gradPrevOut
+=
(
gradResetOutput
*
valueResetGate
);
grad_prev_out
+=
(
grad_reset_output
*
value_reset_gate
);
gradUpdateGate
=
activation
(
gradUpdateGate
,
valueUpdateGate
,
actGate
);
grad_update_gate
=
gradResetGate
=
activation
(
gradResetGate
,
valueResetGate
,
actGate
);
activation
(
grad_update_gate
,
value_update_gate
,
act_gate
);
grad_reset_gate
=
activation
(
grad_reset_gate
,
value_reset_gate
,
act_gate
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
valueResetGate
,
__m256
&
gradResetGate
,
__m256
&
grad_update_gate
,
__m256
&
value_reset_gate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
grad_reset_gate
,
__m256
&
value_prev_out
,
__m256
&
gradResetOutput
,
__m256
&
grad_prev_out
,
__m256
&
grad_reset_output
,
activation_mode_t
actGate
)
{
activation_mode_t
act_gate
)
{
gradResetGate
=
_mm256_mul_ps
(
gradResetOutput
,
valuePrevOut
);
grad_reset_gate
=
_mm256_mul_ps
(
grad_reset_output
,
value_prev_out
);
gradPrevOut
=
_mm256_add_ps
(
gradPrevOut
,
grad_prev_out
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradResetOutput
,
valueResetGate
));
grad_prev_out
,
_mm256_mul_ps
(
grad_reset_output
,
value_reset_gate
));
gradUpdateGate
=
activation
(
gradUpdateGate
,
valueUpdateGate
,
actGate
);
grad_update_gate
=
gradResetGate
=
activation
(
gradResetGate
,
valueResetGate
,
actGate
);
activation
(
grad_update_gate
,
value_update_gate
,
act_gate
);
grad_reset_gate
=
activation
(
grad_reset_gate
,
value_reset_gate
,
act_gate
);
}
}
#endif
#endif
#endif
#endif
...
...
paddle/operators/math/gru_compute.cc
浏览文件 @
3e552cdc
...
@@ -21,29 +21,29 @@ namespace math {
...
@@ -21,29 +21,29 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
#ifndef __NVCC__
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
*
2
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
*
2
,
frame_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
prev
_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
value
.
gateValue
,
frameS
ize
*
3
);
1
,
value
.
gate_value
,
frame_s
ize
*
3
);
}
}
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
frame
Size
,
batchS
ize
,
active_gate
);
frame
_size
,
batch_s
ize
,
active_gate
);
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
value
.
reset
OutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
reset
_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
value
.
gateValue
+
frameSize
*
2
,
frameS
ize
*
3
);
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_s
ize
*
3
);
}
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
frame
Size
,
batchS
ize
,
active_node
);
frame
_size
,
batch_s
ize
,
active_node
);
#endif
#endif
}
}
};
};
...
@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
...
@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
);
grad
,
frame
_size
,
batch_s
ize
,
active_node
);
if
(
value
.
prev
OutValue
&&
grad
.
prevOutG
rad
)
{
if
(
value
.
prev
_out_value
&&
grad
.
prev_out_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
grad
.
gate
Grad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateW
eight
,
grad
.
gate
_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_w
eight
,
frame
Size
,
0
,
grad
.
resetOutputGrad
,
frameS
ize
);
frame
_size
,
0
,
grad
.
reset_output_grad
,
frame_s
ize
);
if
(
grad
.
state
WeightG
rad
)
{
if
(
grad
.
state
_weight_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
value
.
reset_output_value
,
frame_size
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
}
}
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
grad
,
frame
Size
,
batchS
ize
,
active_gate
);
grad
,
frame
_size
,
batch_s
ize
,
active_gate
);
if
(
grad
.
prev
OutGrad
&&
value
.
prevOutV
alue
)
{
if
(
grad
.
prev
_out_grad
&&
value
.
prev_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
*
2
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
*
2
,
1
,
grad
.
gate
Grad
,
frameSize
*
3
,
value
.
gateWeight
,
frameS
ize
*
2
,
1
,
grad
.
gate
_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_s
ize
*
2
,
1
,
grad
.
prev
OutGrad
,
frameS
ize
);
grad
.
prev
_out_grad
,
frame_s
ize
);
if
(
grad
.
gate
WeightG
rad
)
{
if
(
grad
.
gate
_weight_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frame
Size
,
frameSize
*
2
,
batchS
ize
,
1
,
context
,
true
,
false
,
frame
_size
,
frame_size
*
2
,
batch_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
grad
.
gateGrad
,
frameS
ize
*
3
,
1
,
value
.
prev
_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_s
ize
*
3
,
1
,
grad
.
gate
WeightGrad
,
frameS
ize
*
2
);
grad
.
gate
_weight_grad
,
frame_s
ize
*
2
);
}
}
}
}
#endif
#endif
...
...
paddle/operators/math/gru_compute.cu
浏览文件 @
3e552cdc
...
@@ -21,66 +21,66 @@ namespace math {
...
@@ -21,66 +21,66 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
}
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
*
2
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
*
2
,
frame_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
prev
_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
value
.
gateValue
,
frameS
ize
*
3
);
1
,
value
.
gate_value
,
frame_s
ize
*
3
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* is
B
atch= */
false
,
/* is
_b
atch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
reset
OutputValue
,
value
.
prevOutValue
,
frameSize
,
batchS
ize
,
value
.
reset
_output_value
,
value
.
prev_out_value
,
frame_s
ize
,
active_gate
);
batch_size
,
active_gate
);
}
else
{
}
else
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* is
B
atch= */
true
,
/* is
_b
atch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
reset
OutputValue
,
value
.
prevOutValue
,
frameSize
,
batchS
ize
,
value
.
reset
_output_value
,
value
.
prev_out_value
,
frame_s
ize
,
active_gate
);
batch_size
,
active_gate
);
}
}
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
value
.
reset
OutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
reset
_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
value
.
gateValue
+
frameSize
*
2
,
frameS
ize
*
3
);
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_s
ize
*
3
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is
B
atch= */
false
,
/* is
_b
atch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
prev
OutValue
,
value
.
outputValue
,
frameSize
,
batchS
ize
,
value
.
prev
_out_value
,
value
.
output_value
,
frame_size
,
batch_s
ize
,
active_node
);
active_node
);
}
else
{
}
else
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is
B
atch= */
true
,
/* is
_b
atch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
prev
OutValue
,
value
.
outputValue
,
frameSize
,
batchS
ize
,
value
.
prev
_out_value
,
value
.
output_value
,
frame_size
,
batch_s
ize
,
active_node
);
active_node
);
}
}
}
}
...
@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
...
@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruBackwardStateGrad
<
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_node
);
grad
.
output_grad
,
frame_size
,
batch_s
ize
,
active_node
);
}
else
{
}
else
{
detail
::
KeGruBackwardStateGrad
<
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_node
);
grad
.
output_grad
,
frame_size
,
batch_s
ize
,
active_node
);
}
}
if
(
value
.
prev
OutValue
&&
grad
.
prevOutG
rad
)
{
if
(
value
.
prev
_out_value
&&
grad
.
prev_out_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
grad
.
gate
Grad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateW
eight
,
grad
.
gate
_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_w
eight
,
frame
Size
,
0
,
grad
.
resetOutputGrad
,
frameS
ize
);
frame
_size
,
0
,
grad
.
reset_output_grad
,
frame_s
ize
);
if
(
grad
.
state
WeightG
rad
)
{
if
(
grad
.
state
_weight_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
value
.
reset_output_value
,
frame_size
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruBackwardResetGrad
<
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_gate
);
grad
.
reset_output_grad
,
frame_size
,
batch_s
ize
,
active_gate
);
}
else
{
}
else
{
detail
::
KeGruBackwardResetGrad
<
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_gate
);
grad
.
reset_output_grad
,
frame_size
,
batch_s
ize
,
active_gate
);
}
}
if
(
grad
.
prev
OutGrad
&&
value
.
prevOutV
alue
)
{
if
(
grad
.
prev
_out_grad
&&
value
.
prev_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
*
2
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
*
2
,
1
,
grad
.
gate
Grad
,
frameSize
*
3
,
value
.
gateWeight
,
frameS
ize
*
2
,
1
,
grad
.
gate
_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_s
ize
*
2
,
1
,
grad
.
prev
OutGrad
,
frameS
ize
);
grad
.
prev
_out_grad
,
frame_s
ize
);
if
(
grad
.
gate
WeightG
rad
)
{
if
(
grad
.
gate
_weight_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frame
Size
,
frameSize
*
2
,
batchS
ize
,
1
,
context
,
true
,
false
,
frame
_size
,
frame_size
*
2
,
batch_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
grad
.
gateGrad
,
frameS
ize
*
3
,
1
,
value
.
prev
_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_s
ize
*
3
,
1
,
grad
.
gate
WeightGrad
,
frameS
ize
*
2
);
grad
.
gate
_weight_grad
,
frame_s
ize
*
2
);
}
}
}
}
}
}
...
...
paddle/operators/math/gru_compute.h
浏览文件 @
3e552cdc
...
@@ -22,28 +22,28 @@ namespace math {
...
@@ -22,28 +22,28 @@ namespace math {
// TODO(guosheng): refine code style in gru_compute
// TODO(guosheng): refine code style in gru_compute
template
<
typename
T
>
template
<
typename
T
>
struct
hl_gru_value
{
struct
hl_gru_value
{
T
*
gate
W
eight
;
T
*
gate
_w
eight
;
T
*
state
W
eight
;
T
*
state
_w
eight
;
T
*
gate
V
alue
;
T
*
gate
_v
alue
;
T
*
reset
OutputV
alue
;
T
*
reset
_output_v
alue
;
T
*
output
V
alue
;
T
*
output
_v
alue
;
T
*
prev
OutV
alue
;
T
*
prev
_out_v
alue
;
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
hl_gru_grad
{
struct
hl_gru_grad
{
T
*
gate
WeightG
rad
;
T
*
gate
_weight_g
rad
;
T
*
state
WeightG
rad
;
T
*
state
_weight_g
rad
;
T
*
gate
G
rad
;
T
*
gate
_g
rad
;
T
*
reset
OutputG
rad
;
T
*
reset
_output_g
rad
;
T
*
output
G
rad
;
T
*
output
_g
rad
;
T
*
prev
OutG
rad
;
T
*
prev
_out_g
rad
;
};
};
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
struct
GRUUnitFunctor
{
struct
GRUUnitFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
activation_mode_t
active_gate
);
};
};
...
@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
...
@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
struct
GRUUnitGradFunctor
{
struct
GRUUnitGradFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
activation_mode_t
active_gate
);
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录