Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
6f78fd7d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6f78fd7d
编写于
8月 16, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fuse fc in gru
上级
300180cc
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
130 addition
and
101 deletion
+130
-101
paddle/fluid/operators/fusion_gru_op.cc
paddle/fluid/operators/fusion_gru_op.cc
+130
-101
未找到文件。
paddle/fluid/operators/fusion_gru_op.cc
浏览文件 @
6f78fd7d
...
@@ -15,8 +15,11 @@ limitations under the License. */
...
@@ -15,8 +15,11 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h"
#include "paddle/fluid/operators/fusion_gru_op.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/
framework/op_registry
.h"
#include "paddle/fluid/
operators/math/blas
.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
...
@@ -25,47 +28,69 @@ namespace paddle {
...
@@ -25,47 +28,69 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
void
FusionGRUOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
void
FusionGRUOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GRU should not be null."
);
"Input(%s) of GRUOp should not be null."
,
"Input"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightX"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(WeightX) of GRU should not be null."
);
"Input(%s) of GRUOp should not be null."
,
"Weight"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Input(WeightH) of GRU should not be null."
);
"Output(%s) of GRUOp should not be null."
,
"BatchGate"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Output(XX) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedGate"
),
"Output(BatchedGate) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchResetHiddenPrev"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchResetHiddenPrev"
),
"Output(%s) of GRUOp should not be null."
,
"Output(BatchResetHiddenPrev) of GRU should not be null."
);
"BatchResetHiddenPrev"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchHidden"
),
"Output(BatchedHidden) of GRU should not be null."
);
"Output(%s) of GRUOp should not be null."
,
"BatchHidden"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Output(%s) of GRUOp should not be null."
,
"Hidden"
);
"Output(Hidden) of GRU should not be null."
);
auto
input_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
weight_dims
=
ctx
->
GetInputDim
(
"Weight"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
int
input_size
=
input_dims
[
1
];
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
int
frame_size
=
weight_dims
[
0
];
PADDLE_ENFORCE_EQ
(
input_size
,
frame_size
*
3
,
auto
wx_dims
=
ctx
->
GetInputDim
(
"WeightX"
);
"The input_size must be 3 times of frame_size in GRUOp."
);
PADDLE_ENFORCE_EQ
(
wx_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
"The rank of Input(WeightX) should be 2."
);
weight_dims
[
1
],
frame_size
*
3
,
PADDLE_ENFORCE_EQ
(
wx_dims
[
0
],
x_dims
[
1
],
"The shape of Weight matrix must be [frame_size, frame_size * 3]."
);
"The first dimension of Input(WeightX) "
"should be %d."
,
x_dims
[
1
]);
int
frame_size
=
wx_dims
[
1
]
/
3
;
auto
wh_dims
=
ctx
->
GetInputDim
(
"WeightH"
);
PADDLE_ENFORCE_EQ
(
wh_dims
.
size
(),
2
,
"The rank of Input(WeightH) should be 2."
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
0
],
frame_size
,
"The first dimension of Input(WeightH) "
"should be %d."
,
frame_size
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
1
],
3
*
frame_size
,
"The second dimension of Input(WeightH) "
"should be 3 * %d."
,
frame_size
);
if
(
ctx
->
HasInput
(
"H0"
))
{
if
(
ctx
->
HasInput
(
"H0"
))
{
auto
h0_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
h0_dims
=
ctx
->
GetInputDim
(
"H0"
);
PADDLE_ENFORCE_EQ
(
h0_dims
[
1
],
frame_size
,
PADDLE_ENFORCE_EQ
(
h0_dims
[
1
],
frame_size
,
"The width of H0 must be equal to frame_size."
);
"The width of H0 must be equal to frame_size."
);
}
}
if
(
ctx
->
HasInput
(
"Bias"
))
{
if
(
ctx
->
HasInput
(
"Bias"
))
{
auto
bias_dims
=
ctx
->
GetInputDim
(
"Bias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
int
bias_height
=
bias_dims
[
0
];
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
int
bias_width
=
bias_dims
[
1
];
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
PADDLE_ENFORCE_EQ
(
bias_height
,
1
,
"The first dimension of Input(Bias) should be 1."
);
"The shape of Bias must be [1, frame_size * 3]."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
frame_size
*
3
,
PADDLE_ENFORCE_EQ
(
bias_width
,
frame_size
*
3
,
"The shape of Bias must be [1, frame_size * 3]."
);
"The shape of Bias must be [1, frame_size * 3]."
);
}
}
ctx
->
SetOutputDim
(
"BatchGate"
,
input_dims
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"BatchResetHiddenPrev"
,
{
input_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchHidden"
,
{
input_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"BatchedGate"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"Hidden"
,
{
input_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
ShareLoD
(
"Input"
,
"Hidden"
);
ctx
->
SetOutputDim
(
"BatchResetHiddenPrev"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
int
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"X"
,
"XX"
);
}
}
framework
::
OpKernelType
FusionGRUOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
FusionGRUOp
::
GetExpectedKernelType
(
...
@@ -76,53 +101,38 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
...
@@ -76,53 +101,38 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
}
}
void
FusionGRUOpMaker
::
Make
()
{
void
FusionGRUOpMaker
::
Make
()
{
AddInput
(
"
Input
"
,
AddInput
(
"
X
"
,
"(LoDTensor)
The first input is a LodTensor, which supports
"
"(LoDTensor)
the input is a LodTensor, which support
"
"variable-time length input sequence. The underlying tensor in "
"variable-time length input sequence. The underlying tensor in "
"this LoDTen
osr is a matrix with shape (T X 3D), where,
T is the "
"this LoDTen
sor is a matrix with shape (T X M), where
T is the "
"total time steps in this mini-batch,
D is the hidden size
."
);
"total time steps in this mini-batch,
M is the dim size of x
."
);
AddInput
(
"H0"
,
AddInput
(
"H0"
,
"(Tensor, optional) The initial hidden state is an optional "
"(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size."
)
"batch size, D is the hidden size."
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
AddInput
(
"WeightX"
,
"Weight"
,
"(Tensor) The FC weight with shape (M x 3D),"
"(Tensor) The learnable hidden-hidden weight matrix with shape "
"where M is the dim size of x, D is the hidden size. "
);
"(D x 3D), where D is the hidden size. The elements continuous in "
AddInput
(
"WeightH"
,
"memory can be divided into two parts. The first part are weights of "
"(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "
);
"the update gate and reset gate with shape (D x 2D), and the second "
"part are weights of output candidate with shape (D x D)."
);
AddInput
(
"Bias"
,
AddInput
(
"Bias"
,
"(Tensor, optional) Bias vector with shape (1 x 3D) concating "
"(Tensor, optional) (1 x 3D)."
"bias of the update gate, reset gate and output candidate."
)
"Almost same as GRUOp."
"Note: if have FC bias it should be added on this bias."
)
.
AsDispensable
();
.
AsDispensable
();
AddOutput
(
"BatchGate"
,
AddOutput
(
"XX"
,
"(LoDTensor) To compute with batches, sequence data will be "
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
"reorganized into several successive batches each containing "
" or batched_X (size is T x M), this will be automatically chosen,"
"data from the same time step. The LoDTensor BatchGate contains "
" where T is the total time steps in this mini-batch,"
"the update gate, reset gate and output candidate values "
" D is the hidden size, M is the dim size of x input."
)
"organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in "
"the raw sequence data."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"BatchedGate"
,
"(LoDTensor) Same as GRUOp"
).
AsIntermediate
();
"BatchResetHiddenPrev"
,
AddOutput
(
"BatchResetHiddenPrev"
,
"(LoDTensor) (T x 3D) Same as GRUOp."
)
"(LoDTensor) The reseted hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"BatchedHidden"
,
"(LoDTensor) (T X D) Same as GRUOp."
)
"BatchHidden"
,
"(LoDTensor) The hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
AddOutput
(
"Hidden"
,
"(LoDTensor) (T x D) Same as GRUOp"
);
"Hidden"
,
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`."
);
AddAttr
<
std
::
string
>
(
"activation"
,
AddAttr
<
std
::
string
>
(
"activation"
,
"(string, default tanh) "
"(string, default tanh) "
"The activation type used for output candidate {h}_t."
)
"The activation type used for output candidate {h}_t."
)
...
@@ -156,52 +166,71 @@ inline void ReorderInitState(const DeviceContext& ctx,
...
@@ -156,52 +166,71 @@ inline void ReorderInitState(const DeviceContext& ctx,
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FusionGRUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FusionGRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
h
=
context
.
Input
<
LoDTensor
>
(
"H"
);
auto
*
wx
=
ctx
.
Input
<
Tensor
>
(
"WeightX"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
wh
=
ctx
.
Input
<
Tensor
>
(
"WeightH"
);
auto
*
x_weight
=
context
.
Input
<
Tensor
>
(
"XWeight"
);
// x_dim*3D
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
gate_weight
=
context
.
Input
<
Tensor
>
(
"HWeight"
);
// D*3D
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
// 1*3D
auto
hidden_dims
=
hidden
->
dims
();
auto
*
xx
=
ctx
.
Output
<
LoDTensor
>
(
"XX"
);
auto
*
batched_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedGate"
);
auto
*
batch_reset_hidden_prev
=
ctx
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
auto
*
batch_hidden
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
T
*
batched_gate_data
=
batched_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
to_batch
(
dev_ctx
,
*
input
,
batch_gate
,
true
,
is_reverse
);
batch_hidden
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
bias
)
{
const
T
*
x_data
=
x
->
data
<
T
>
();
math
::
RowwiseAdd
<
DeviceContext
,
T
>
add_bias
;
const
T
*
wx_data
=
wx
->
data
<
T
>
();
add_bias
(
dev_ctx
,
*
batch_gate
,
*
bias
,
batch_gate
);
const
T
*
wh_data
=
wh
->
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
auto
wx_dims
=
wx
->
dims
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
if
(
x_dims
[
1
]
>
wx_dims
[
1
])
{
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
x_data
,
wx_data
,
xx_data
,
bias
?
bias
->
data
<
T
>
()
:
NULL
);
to_batch
(
dev_ctx
,
*
xx
,
batched_gate
,
true
,
is_reverse
);
}
else
{
to_batch
(
dev_ctx
,
*
x
,
xx
,
true
,
is_reverse
);
batched_gate
->
set_lod
(
xx
->
lod
());
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
xx_data
,
wx_data
,
batched_gate_data
,
bias
?
bias
->
data
<
T
>
()
:
NULL
);
}
}
int
frame_size
=
hidden_dims
[
1
]
;
int
frame_size
=
static_cast
<
int
>
(
wx_dims
[
1
]
/
3
)
;
math
::
GRUMetaValue
<
T
>
gru_value
;
math
::
GRUMetaValue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
w
eight
_data
);
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
w
h
_data
);
gru_value
.
state_weight
=
gru_value
.
state_weight
=
const_cast
<
T
*>
(
w
eight
_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
w
h
_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
framework
::
Vector
<
size_t
>
order
(
batch
ed
_gate
->
lod
()[
2
]);
if
(
h0
)
{
if
(
h0
)
{
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
ReorderInitState
<
DeviceContext
,
T
>
(
c
ontext
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
c
tx
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
&
ordered_h0
,
&
ordered_h0
,
true
);
true
);
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
}
else
{
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
gru_value
.
prev_out_value
=
nullptr
;
}
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch
ed
_gate
->
lod
()[
0
];
size_t
seq_len
=
batch_starts
.
size
()
-
1
;
size_t
seq_len
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
auto
active_node
=
context
.
Attr
<
std
::
string
>
(
"activation"
));
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
auto
active_gate
=
math
::
detail
::
GetActivationType
(
c
ontext
.
Attr
<
std
::
string
>
(
"gate_activation"
));
c
tx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM
// use MKL packed to speedup GEMM
...
@@ -226,7 +255,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -226,7 +255,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch
ed
_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
...
@@ -269,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -269,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch
ed
_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
...
@@ -287,8 +316,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -287,8 +316,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
}
}
#endif
#endif
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
batch_hidden
->
set_lod
(
batch
ed
_gate
->
lod
());
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
);
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
_out
);
}
}
};
};
...
@@ -300,4 +329,4 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
...
@@ -300,4 +329,4 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
fusion_gru
,
ops
::
FusionGRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
fusion_gru
,
ops
::
FusionGRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
ops
::
Fusion
GRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录