Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
18c322c2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
18c322c2
编写于
8月 06, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
seperate cpu and gpu implementations for gru kernel compute
上级
54c95e49
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
225 addition
and
126 deletion
+225
-126
paddle/fluid/operators/gru_op.cc
paddle/fluid/operators/gru_op.cc
+135
-3
paddle/fluid/operators/gru_op.cu.cc
paddle/fluid/operators/gru_op.cu.cc
+90
-0
paddle/fluid/operators/gru_op.h
paddle/fluid/operators/gru_op.h
+0
-123
未找到文件。
paddle/fluid/operators/gru_op.cc
浏览文件 @
18c322c2
...
@@ -211,6 +211,139 @@ class GRUGradOp : public framework::OperatorWithKernel {
...
@@ -211,6 +211,139 @@ class GRUGradOp : public framework::OperatorWithKernel {
}
}
};
};
template
<
typename
T
>
class
GRUCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
context
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_reset_hidden_prev
=
context
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_hidden
=
context
.
Output
<
LoDTensor
>
(
"BatchHidden"
);
batch_hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
hidden
=
context
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
hidden_dims
=
hidden
->
dims
();
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
to_batch
(
dev_ctx
,
*
input
,
batch_gate
,
true
,
is_reverse
);
if
(
bias
)
{
math
::
RowwiseAdd
<
DeviceContext
,
T
>
add_bias
;
add_bias
(
dev_ctx
,
*
batch_gate
,
*
bias
,
batch_gate
);
}
int
frame_size
=
hidden_dims
[
1
];
math
::
GRUMetaValue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state_weight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
if
(
h0
)
{
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
&
ordered_h0
,
true
);
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
));
#ifdef PADDLE_WITH_MKLML
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
// TODO(TJ): make a class
T
*
packed_gate
=
blas
.
GEMM_ALLOC
(
CblasBMatrix
,
1
/*height of C*/
,
frame_size
*
2
/*width of weight*/
,
frame_size
/*height of height*/
);
PADDLE_ENFORCE
(
packed_gate
);
blas
.
GEMM_PACK
(
CblasBMatrix
,
CblasNoTrans
,
1
/*cur bs?*/
,
frame_size
*
2
,
frame_size
,
T
(
1.0
),
gru_value
.
gate_weight
,
frame_size
*
2
,
packed_gate
);
T
*
packed_state
=
blas
.
GEMM_ALLOC
(
CblasBMatrix
,
1
/*height of C*/
,
frame_size
/*width of weight*/
,
frame_size
/*height of height*/
);
PADDLE_ENFORCE
(
packed_state
);
blas
.
GEMM_PACK
(
CblasBMatrix
,
CblasNoTrans
,
1
/*cur bs?*/
,
frame_size
,
frame_size
,
T
(
1.0
),
gru_value
.
state_weight
,
frame_size
,
packed_state
);
#endif
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output_value
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
#ifdef PADDLE_WITH_MKLML
if
(
gru_value
.
prev_out_value
)
{
blas
.
GEMM_COMPUTE
(
CblasNoTrans
,
CblasPacked
,
cur_batch_size
,
frame_size
*
2
,
frame_size
,
gru_value
.
prev_out_value
,
frame_size
,
packed_gate
,
frame_size
*
2
,
T
(
1
),
gru_value
.
gate_value
,
frame_size
*
3
);
}
math
::
detail
::
forward_reset_output
(
math
::
detail
::
forward
::
gru_resetOutput
<
T
>
(),
gru_value
,
frame_size
,
cur_batch_size
,
active_gate
);
if
(
gru_value
.
prev_out_value
)
{
blas
.
GEMM_COMPUTE
(
CblasNoTrans
,
CblasPacked
,
cur_batch_size
,
frame_size
,
frame_size
,
gru_value
.
reset_output_value
,
frame_size
,
packed_state
,
frame_size
,
T
(
1
),
gru_value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
}
math
::
detail
::
forward_final_output
(
math
::
detail
::
forward
::
gru_finalOutput
<
T
>
(),
gru_value
,
frame_size
,
cur_batch_size
,
active_node
);
#else
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
#endif
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
#ifdef PADDLE_WITH_MKLML
blas
.
GEMM_FREE
(
packed_gate
);
blas
.
GEMM_FREE
(
packed_state
);
#endif
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
BatchCompute
(
context
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -218,9 +351,8 @@ namespace ops = paddle::operators;
...
@@ -218,9 +351,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
gru
,
ops
::
GRUOp
,
ops
::
GRUOpMaker
,
REGISTER_OPERATOR
(
gru
,
ops
::
GRUOp
,
ops
::
GRUOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
gru_grad
,
ops
::
GRUGradOp
);
REGISTER_OPERATOR
(
gru_grad
,
ops
::
GRUGradOp
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
gru
,
ops
::
GRUCPUKernel
<
float
>
,
gru
,
ops
::
GRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GRUCPUKernel
<
double
>
);
ops
::
GRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
gru_grad
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
gru_grad
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
ops
::
GRUGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/gru_op.cu.cc
浏览文件 @
18c322c2
...
@@ -14,6 +14,96 @@ limitations under the License. */
...
@@ -14,6 +14,96 @@ limitations under the License. */
#include "paddle/fluid/operators/gru_op.h"
#include "paddle/fluid/operators/gru_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
GRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
context
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_reset_hidden_prev
=
context
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_hidden
=
context
.
Output
<
LoDTensor
>
(
"BatchHidden"
);
batch_hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
hidden
=
context
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
hidden_dims
=
hidden
->
dims
();
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
to_batch
(
dev_ctx
,
*
input
,
batch_gate
,
true
,
is_reverse
);
if
(
bias
)
{
math
::
RowwiseAdd
<
DeviceContext
,
T
>
add_bias
;
add_bias
(
dev_ctx
,
*
batch_gate
,
*
bias
,
batch_gate
);
}
int
frame_size
=
hidden_dims
[
1
];
math
::
GRUMetaValue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state_weight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
if
(
h0
)
{
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
&
ordered_h0
,
true
);
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
));
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output_value
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
BatchCompute
(
context
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
gru
,
ops
::
GRUKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
gru
,
ops
::
GRUKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
...
paddle/fluid/operators/gru_op.h
浏览文件 @
18c322c2
...
@@ -40,129 +40,6 @@ inline void ReorderInitState(const DeviceContext& ctx,
...
@@ -40,129 +40,6 @@ inline void ReorderInitState(const DeviceContext& ctx,
row_shuffle
(
ctx
,
src
,
index_lod
,
dst
,
indexed_src
);
row_shuffle
(
ctx
,
src
,
index_lod
,
dst
,
indexed_src
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
GRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
context
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_reset_hidden_prev
=
context
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_hidden
=
context
.
Output
<
LoDTensor
>
(
"BatchHidden"
);
batch_hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
hidden
=
context
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
hidden_dims
=
hidden
->
dims
();
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
to_batch
(
dev_ctx
,
*
input
,
batch_gate
,
true
,
is_reverse
);
if
(
bias
)
{
math
::
RowwiseAdd
<
DeviceContext
,
T
>
add_bias
;
add_bias
(
dev_ctx
,
*
batch_gate
,
*
bias
,
batch_gate
);
}
int
frame_size
=
hidden_dims
[
1
];
math
::
GRUMetaValue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state_weight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
if
(
h0
)
{
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
&
ordered_h0
,
true
);
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
));
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
// TODO(TJ): make a class, make one pack
T
*
packed_gate
=
blas
.
GEMM_ALLOC
(
CblasBMatrix
,
1
/*height of C*/
,
frame_size
*
2
/*width of weight*/
,
frame_size
/*height of height*/
);
PADDLE_ENFORCE
(
packed_gate
);
blas
.
GEMM_PACK
(
CblasBMatrix
,
CblasNoTrans
,
1
/*cur bs?*/
,
frame_size
*
2
,
frame_size
,
T
(
1.0
),
gru_value
.
gate_weight
,
frame_size
*
2
,
packed_gate
);
T
*
packed_state
=
blas
.
GEMM_ALLOC
(
CblasBMatrix
,
1
/*height of C*/
,
frame_size
/*width of weight*/
,
frame_size
/*height of height*/
);
PADDLE_ENFORCE
(
packed_state
);
blas
.
GEMM_PACK
(
CblasBMatrix
,
CblasNoTrans
,
1
/*cur bs?*/
,
frame_size
,
frame_size
,
T
(
1.0
),
gru_value
.
state_weight
,
frame_size
,
packed_state
);
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output_value
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
if
(
gru_value
.
prev_out_value
)
{
blas
.
GEMM_COMPUTE
(
CblasNoTrans
,
CblasPacked
,
cur_batch_size
,
frame_size
*
2
,
frame_size
,
gru_value
.
prev_out_value
,
frame_size
,
packed_gate
,
frame_size
*
2
,
T
(
1
),
gru_value
.
gate_value
,
frame_size
*
3
);
}
math
::
detail
::
forward_reset_output
(
math
::
detail
::
forward
::
gru_resetOutput
<
T
>
(),
gru_value
,
frame_size
,
cur_batch_size
,
active_gate
);
if
(
gru_value
.
prev_out_value
)
{
blas
.
GEMM_COMPUTE
(
CblasNoTrans
,
CblasPacked
,
cur_batch_size
,
frame_size
,
frame_size
,
gru_value
.
reset_output_value
,
frame_size
,
packed_state
,
frame_size
,
T
(
1
),
gru_value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
}
math
::
detail
::
forward_final_output
(
math
::
detail
::
forward
::
gru_finalOutput
<
T
>
(),
gru_value
,
frame_size
,
cur_batch_size
,
active_node
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
blas
.
GEMM_FREE
(
packed_gate
);
blas
.
GEMM_FREE
(
packed_state
);
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
BatchCompute
(
context
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
GRUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GRUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录