Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6f78fd7d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6f78fd7d
编写于
8月 16, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fuse fc in gru
上级
300180cc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
130 addition
and
101 deletion
+130
-101
paddle/fluid/operators/fusion_gru_op.cc
paddle/fluid/operators/fusion_gru_op.cc
+130
-101
未找到文件。
paddle/fluid/operators/fusion_gru_op.cc
浏览文件 @
6f78fd7d
...
@@ -15,8 +15,11 @@ limitations under the License. */
...
@@ -15,8 +15,11 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h"
#include "paddle/fluid/operators/fusion_gru_op.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/
framework/op_registry
.h"
#include "paddle/fluid/
operators/math/blas
.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
...
@@ -25,47 +28,69 @@ namespace paddle {
...
@@ -25,47 +28,69 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
void
FusionGRUOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
void
FusionGRUOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GRU should not be null."
);
"Input(%s) of GRUOp should not be null."
,
"Input"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightX"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(WeightX) of GRU should not be null."
);
"Input(%s) of GRUOp should not be null."
,
"Weight"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Input(WeightH) of GRU should not be null."
);
"Output(%s) of GRUOp should not be null."
,
"BatchGate"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Output(XX) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedGate"
),
"Output(BatchedGate) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchResetHiddenPrev"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchResetHiddenPrev"
),
"Output(%s) of GRUOp should not be null."
,
"Output(BatchResetHiddenPrev) of GRU should not be null."
);
"BatchResetHiddenPrev"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchHidden"
),
"Output(BatchedHidden) of GRU should not be null."
);
"Output(%s) of GRUOp should not be null."
,
"BatchHidden"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Output(%s) of GRUOp should not be null."
,
"Hidden"
);
"Output(Hidden) of GRU should not be null."
);
auto
input_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
weight_dims
=
ctx
->
GetInputDim
(
"Weight"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
int
input_size
=
input_dims
[
1
];
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
int
frame_size
=
weight_dims
[
0
];
PADDLE_ENFORCE_EQ
(
input_size
,
frame_size
*
3
,
auto
wx_dims
=
ctx
->
GetInputDim
(
"WeightX"
);
"The input_size must be 3 times of frame_size in GRUOp."
);
PADDLE_ENFORCE_EQ
(
wx_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
"The rank of Input(WeightX) should be 2."
);
weight_dims
[
1
],
frame_size
*
3
,
PADDLE_ENFORCE_EQ
(
wx_dims
[
0
],
x_dims
[
1
],
"The shape of Weight matrix must be [frame_size, frame_size * 3]."
);
"The first dimension of Input(WeightX) "
"should be %d."
,
x_dims
[
1
]);
int
frame_size
=
wx_dims
[
1
]
/
3
;
auto
wh_dims
=
ctx
->
GetInputDim
(
"WeightH"
);
PADDLE_ENFORCE_EQ
(
wh_dims
.
size
(),
2
,
"The rank of Input(WeightH) should be 2."
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
0
],
frame_size
,
"The first dimension of Input(WeightH) "
"should be %d."
,
frame_size
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
1
],
3
*
frame_size
,
"The second dimension of Input(WeightH) "
"should be 3 * %d."
,
frame_size
);
if
(
ctx
->
HasInput
(
"H0"
))
{
if
(
ctx
->
HasInput
(
"H0"
))
{
auto
h0_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
h0_dims
=
ctx
->
GetInputDim
(
"H0"
);
PADDLE_ENFORCE_EQ
(
h0_dims
[
1
],
frame_size
,
PADDLE_ENFORCE_EQ
(
h0_dims
[
1
],
frame_size
,
"The width of H0 must be equal to frame_size."
);
"The width of H0 must be equal to frame_size."
);
}
}
if
(
ctx
->
HasInput
(
"Bias"
))
{
if
(
ctx
->
HasInput
(
"Bias"
))
{
auto
bias_dims
=
ctx
->
GetInputDim
(
"Bias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
int
bias_height
=
bias_dims
[
0
];
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
int
bias_width
=
bias_dims
[
1
];
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
PADDLE_ENFORCE_EQ
(
bias_height
,
1
,
"The first dimension of Input(Bias) should be 1."
);
"The shape of Bias must be [1, frame_size * 3]."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
frame_size
*
3
,
PADDLE_ENFORCE_EQ
(
bias_width
,
frame_size
*
3
,
"The shape of Bias must be [1, frame_size * 3]."
);
"The shape of Bias must be [1, frame_size * 3]."
);
}
}
ctx
->
SetOutputDim
(
"BatchGate"
,
input_dims
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"BatchResetHiddenPrev"
,
{
input_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchHidden"
,
{
input_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"BatchedGate"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"Hidden"
,
{
input_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
ShareLoD
(
"Input"
,
"Hidden"
);
ctx
->
SetOutputDim
(
"BatchResetHiddenPrev"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
int
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"X"
,
"XX"
);
}
}
framework
::
OpKernelType
FusionGRUOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
FusionGRUOp
::
GetExpectedKernelType
(
...
@@ -76,53 +101,38 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
...
@@ -76,53 +101,38 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
}
}
void
FusionGRUOpMaker
::
Make
()
{
void
FusionGRUOpMaker
::
Make
()
{
AddInput
(
"
Input
"
,
AddInput
(
"
X
"
,
"(LoDTensor)
The first input is a LodTensor, which supports
"
"(LoDTensor)
the input is a LodTensor, which support
"
"variable-time length input sequence. The underlying tensor in "
"variable-time length input sequence. The underlying tensor in "
"this LoDTen
osr is a matrix with shape (T X 3D), where,
T is the "
"this LoDTen
sor is a matrix with shape (T X M), where
T is the "
"total time steps in this mini-batch,
D is the hidden size
."
);
"total time steps in this mini-batch,
M is the dim size of x
."
);
AddInput
(
"H0"
,
AddInput
(
"H0"
,
"(Tensor, optional) The initial hidden state is an optional "
"(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size."
)
"batch size, D is the hidden size."
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
AddInput
(
"WeightX"
,
"Weight"
,
"(Tensor) The FC weight with shape (M x 3D),"
"(Tensor) The learnable hidden-hidden weight matrix with shape "
"where M is the dim size of x, D is the hidden size. "
);
"(D x 3D), where D is the hidden size. The elements continuous in "
AddInput
(
"WeightH"
,
"memory can be divided into two parts. The first part are weights of "
"(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "
);
"the update gate and reset gate with shape (D x 2D), and the second "
"part are weights of output candidate with shape (D x D)."
);
AddInput
(
"Bias"
,
AddInput
(
"Bias"
,
"(Tensor, optional) Bias vector with shape (1 x 3D) concating "
"(Tensor, optional) (1 x 3D)."
"bias of the update gate, reset gate and output candidate."
)
"Almost same as GRUOp."
"Note: if have FC bias it should be added on this bias."
)
.
AsDispensable
();
.
AsDispensable
();
AddOutput
(
"BatchGate"
,
AddOutput
(
"XX"
,
"(LoDTensor) To compute with batches, sequence data will be "
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
"reorganized into several successive batches each containing "
" or batched_X (size is T x M), this will be automatically chosen,"
"data from the same time step. The LoDTensor BatchGate contains "
" where T is the total time steps in this mini-batch,"
"the update gate, reset gate and output candidate values "
" D is the hidden size, M is the dim size of x input."
)
"organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in "
"the raw sequence data."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"BatchedGate"
,
"(LoDTensor) Same as GRUOp"
).
AsIntermediate
();
"BatchResetHiddenPrev"
,
AddOutput
(
"BatchResetHiddenPrev"
,
"(LoDTensor) (T x 3D) Same as GRUOp."
)
"(LoDTensor) The reseted hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"BatchedHidden"
,
"(LoDTensor) (T X D) Same as GRUOp."
)
"BatchHidden"
,
"(LoDTensor) The hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"Hidden"
,
"(LoDTensor) (T x D) Same as GRUOp"
);
"Hidden"
,
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`."
);
AddAttr
<
std
::
string
>
(
"activation"
,
AddAttr
<
std
::
string
>
(
"activation"
,
"(string, default tanh) "
"(string, default tanh) "
"The activation type used for output candidate {h}_t."
)
"The activation type used for output candidate {h}_t."
)
...
@@ -156,52 +166,71 @@ inline void ReorderInitState(const DeviceContext& ctx,
...
@@ -156,52 +166,71 @@ inline void ReorderInitState(const DeviceContext& ctx,
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FusionGRUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FusionGRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
h
=
context
.
Input
<
LoDTensor
>
(
"H"
);
auto
*
wx
=
ctx
.
Input
<
Tensor
>
(
"WeightX"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
wh
=
ctx
.
Input
<
Tensor
>
(
"WeightH"
);
auto
*
x_weight
=
context
.
Input
<
Tensor
>
(
"XWeight"
);
// x_dim*3D
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
gate_weight
=
context
.
Input
<
Tensor
>
(
"HWeight"
);
// D*3D
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
// 1*3D
auto
hidden_dims
=
hidden
->
dims
();
auto
*
xx
=
ctx
.
Output
<
LoDTensor
>
(
"XX"
);
auto
*
batched_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedGate"
);
auto
*
batch_reset_hidden_prev
=
ctx
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
auto
*
batch_hidden
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
T
*
batched_gate_data
=
batched_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
to_batch
(
dev_ctx
,
*
input
,
batch_gate
,
true
,
is_reverse
);
batch_hidden
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
bias
)
{
const
T
*
x_data
=
x
->
data
<
T
>
();
math
::
RowwiseAdd
<
DeviceContext
,
T
>
add_bias
;
const
T
*
wx_data
=
wx
->
data
<
T
>
();
add_bias
(
dev_ctx
,
*
batch_gate
,
*
bias
,
batch_gate
);
const
T
*
wh_data
=
wh
->
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
auto
wx_dims
=
wx
->
dims
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
if
(
x_dims
[
1
]
>
wx_dims
[
1
])
{
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
x_data
,
wx_data
,
xx_data
,
bias
?
bias
->
data
<
T
>
()
:
NULL
);
to_batch
(
dev_ctx
,
*
xx
,
batched_gate
,
true
,
is_reverse
);
}
else
{
to_batch
(
dev_ctx
,
*
x
,
xx
,
true
,
is_reverse
);
batched_gate
->
set_lod
(
xx
->
lod
());
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
xx_data
,
wx_data
,
batched_gate_data
,
bias
?
bias
->
data
<
T
>
()
:
NULL
);
}
}
int
frame_size
=
hidden_dims
[
1
]
;
int
frame_size
=
static_cast
<
int
>
(
wx_dims
[
1
]
/
3
)
;
math
::
GRUMetaValue
<
T
>
gru_value
;
math
::
GRUMetaValue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
w
eight
_data
);
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
w
h
_data
);
gru_value
.
state_weight
=
gru_value
.
state_weight
=
const_cast
<
T
*>
(
w
eight
_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
w
h
_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
framework
::
Vector
<
size_t
>
order
(
batch
ed
_gate
->
lod
()[
2
]);
if
(
h0
)
{
if
(
h0
)
{
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
ReorderInitState
<
DeviceContext
,
T
>
(
c
ontext
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
c
tx
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
&
ordered_h0
,
&
ordered_h0
,
true
);
true
);
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
}
else
{
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
gru_value
.
prev_out_value
=
nullptr
;
}
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch
ed
_gate
->
lod
()[
0
];
size_t
seq_len
=
batch_starts
.
size
()
-
1
;
size_t
seq_len
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
auto
active_node
=
context
.
Attr
<
std
::
string
>
(
"activation"
));
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
auto
active_gate
=
math
::
detail
::
GetActivationType
(
c
ontext
.
Attr
<
std
::
string
>
(
"gate_activation"
));
c
tx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM
// use MKL packed to speedup GEMM
...
@@ -226,7 +255,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -226,7 +255,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch
ed
_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
...
@@ -269,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -269,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch
ed
_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
...
@@ -287,8 +316,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -287,8 +316,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
}
}
#endif
#endif
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
batch_hidden
->
set_lod
(
batch
ed
_gate
->
lod
());
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
);
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
_out
);
}
}
};
};
...
@@ -300,4 +329,4 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
...
@@ -300,4 +329,4 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
fusion_gru
,
ops
::
FusionGRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
fusion_gru
,
ops
::
FusionGRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
ops
::
Fusion
GRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录