Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b87eabae
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b87eabae
编写于
10月 18, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add GRU Operator
上级
7d653c41
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
2008 addition
and
9 deletion
+2008
-9
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+3
-1
paddle/operators/gru_op.cc
paddle/operators/gru_op.cc
+213
-0
paddle/operators/gru_op.cu
paddle/operators/gru_op.cu
+23
-0
paddle/operators/gru_op.h
paddle/operators/gru_op.h
+258
-0
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+2
-0
paddle/operators/math/detail/gru_cpu_kernel.h
paddle/operators/math/detail/gru_cpu_kernel.h
+428
-0
paddle/operators/math/detail/gru_gpu_kernel.h
paddle/operators/math/detail/gru_gpu_kernel.h
+207
-0
paddle/operators/math/detail/gru_kernel.h
paddle/operators/math/detail/gru_kernel.h
+191
-0
paddle/operators/math/gru_compute.cc
paddle/operators/math/gru_compute.cc
+102
-0
paddle/operators/math/gru_compute.cu
paddle/operators/math/gru_compute.cu
+178
-0
paddle/operators/math/gru_compute.h
paddle/operators/math/gru_compute.h
+82
-0
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+138
-8
python/paddle/v2/framework/tests/test_gru_op.py
python/paddle/v2/framework/tests/test_gru_op.py
+183
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
b87eabae
...
...
@@ -116,7 +116,8 @@ set(DEPS_OPS
sum_op
pool_op
pool_with_index_op
lstm_op
)
lstm_op
gru_op
)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
...
...
@@ -128,6 +129,7 @@ op_library(sum_op DEPS net_op)
op_library
(
pool_op DEPS pooling
)
op_library
(
pool_with_index_op DEPS pooling
)
op_library
(
lstm_op DEPS sequence2batch lstm_compute
)
op_library
(
gru_op DEPS sequence2batch gru_compute
)
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
...
...
paddle/operators/gru_op.cc
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gru_op.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
class
GRUOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(%s) of GRUOp should not be null."
,
"Input"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(%s) of GRUOp should not be null."
,
"Weight"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output(%s) of GRUOp should not be null."
,
"BatchGate"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchResetHiddenPrev"
),
"Output(%s) of GRUOp should not be null."
,
"BatchResetHiddenPrev"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchHidden"
),
"Output(%s) of GRUOp should not be null."
,
"BatchHidden"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Output(%s) of GRUOp should not be null."
,
"Hidden"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
weight_dims
=
ctx
->
GetInputDim
(
"Weight"
);
int
input_size
=
input_dims
[
1
];
int
frame_size
=
weight_dims
[
0
];
PADDLE_ENFORCE_EQ
(
input_size
,
frame_size
*
3
,
"The input_size must be 3 times of frame_size in GRUOp."
);
PADDLE_ENFORCE_EQ
(
weight_dims
[
1
],
frame_size
*
3
,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."
);
auto
h0
=
Input
(
"H0"
);
if
(
h0
!=
framework
::
kEmptyVarName
)
{
auto
h0_dims
=
ctx
->
GetInputDim
(
"H0"
);
PADDLE_ENFORCE_EQ
(
h0_dims
[
1
],
frame_size
,
"The width of H0 must be equal to frame_size."
);
}
auto
bias
=
Input
(
"Bias"
);
if
(
bias
!=
framework
::
kEmptyVarName
)
{
auto
bias_dims
=
ctx
->
GetInputDim
(
"Bias"
);
int
bias_height
=
bias_dims
[
0
];
int
bias_width
=
bias_dims
[
1
];
PADDLE_ENFORCE_EQ
(
bias_height
,
1
,
"The shape of Bias must be [1, frame_size * 3]."
);
PADDLE_ENFORCE_EQ
(
bias_width
,
frame_size
*
3
,
"The shape of Bias must be [1, frame_size * 3]."
);
}
ctx
->
SetOutputDim
(
"BatchGate"
,
input_dims
);
ctx
->
SetOutputDim
(
"BatchResetHiddenPrev"
,
{
input_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"BatchHidden"
,
{
input_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
{
input_dims
[
0
],
frame_size
});
// ctx->ShareLoD("Input", "Gate");
// ctx->ShareLoD("Input", "ResetHiddenPrev");
ctx
->
ShareLoD
(
"Input"
,
"Hidden"
);
}
};
class
GRUOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
GRUOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Input"
,
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTenosr is a matrix with shape (T X 3D), where, T is the "
"total time steps in this mini-batch, D is the hidden size."
);
AddInput
(
"H0"
,
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size."
);
AddInput
(
"Weight"
,
"(Tensor) Weight matrix with shape [hidden_size, hidden_size * 3]. "
"The elements continuous in memory can be divided into two parts. "
"The first part are weights of the update gate and reset gate "
"with shape [hidden_size, hidden_size * 2], and the second part are "
"weights of output candidate with shape [hidden_size, hidden_size]"
);
AddInput
(
"Bias"
,
"(Tensor) Bias vector with shape [1, hidden_size * 3] concating "
"bias of the update gate, reset gate and output candidate."
);
AddOutput
(
"BatchGate"
,
"(LoDTensor) the update gata, reset gate and output candidate "
"lod tensor of GRU operator. "
"The shape and lod is the same with the `Input`."
)
.
AsIntermediate
();
AddOutput
(
"BatchResetHiddenPrev"
,
"(LoDTensor) the reseted hidden state lod tensor of GRU operator. "
"The shape and lod is the same with the `Input`."
)
.
AsIntermediate
();
AddOutput
(
"BatchHidden"
,
"(LoDTensor) the reseted hidden state lod tensor of GRU operator. "
"The shape and lod is the same with the `Input`."
)
.
AsIntermediate
();
AddOutput
(
"Hidden"
,
"(LoDTensor) the hidden state lod tensor of GRU operator. "
"The shape and lod is the same with the `Input`."
);
AddAttr
<
std
::
string
>
(
"activation"
,
"(string, default tanh) "
"The activation type used for output candidate {h}_t."
)
.
SetDefault
(
"tanh"
);
AddAttr
<
std
::
string
>
(
"gate_activation"
,
"(string, default sigmoid) "
"The activation type used in update gate and reset gate."
)
.
SetDefault
(
"sigmoid"
);
AddAttr
<
bool
>
(
"is_reverse"
,
"(bool, defalut: False) "
"whether to compute reversed GRU."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
GRUOp implements part calculations of the GRU unit as following:
\f[
update \ gate: u_t = actGate(xu_t + W_u * hidden_prev + bias_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * hidden_prev + bias_r) \\
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, hidden_prev) + bias_c) \\
output: h_t = dot((1-u_t), hidden_prev) + dot(u_t, {h}_t)
\f]
The rest of GRU unit can be completed by using FCOp's output as the input of GRUOp.
)DOC"
);
}
};
class
GRUGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(%s) of GRUGradOp should not be null."
,
"Input"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(%s) of GRUGradOp should not be null."
,
"Weight"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
"Input(%s) of GRUGradOp should not be null."
,
"BatchGate"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchResetHiddenPrev"
),
"Input(%s) of GRUGradOp should not be null."
,
"BatchResetHiddenPrev"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchHidden"
),
"Input(%s) of GRUOp should not be null."
,
"BatchHidden"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Hidden"
),
"Input(%s) of GRUGradOp should not be null."
,
"Hidden"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Hidden"
)),
"Input(%s@GRAD) of GRUGradOp should not be null."
,
"Hidden"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
weight_dims
=
ctx
->
GetInputDim
(
"Weight"
);
int
input_size
=
input_dims
[
1
];
int
frame_size
=
weight_dims
[
0
];
int
weight_height
=
weight_dims
[
0
];
int
weight_width
=
weight_dims
[
1
];
PADDLE_ENFORCE_EQ
(
input_size
,
frame_size
*
3
,
"The input_size must be 3 times of frame_size in GRUOp."
);
PADDLE_ENFORCE_EQ
(
weight_height
,
frame_size
,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."
);
PADDLE_ENFORCE_EQ
(
weight_width
,
frame_size
*
3
,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."
);
auto
h0
=
Input
(
"H0"
);
if
(
h0
!=
framework
::
kEmptyVarName
)
{
auto
h0_dims
=
ctx
->
GetInputDim
(
"H0"
);
PADDLE_ENFORCE_EQ
(
h0_dims
[
1
],
frame_size
,
"The width of H0 must be equal to frame_size."
);
auto
h0_grad_name
=
framework
::
GradVarName
(
"H0"
);
if
(
ctx
->
HasOutput
(
h0_grad_name
))
ctx
->
SetOutputDim
(
h0_grad_name
,
h0_dims
);
}
auto
bias
=
Input
(
"Bias"
);
if
(
bias
!=
framework
::
kEmptyVarName
)
{
auto
bias_dims
=
ctx
->
GetInputDim
(
"Bias"
);
int
bias_height
=
bias_dims
[
0
];
int
bias_width
=
bias_dims
[
1
];
PADDLE_ENFORCE_EQ
(
bias_height
,
1
,
"The shape of Bias must be [1, frame_size * 3]."
);
PADDLE_ENFORCE_EQ
(
bias_width
,
frame_size
*
3
,
"The shape of Bias must be [1, frame_size * 3]."
);
auto
bias_grad_name
=
framework
::
GradVarName
(
"Bias"
);
if
(
ctx
->
HasOutput
(
bias_grad_name
))
ctx
->
SetOutputDim
(
bias_grad_name
,
bias_dims
);
}
auto
input_grad_name
=
framework
::
GradVarName
(
"Input"
);
if
(
ctx
->
HasOutput
(
input_grad_name
))
ctx
->
SetOutputDim
(
input_grad_name
,
input_dims
);
auto
weight_grad_name
=
framework
::
GradVarName
(
"Weight"
);
if
(
ctx
->
HasOutput
(
weight_grad_name
))
ctx
->
SetOutputDim
(
weight_grad_name
,
weight_dims
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
gru
,
ops
::
GRUOp
,
ops
::
GRUOpMaker
,
gru_grad
,
ops
::
GRUGradOp
);
REGISTER_OP_CPU_KERNEL
(
gru
,
ops
::
GRUKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GRUKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
gru_grad
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/operators/gru_op.cu
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/gru_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
gru
,
ops
::
GRUKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GRUKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
REGISTER_OP_GPU_KERNEL
(
gru_grad
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
paddle/operators/gru_op.h
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
class
GRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
context
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_reset_hidden_prev
=
context
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_hidden
=
context
.
Output
<
LoDTensor
>
(
"BatchHidden"
);
batch_hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
hidden
=
context
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// context.ShareLoD("Input", "Gate");
// context.ShareLoD("Input", "ResetHiddenPrev");
context
.
ShareLoD
(
"Input"
,
"Hidden"
);
// auto gate_dims = gate->dims();
auto
hidden_dims
=
hidden
->
dims
();
// LoDTensor batch_gate, batch_reset_hidden_prev, batch_hidden;
// batch_gate.mutable_data<T>(gate_dims, context.GetPlace());
// batch_reset_hidden_prev.mutable_data<T>(hidden_dims, context.GetPlace());
// batch_hidden.mutable_data<T>(hidden_dims, context.GetPlace());
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
// to_batch(context.device_context(), *input, batch_gate, is_reverse);
to_batch
(
context
.
device_context
(),
*
input
,
*
batch_gate
,
is_reverse
);
int
frame_size
=
hidden_dims
[
1
];
int
batch_size
=
hidden_dims
[
0
];
// auto g = EigenMatrix<T>::From(batch_gate);
auto
g
=
EigenMatrix
<
T
>::
From
(
*
batch_gate
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
if
(
bias
)
{
auto
b
=
EigenMatrix
<
T
>::
From
(
*
bias
);
g
.
device
(
place
)
=
g
+
b
.
reshape
(
Eigen
::
array
<
int
,
2
>
({{
1
,
frame_size
*
3
}}))
.
broadcast
(
Eigen
::
array
<
int
,
2
>
({{
batch_size
,
1
}}));
}
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gateWeight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
stateWeight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
gru_value
.
prevOutValue
=
const_cast
<
T
*>
(
h0_data
);
// auto batch_starts = batch_gate.lod()[0];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
// for (auto i = batch_gate->lod()[1].begin(); i !=
// batch_gate->lod()[1].end(); ++i)
// std::cout << static_cast<int>(*i) << ' ';
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
// Tensor gate_t = batch_gate.Slice(bstart, bend);
// Tensor reset_hidden_prev_t = batch_reset_hidden_prev.Slice(bstart,
// bend);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
outputValue
=
hidden_t
.
data
<
T
>
();
gru_value
.
gateValue
=
gate_t
.
data
<
T
>
();
gru_value
.
resetOutputValue
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
Place
,
T
>::
compute
(
context
.
device_context
(),
gru_value
,
frame_size
,
cur_batch_size
,
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
gru_value
.
prevOutValue
=
gru_value
.
outputValue
;
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
// batch_gate.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_gate, *gate);
// batch_reset_hidden_prev.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_reset_hidden_prev,
// *reset_hidden_prev);
// batch_hidden.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_hidden, *hidden);
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
to_seq
(
context
.
device_context
(),
*
batch_hidden
,
*
hidden
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
BatchCompute
(
context
);
}
};
template
<
typename
Place
,
typename
T
>
class
GRUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
auto
*
batch_gate
=
context
.
Input
<
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_reset_hidden_prev
=
context
.
Input
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
auto
*
batch_hidden
=
context
.
Input
<
LoDTensor
>
(
"BatchHidden"
);
auto
*
hidden
=
context
.
Input
<
LoDTensor
>
(
"Hidden"
);
auto
*
hidden_grad
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Hidden"
));
auto
*
input_grad
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
*
h0_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"H0"
));
auto
*
weight_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Weight"
));
auto
*
bias_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
gate_dims
=
batch_gate
->
dims
();
auto
hidden_dims
=
hidden
->
dims
();
int
frame_size
=
hidden_dims
[
1
];
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
LoDTensor
batch_hidden_grad
,
batch_gate_grad
,
batch_reset_hidden_prev_grad
;
batch_hidden_grad
.
mutable_data
<
T
>
(
hidden_dims
,
context
.
GetPlace
());
batch_gate_grad
.
mutable_data
<
T
>
(
gate_dims
,
context
.
GetPlace
());
batch_reset_hidden_prev_grad
.
mutable_data
<
T
>
(
hidden_dims
,
context
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
zero
;
zero
(
context
.
device_context
(),
&
batch_hidden_grad
,
static_cast
<
T
>
(
0.0
));
zero
(
context
.
device_context
(),
&
batch_gate_grad
,
static_cast
<
T
>
(
0.0
));
zero
(
context
.
device_context
(),
&
batch_reset_hidden_prev_grad
,
static_cast
<
T
>
(
0.0
));
// batch_hidden.set_lod(batch_gate->lod());
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
batch_hidden_grad
.
set_lod
(
batch_hidden
->
lod
());
// context.ShareLoD(framework::GradVarName("Hidden"),
// framework::GradVarName("Input"));
to_batch
(
context
.
device_context
(),
*
hidden_grad
,
batch_hidden_grad
,
is_reverse
,
false
);
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gateWeight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
stateWeight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
math
::
hl_gru_grad
<
T
>
gru_grad
;
if
(
weight_grad
)
{
gru_grad
.
gateWeightGrad
=
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
zero
(
context
.
device_context
(),
weight_grad
,
static_cast
<
T
>
(
0.0
));
gru_grad
.
stateWeightGrad
=
weight_grad
->
data
<
T
>
()
+
2
*
frame_size
*
frame_size
;
}
else
{
gru_grad
.
gateWeightGrad
=
nullptr
;
gru_grad
.
stateWeightGrad
=
nullptr
;
}
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
for
(
int
n
=
static_cast
<
int
>
(
num_batch
)
-
1
;
n
>=
0
;
n
--
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
gru_value
.
gateValue
=
gate_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
gru_value
.
resetOutputValue
=
reset_hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_grad_t
=
batch_hidden_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
outputGrad
=
hidden_grad_t
.
data
<
T
>
();
Tensor
gate_grad_t
=
batch_gate_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
gateGrad
=
gate_grad_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_grad_t
=
batch_reset_hidden_prev_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
resetOutputGrad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
if
(
n
==
0
)
{
gru_value
.
prevOutValue
=
const_cast
<
T
*>
(
h0_data
);
if
(
h0_grad
)
{
T
*
h0_grad_data
=
h0_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
zero
(
context
.
device_context
(),
h0_grad
,
static_cast
<
T
>
(
0.0
));
gru_grad
.
prevOutGrad
=
h0_grad_data
;
}
else
{
gru_grad
.
prevOutGrad
=
nullptr
;
}
}
else
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
hidden_prev_t
=
batch_hidden
->
Slice
(
bstart_pre
,
bstart
);
gru_value
.
prevOutValue
=
hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_prev_grad_t
=
batch_hidden_grad
.
Slice
(
bstart_pre
,
bstart
);
gru_grad
.
prevOutGrad
=
hidden_prev_grad_t
.
data
<
T
>
();
}
math
::
GRUUnitGradFunctor
<
Place
,
T
>::
compute
(
context
.
device_context
(),
gru_value
,
gru_grad
,
frame_size
,
cur_batch_size
,
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
}
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
batch_gate_grad
.
set_lod
(
batch_gate
->
lod
());
to_seq
(
context
.
device_context
(),
batch_gate_grad
,
*
input_grad
);
}
if
(
bias_grad
)
{
bias_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_b
=
EigenMatrix
<
T
>::
From
(
*
bias_grad
);
auto
d_g
=
EigenMatrix
<
T
>::
From
(
batch_gate_grad
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
d_b
.
device
(
place
)
=
d_g
.
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}));
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
BatchCompute
(
context
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/operators/math/CMakeLists.txt
浏览文件 @
b87eabae
...
...
@@ -11,6 +11,7 @@ if(WITH_GPU)
nv_library
(
vol2col SRCS vol2col.cc vol2col.cu DEPS device_context
)
nv_library
(
sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context
)
nv_library
(
lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions
)
nv_library
(
gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator
)
cc_library
(
selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function
)
...
...
@@ -20,6 +21,7 @@ else()
cc_library
(
vol2col SRCS vol2col.cc DEPS device_context
)
cc_library
(
sequence2batch SRCS sequence2batch.cc DEPS device_context
)
cc_library
(
lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions
)
cc_library
(
gru_compute SRCS gru_compute.cc DEPS device_context activation_functions
)
endif
()
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
paddle/operators/math/detail/gru_cpu_kernel.h
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
detail
{
#ifndef __NVCC__
template
<
class
OpResetOutput
,
typename
T
>
void
hl_naive_gru_forward_reset_output
(
OpResetOutput
opResetOutput
,
T
*
gateValue
,
T
*
resetOutputValue
,
T
*
prevOutputValue
,
int
frameSize
,
activation_mode_t
active_gate
)
{
T
rValueUpdateGate
;
T
rValueResetGate
;
T
rValueResetOutput
;
T
rPrevOut
=
0
;
T
*
updateGate
=
gateValue
;
T
*
resetGate
=
gateValue
+
frameSize
;
for
(
int
i
=
0
;
i
<
frameSize
;
i
++
)
{
rValueUpdateGate
=
updateGate
[
i
];
rValueResetGate
=
resetGate
[
i
];
if
(
prevOutputValue
)
{
rPrevOut
=
prevOutputValue
[
i
];
}
hppl
::
cpu
::
ForwardAct
<
T
>
act
;
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutput
,
act
(
active_gate
));
updateGate
[
i
]
=
rValueUpdateGate
;
resetGate
[
i
]
=
rValueResetGate
;
resetOutputValue
[
i
]
=
rValueResetOutput
;
}
}
template
<
class
OpFinalOutput
,
typename
T
>
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
opFinalOutput
,
T
*
gateValue
,
T
*
prevOutputValue
,
T
*
outputValue
,
int
frameSize
,
activation_mode_t
active_node
)
{
T
rValueUpdateGate
;
T
rValueFrameState
;
T
rPrevOut
=
0
;
T
rOutput
;
T
*
updateGate
=
gateValue
;
T
*
frameState
=
gateValue
+
frameSize
*
2
;
for
(
int
i
=
0
;
i
<
frameSize
;
i
++
)
{
rValueUpdateGate
=
updateGate
[
i
];
rValueFrameState
=
frameState
[
i
];
if
(
prevOutputValue
)
{
rPrevOut
=
prevOutputValue
[
i
];
}
hppl
::
cpu
::
ForwardAct
<
T
>
act
;
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
act
(
active_node
));
frameState
[
i
]
=
rValueFrameState
;
outputValue
[
i
]
=
rOutput
;
}
}
template
<
class
OpResetOutput
,
typename
T
>
void
hl_avx_gru_forward_reset_output
(
OpResetOutput
opResetOutput
,
T
*
gateValue
,
T
*
resetOutputValue
,
T
*
prevOutputValue
,
int
frameSize
,
activation_mode_t
active_gate
)
{
#ifdef __AVX__
__m256
rValueUpdateGate
;
__m256
rValueResetGate
;
__m256
rValueResetOutput
;
__m256
rPrevOut
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
updateGate
=
(
__m256
*
)
gateValue
;
__m256
*
resetGate
=
(
__m256
*
)(
gateValue
+
frameSize
);
for
(
int
i
=
0
;
i
<
frameSize
/
8
;
i
++
)
{
rValueUpdateGate
=
updateGate
[
i
];
rValueResetGate
=
resetGate
[
i
];
if
(
prevOutputValue
)
{
rPrevOut
=
((
__m256
*
)
prevOutputValue
)[
i
];
}
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutput
,
hppl
::
avx
::
forward
[
active_gate
]);
updateGate
[
i
]
=
rValueUpdateGate
;
resetGate
[
i
]
=
rValueResetGate
;
((
__m256
*
)
resetOutputValue
)[
i
]
=
rValueResetOutput
;
}
#endif
}
template
<
class
OpFinalOutput
,
typename
T
>
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
opFinalOutput
,
T
*
gateValue
,
T
*
prevOutputValue
,
T
*
outputValue
,
int
frameSize
,
activation_mode_t
active_node
)
{
#ifdef __AVX__
__m256
rValueUpdateGate
;
__m256
rValueFrameState
;
__m256
rPrevOut
=
_mm256_set1_ps
(
0.0
f
);
__m256
rOutput
;
__m256
*
updateGate
=
(
__m256
*
)
gateValue
;
__m256
*
frameState
=
(
__m256
*
)(
gateValue
+
frameSize
*
2
);
for
(
int
i
=
0
;
i
<
frameSize
/
8
;
i
++
)
{
rValueUpdateGate
=
updateGate
[
i
];
rValueFrameState
=
frameState
[
i
];
if
(
prevOutputValue
)
{
rPrevOut
=
((
__m256
*
)
prevOutputValue
)[
i
];
}
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
hppl
::
avx
::
forward
[
active_node
]);
frameState
[
i
]
=
rValueFrameState
;
((
__m256
*
)
outputValue
)[
i
]
=
rOutput
;
}
#endif
}
template
<
class
OpResetOutput
,
typename
T
>
inline
void
forward_reset_output
(
OpResetOutput
opResetOutput
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_gate
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
if
(
OpResetOutput
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_reset_output
(
opResetOutput
,
value
.
gateValue
,
value
.
resetOutputValue
,
value
.
prevOutValue
,
frameSize
,
active_gate
);
}
else
{
hl_naive_gru_forward_reset_output
(
opResetOutput
,
value
.
gateValue
,
value
.
resetOutputValue
,
value
.
prevOutValue
,
frameSize
,
active_gate
);
}
value
.
gateValue
+=
frameSize
*
3
;
value
.
resetOutputValue
+=
frameSize
;
if
(
value
.
prevOutValue
)
{
value
.
prevOutValue
+=
frameSize
;
}
}
}
template
<
class
OpFinalOutput
,
typename
T
>
inline
void
forward_final_output
(
OpFinalOutput
opFinalOutput
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
if
(
OpFinalOutput
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_final_output
(
opFinalOutput
,
value
.
gateValue
,
value
.
prevOutValue
,
value
.
outputValue
,
frameSize
,
active_node
);
}
else
{
hl_naive_gru_forward_final_output
(
opFinalOutput
,
value
.
gateValue
,
value
.
prevOutValue
,
value
.
outputValue
,
frameSize
,
active_node
);
}
value
.
gateValue
+=
frameSize
*
3
;
value
.
outputValue
+=
frameSize
;
if
(
value
.
prevOutValue
)
{
value
.
prevOutValue
+=
frameSize
;
}
}
}
template
<
class
OpStateGrad
,
typename
T
>
void
hl_naive_gru_backward_state_grad
(
OpStateGrad
opStateGrad
,
T
*
gateValue
,
T
*
gateGrad
,
T
*
prevOutValue
,
T
*
prevOutGrad
,
T
*
outputGrad
,
int
frameSize
,
activation_mode_t
active_node
)
{
T
rUpdateGateValue
;
T
rUpdateGateGrad
;
T
rFrameStateValue
;
T
rFrameStateGrad
;
T
rOutGrad
;
T
rPrevOutValue
=
0
;
T
rPrevOutGrad
=
0
;
T
*
updateGateValue
=
gateValue
;
T
*
updateGateGrad
=
gateGrad
;
T
*
frameStateValue
=
gateValue
+
frameSize
*
2
;
T
*
frameStateGrad
=
gateGrad
+
frameSize
*
2
;
for
(
int
i
=
0
;
i
<
frameSize
;
i
++
)
{
rUpdateGateValue
=
updateGateValue
[
i
];
rFrameStateValue
=
frameStateValue
[
i
];
rOutGrad
=
outputGrad
[
i
];
if
(
prevOutValue
)
{
rPrevOutValue
=
prevOutValue
[
i
];
}
if
(
prevOutGrad
)
{
rPrevOutGrad
=
prevOutGrad
[
i
];
}
hppl
::
cpu
::
BackwardAct
<
T
>
act
;
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
act
(
active_node
));
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
frameStateGrad
[
i
]
=
rFrameStateGrad
;
if
(
prevOutGrad
)
{
prevOutGrad
[
i
]
=
rPrevOutGrad
;
}
}
}
template
<
class
OpResetGrad
,
typename
T
>
void
hl_naive_gru_backward_reset_grad
(
OpResetGrad
opResetGrad
,
T
*
gateValue
,
T
*
gateGrad
,
T
*
prevOutValue
,
T
*
prevOutGrad
,
T
*
resetOutputGrad
,
int
frameSize
,
activation_mode_t
active_gate
)
{
T
rUpdateGateValue
;
T
rUpdateGateGrad
;
T
rResetGateValue
;
T
rResetGateGrad
;
T
rResetOutputGrad
=
0
;
T
rPrevOutValue
=
0
;
T
rPrevOutGrad
=
0
;
T
*
updateGateValue
=
gateValue
;
T
*
updateGateGrad
=
gateGrad
;
T
*
resetGateValue
=
gateValue
+
frameSize
;
T
*
resetGateGrad
=
gateGrad
+
frameSize
;
for
(
int
i
=
0
;
i
<
frameSize
;
i
++
)
{
rUpdateGateValue
=
updateGateValue
[
i
];
rUpdateGateGrad
=
updateGateGrad
[
i
];
rResetGateValue
=
resetGateValue
[
i
];
if
(
prevOutValue
&&
prevOutGrad
)
{
rResetOutputGrad
=
resetOutputGrad
[
i
];
}
if
(
prevOutValue
)
{
rPrevOutValue
=
prevOutValue
[
i
];
}
if
(
prevOutGrad
)
{
rPrevOutGrad
=
prevOutGrad
[
i
];
}
hppl
::
cpu
::
BackwardAct
<
T
>
act
;
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
act
(
active_gate
));
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
resetGateGrad
[
i
]
=
rResetGateGrad
;
if
(
prevOutGrad
)
{
prevOutGrad
[
i
]
=
rPrevOutGrad
;
}
}
}
template
<
class
OpStateGrad
,
typename
T
>
void
hl_avx_gru_backward_state_grad
(
OpStateGrad
opStateGrad
,
T
*
gateValue
,
T
*
gateGrad
,
T
*
prevOutValue
,
T
*
prevOutGrad
,
T
*
outputGrad
,
int
frameSize
,
activation_mode_t
active_node
)
{
#ifdef __AVX__
__m256
rUpdateGateValue
;
__m256
rUpdateGateGrad
;
__m256
rFrameStateValue
;
__m256
rFrameStateGrad
;
__m256
rOutGrad
;
__m256
rPrevOutValue
=
_mm256_set1_ps
(
0.0
f
);
__m256
rPrevOutGrad
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
updateGateValue
=
(
__m256
*
)
gateValue
;
__m256
*
updateGateGrad
=
(
__m256
*
)
gateGrad
;
__m256
*
frameStateValue
=
(
__m256
*
)(
gateValue
+
frameSize
*
2
);
__m256
*
frameStateGrad
=
(
__m256
*
)(
gateGrad
+
frameSize
*
2
);
for
(
int
i
=
0
;
i
<
frameSize
/
8
;
i
++
)
{
rUpdateGateValue
=
updateGateValue
[
i
];
rFrameStateValue
=
frameStateValue
[
i
];
rOutGrad
=
((
__m256
*
)
outputGrad
)[
i
];
if
(
prevOutValue
)
{
rPrevOutValue
=
((
__m256
*
)
prevOutValue
)[
i
];
}
if
(
prevOutGrad
)
{
rPrevOutGrad
=
((
__m256
*
)
prevOutGrad
)[
i
];
}
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
hppl
::
avx
::
backward
[
active_node
]);
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
frameStateGrad
[
i
]
=
rFrameStateGrad
;
if
(
prevOutGrad
)
{
((
__m256
*
)
prevOutGrad
)[
i
]
=
rPrevOutGrad
;
}
}
#endif
}
template
<
class
OpResetGrad
,
typename
T
>
void
hl_avx_gru_backward_reset_grad
(
OpResetGrad
opResetGrad
,
T
*
gateValue
,
T
*
gateGrad
,
T
*
prevOutValue
,
T
*
prevOutGrad
,
T
*
resetOutputGrad
,
int
frameSize
,
activation_mode_t
active_gate
)
{
#ifdef __AVX__
__m256
rUpdateGateValue
;
__m256
rUpdateGateGrad
;
__m256
rResetGateValue
;
__m256
rResetGateGrad
;
__m256
rResetOutputGrad
=
_mm256_set1_ps
(
0.0
f
);
__m256
rPrevOutValue
=
_mm256_set1_ps
(
0.0
f
);
__m256
rPrevOutGrad
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
updateGateValue
=
(
__m256
*
)
gateValue
;
__m256
*
updateGateGrad
=
(
__m256
*
)
gateGrad
;
__m256
*
resetGateValue
=
(
__m256
*
)(
gateValue
+
frameSize
);
__m256
*
resetGateGrad
=
(
__m256
*
)(
gateGrad
+
frameSize
);
for
(
int
i
=
0
;
i
<
frameSize
/
8
;
i
++
)
{
rUpdateGateValue
=
updateGateValue
[
i
];
rUpdateGateGrad
=
updateGateGrad
[
i
];
rResetGateValue
=
resetGateValue
[
i
];
if
(
prevOutValue
&&
prevOutGrad
)
{
rResetOutputGrad
=
((
__m256
*
)
resetOutputGrad
)[
i
];
}
if
(
prevOutValue
)
{
rPrevOutValue
=
((
__m256
*
)
prevOutValue
)[
i
];
}
if
(
prevOutGrad
)
{
rPrevOutGrad
=
((
__m256
*
)
prevOutGrad
)[
i
];
}
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
hppl
::
avx
::
backward
[
active_gate
]);
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
resetGateGrad
[
i
]
=
rResetGateGrad
;
if
(
prevOutGrad
)
{
((
__m256
*
)
prevOutGrad
)[
i
]
=
rPrevOutGrad
;
}
}
#endif
}
template
<
class
OpStateGrad
,
typename
T
>
inline
void
backward_state_grad
(
OpStateGrad
opStateGrad
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
if
(
OpStateGrad
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_state_grad
(
opStateGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
active_node
);
}
else
{
hl_naive_gru_backward_state_grad
(
opStateGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
active_node
);
}
value
.
gateValue
+=
frameSize
*
3
;
if
(
value
.
prevOutValue
)
{
value
.
prevOutValue
+=
frameSize
;
}
grad
.
gateGrad
+=
frameSize
*
3
;
grad
.
outputGrad
+=
frameSize
;
if
(
grad
.
prevOutGrad
)
{
grad
.
prevOutGrad
+=
frameSize
;
}
}
}
template
<
class
OpResetGrad
,
typename
T
>
inline
void
backward_reset_grad
(
OpResetGrad
opResetGrad
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_gate
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
if
(
OpResetGrad
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_reset_grad
(
opResetGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
active_gate
);
}
else
{
hl_naive_gru_backward_reset_grad
(
opResetGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
active_gate
);
}
value
.
gateValue
+=
frameSize
*
3
;
if
(
value
.
prevOutValue
)
{
value
.
prevOutValue
+=
frameSize
;
}
grad
.
gateGrad
+=
frameSize
*
3
;
grad
.
resetOutputGrad
+=
frameSize
;
if
(
grad
.
prevOutGrad
)
{
grad
.
prevOutGrad
+=
frameSize
;
}
}
}
#endif
}
// namespace detail
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/detail/gru_gpu_kernel.h
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
detail
{
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template
<
class
OpResetOutput
,
bool
isBatch
,
typename
T
>
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
opResetOutput
,
T
*
gateValue
,
T
*
resetOutputValue
,
T
*
prevOutputValue
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_gate
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
int
batchIdx
=
0
;
if
(
isBatch
)
{
batchIdx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batchIdx
>=
batchSize
)
return
;
gateValue
+=
batchIdx
*
3
*
frameSize
;
resetOutputValue
+=
batchIdx
*
frameSize
;
}
T
rPrevOut
=
0
;
T
rValueResetOutput
;
T
rValueUpdateGate
=
gateValue
[
frameIdx
+
frameSize
*
0
];
T
rValueResetGate
=
gateValue
[
frameIdx
+
frameSize
*
1
];
if
(
prevOutputValue
)
{
if
(
isBatch
)
prevOutputValue
+=
batchIdx
*
frameSize
;
rPrevOut
=
prevOutputValue
[
frameIdx
];
}
hppl
::
gpu
::
ForwardAct
<
T
>
act
;
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutput
,
act
(
active_gate
));
gateValue
[
frameIdx
+
frameSize
*
0
]
=
rValueUpdateGate
;
gateValue
[
frameIdx
+
frameSize
*
1
]
=
rValueResetGate
;
resetOutputValue
[
frameIdx
]
=
rValueResetOutput
;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template
<
class
OpFinalOutput
,
bool
isBatch
,
typename
T
>
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
opFinalOutput
,
T
*
gateValue
,
T
*
prevOutputValue
,
T
*
outputValue
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
int
batchIdx
=
0
;
if
(
isBatch
)
{
batchIdx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batchIdx
>=
batchSize
)
return
;
gateValue
+=
batchIdx
*
3
*
frameSize
;
outputValue
+=
batchIdx
*
frameSize
;
}
T
rOutput
;
T
rPrevOut
=
0
;
T
rValueUpdateGate
=
gateValue
[
frameIdx
+
frameSize
*
0
];
T
rValueFrameState
=
gateValue
[
frameIdx
+
frameSize
*
2
];
if
(
prevOutputValue
)
{
if
(
isBatch
)
prevOutputValue
+=
batchIdx
*
frameSize
;
rPrevOut
=
prevOutputValue
[
frameIdx
];
}
hppl
::
gpu
::
ForwardAct
<
T
>
act
;
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
act
(
active_node
));
gateValue
[
frameIdx
+
frameSize
*
2
]
=
rValueFrameState
;
outputValue
[
frameIdx
]
=
rOutput
;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template
<
class
OpStateGrad
,
bool
isBatch
,
typename
T
>
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
opStateGrad
,
T
*
gateValue
,
T
*
gateGrad
,
T
*
prevOutValue
,
T
*
prevOutGrad
,
T
*
outputGrad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
int
batchIdx
=
0
;
if
(
isBatch
)
{
batchIdx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batchIdx
>=
batchSize
)
return
;
gateValue
+=
batchIdx
*
3
*
frameSize
;
gateGrad
+=
batchIdx
*
3
*
frameSize
;
outputGrad
+=
batchIdx
*
frameSize
;
}
T
rUpdateGateGrad
;
T
rFrameStateGrad
;
T
rPrevOutValue
=
0
;
T
rPrevOutGrad
=
0
;
T
rUpdateGateValue
=
gateValue
[
frameIdx
+
frameSize
*
0
];
T
rFrameStateValue
=
gateValue
[
frameIdx
+
frameSize
*
2
];
T
rOutGrad
=
outputGrad
[
frameIdx
];
if
(
prevOutValue
&&
prevOutGrad
)
{
if
(
isBatch
)
prevOutValue
+=
batchIdx
*
frameSize
;
rPrevOutValue
=
prevOutValue
[
frameIdx
];
if
(
isBatch
)
prevOutGrad
+=
batchIdx
*
frameSize
;
rPrevOutGrad
=
prevOutGrad
[
frameIdx
];
}
hppl
::
gpu
::
BackwardAct
<
T
>
act
;
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
act
(
active_node
));
gateGrad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateGrad
;
gateGrad
[
frameIdx
+
frameSize
*
2
]
=
rFrameStateGrad
;
if
(
prevOutGrad
)
{
prevOutGrad
[
frameIdx
]
=
rPrevOutGrad
;
}
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template
<
class
OpResetGrad
,
bool
isBatch
,
typename
T
>
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
opResetGrad
,
T
*
gateValue
,
T
*
gateGrad
,
T
*
prevOutValue
,
T
*
prevOutGrad
,
T
*
resetOutputGrad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_gate
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
int
batchIdx
=
0
;
if
(
isBatch
)
{
batchIdx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batchIdx
>=
batchSize
)
return
;
gateValue
+=
batchIdx
*
3
*
frameSize
;
gateGrad
+=
batchIdx
*
3
*
frameSize
;
resetOutputGrad
+=
batchIdx
*
frameSize
;
}
T
rResetGateGrad
;
T
rPrevOutValue
=
0
;
T
rPrevOutGrad
=
0
;
T
rResetOutputGrad
=
0
;
T
rUpdateGateValue
=
gateValue
[
frameIdx
+
frameSize
*
0
];
T
rUpdateGateGrad
=
gateGrad
[
frameIdx
+
frameSize
*
0
];
T
rResetGateValue
=
gateValue
[
frameIdx
+
frameSize
*
1
];
if
(
prevOutValue
&&
prevOutGrad
)
{
if
(
isBatch
)
prevOutValue
+=
batchIdx
*
frameSize
;
if
(
isBatch
)
prevOutGrad
+=
batchIdx
*
frameSize
;
rPrevOutValue
=
prevOutValue
[
frameIdx
];
rPrevOutGrad
=
prevOutGrad
[
frameIdx
];
rResetOutputGrad
=
resetOutputGrad
[
frameIdx
];
}
hppl
::
gpu
::
BackwardAct
<
T
>
act
;
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
act
(
active_gate
));
gateGrad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateGrad
;
gateGrad
[
frameIdx
+
frameSize
*
1
]
=
rResetGateGrad
;
if
(
prevOutGrad
)
{
prevOutGrad
[
frameIdx
]
=
rPrevOutGrad
;
}
}
}
// namespace detail
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/detail/gru_kernel.h
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/platform/hostdevice.h"
#include <type_traits>
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
detail
{
namespace
forward
{
template
<
typename
T
>
class
gru_resetOutput
{
public:
/**
* @param[in,out] valueUpdateGate update gate
* @param[in,out] valueResetGate reset gate
* @param[in] prevOut previous output
* @param[out] valueResetOutput intermediate value for frame state
* @param[in] actGate forward function of gate
*/
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueResetGate
,
T
&
prevOut
,
T
&
valueResetOutput
,
typename
hppl
::
Active
<
T
>::
forward
actGate
)
{
valueUpdateGate
=
actGate
(
valueUpdateGate
);
valueResetGate
=
actGate
(
valueResetGate
);
valueResetOutput
=
prevOut
*
valueResetGate
;
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueResetGate
,
__m256
&
prevOut
,
__m256
&
valueResetOutput
,
typename
hppl
::
Active
<
__m256
>::
forward
actGate
)
{
valueUpdateGate
=
actGate
(
valueUpdateGate
);
valueResetGate
=
actGate
(
valueResetGate
);
valueResetOutput
=
_mm256_mul_ps
(
prevOut
,
valueResetGate
);
}
#endif
#endif
};
template
<
typename
T
>
class
gru_finalOutput
{
public:
/**
* @param[in] valueUpdateGate update gate
* @param[in,out] valueFrameState frame state ({\tilde{h}_t})
* @param[in] prevOut previous output
* @param[out] valueOutput output
* @param[in] actInput forward function of node
*/
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueFrameState
,
T
&
prevOut
,
T
&
valueOutput
,
typename
hppl
::
Active
<
T
>::
forward
actInput
)
{
valueFrameState
=
actInput
(
valueFrameState
);
valueOutput
=
prevOut
-
(
valueUpdateGate
*
prevOut
)
+
(
valueUpdateGate
*
valueFrameState
);
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueFrameState
,
__m256
&
prevOut
,
__m256
&
valueOutput
,
typename
hppl
::
Active
<
__m256
>::
forward
actInput
)
{
valueFrameState
=
actInput
(
valueFrameState
);
valueOutput
=
_mm256_add_ps
(
_mm256_sub_ps
(
prevOut
,
_mm256_mul_ps
(
valueUpdateGate
,
prevOut
)),
_mm256_mul_ps
(
valueUpdateGate
,
valueFrameState
));
}
#endif
#endif
};
}
// namespace forward
namespace
backward
{
template
<
typename
T
>
class
gru_stateGrad
{
public:
/**
* @param[in] valueUpdateGate update gate value
* @param[out] gradUpdateGate update gate grad
* @param[in] valueFrameState frame state value
* @param[out] gradFrameState frame state grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradOutput output grad
* @param[in] actInput backward function of frame state
*/
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
T
&
valueFrameState
,
T
&
gradFrameState
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
gradOutput
,
typename
hppl
::
Active
<
T
>::
backward
actInput
)
{
gradUpdateGate
=
(
gradOutput
*
valueFrameState
);
gradUpdateGate
-=
(
gradOutput
*
valuePrevOut
);
gradPrevOut
-=
(
gradOutput
*
valueUpdateGate
);
gradPrevOut
+=
gradOutput
;
gradFrameState
=
actInput
(
gradOutput
*
valueUpdateGate
,
valueFrameState
);
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
__m256
&
valueFrameState
,
__m256
&
gradFrameState
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
gradOutput
,
typename
hppl
::
Active
<
__m256
>::
backward
actInput
)
{
gradUpdateGate
=
_mm256_mul_ps
(
gradOutput
,
valueFrameState
);
gradUpdateGate
=
_mm256_sub_ps
(
gradUpdateGate
,
_mm256_mul_ps
(
gradOutput
,
valuePrevOut
));
gradPrevOut
=
_mm256_add_ps
(
_mm256_sub_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
)),
gradOutput
);
gradFrameState
=
actInput
(
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
),
valueFrameState
);
}
#endif
#endif
};
template
<
typename
T
>
class
gru_resetGrad
{
public:
/**
* @param[in] valueUpdateGate update gate value
* @param[in,out] gradUpdateGate update gate grad
* @param[in] valueResetGate reset gate value
* @param[out] gradResetGate reset gate grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradResetOutput reset output grad (temp val)
* @param[in] actGate backward function of gate
*/
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
T
&
valueResetGate
,
T
&
gradResetGate
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
gradResetOutput
,
typename
hppl
::
Active
<
T
>::
backward
actGate
)
{
gradResetGate
=
(
gradResetOutput
*
valuePrevOut
);
gradPrevOut
+=
(
gradResetOutput
*
valueResetGate
);
gradUpdateGate
=
actGate
(
gradUpdateGate
,
valueUpdateGate
);
gradResetGate
=
actGate
(
gradResetGate
,
valueResetGate
);
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
__m256
&
valueResetGate
,
__m256
&
gradResetGate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
gradResetOutput
,
typename
hppl
::
Active
<
__m256
>::
backward
actGate
)
{
gradResetGate
=
_mm256_mul_ps
(
gradResetOutput
,
valuePrevOut
);
gradPrevOut
=
_mm256_add_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradResetOutput
,
valueResetGate
));
gradUpdateGate
=
actGate
(
gradUpdateGate
,
valueUpdateGate
);
gradResetGate
=
actGate
(
gradResetGate
,
valueResetGate
);
}
#endif
#endif
};
}
// namespace backward
}
// namespace detail
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/gru_compute.cc
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/operators/math/detail/gru_kernel.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
if
(
value
.
prevOutValue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batchSize
,
frameSize
*
2
,
frameSize
,
1
,
value
.
prevOutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
gateValue
,
frameSize
*
3
);
}
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
frameSize
,
batchSize
,
active_gate
);
if
(
value
.
prevOutValue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batchSize
,
frameSize
,
frameSize
,
1
,
value
.
resetOutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
gateValue
+
frameSize
*
2
,
frameSize
*
3
);
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
frameSize
,
batchSize
,
active_node
);
#endif
}
};
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frameSize
,
batchSize
,
active_node
);
if
(
value
.
prevOutValue
&&
grad
.
prevOutGrad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batchSize
,
frameSize
,
frameSize
,
1
,
grad
.
gateGrad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateWeight
,
frameSize
,
0
,
grad
.
resetOutputGrad
,
frameSize
);
if
(
grad
.
stateWeightGrad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
}
}
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
grad
,
frameSize
,
batchSize
,
active_gate
);
if
(
grad
.
prevOutGrad
&&
value
.
prevOutValue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batchSize
,
frameSize
,
frameSize
*
2
,
1
,
grad
.
gateGrad
,
frameSize
*
3
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
grad
.
prevOutGrad
,
frameSize
);
if
(
grad
.
gateWeightGrad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
*
2
,
batchSize
,
1
,
value
.
prevOutValue
,
frameSize
,
grad
.
gateGrad
,
frameSize
*
3
,
1
,
grad
.
gateWeightGrad
,
frameSize
*
2
);
}
}
#endif
}
};
template
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
float
>;
template
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
double
>;
template
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/gru_compute.cu
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/operators/math/detail/gru_kernel.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
grid
;
if
(
batchSize
==
1
)
{
int
framePerBlock
=
frameSize
<=
1024
?
frameSize
:
1024
;
int
frameBlocks
=
(
frameSize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
framePerBlock
,
1
);
grid
=
dim3
(
frameBlocks
,
1
);
}
else
{
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frameSize
+
32
-
1
)
/
32
,
(
batchSize
+
32
-
1
)
/
32
);
}
if
(
value
.
prevOutValue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batchSize
,
frameSize
*
2
,
frameSize
,
1
,
value
.
prevOutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
gateValue
,
frameSize
*
3
);
}
if
(
batchSize
==
1
)
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* isBatch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gateValue
,
value
.
resetOutputValue
,
value
.
prevOutValue
,
frameSize
,
batchSize
,
active_gate
);
}
else
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* isBatch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gateValue
,
value
.
resetOutputValue
,
value
.
prevOutValue
,
frameSize
,
batchSize
,
active_gate
);
}
if
(
value
.
prevOutValue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batchSize
,
frameSize
,
frameSize
,
1
,
value
.
resetOutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
gateValue
+
frameSize
*
2
,
frameSize
*
3
);
}
if
(
batchSize
==
1
)
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* isBatch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gateValue
,
value
.
prevOutValue
,
value
.
outputValue
,
frameSize
,
batchSize
,
active_node
);
}
else
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* isBatch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gateValue
,
value
.
prevOutValue
,
value
.
outputValue
,
frameSize
,
batchSize
,
active_node
);
}
}
};
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
grid
;
if
(
batchSize
==
1
)
{
int
framePerBlock
=
frameSize
<=
1024
?
frameSize
:
1024
;
int
frameBlocks
=
(
frameSize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
framePerBlock
,
1
);
grid
=
dim3
(
frameBlocks
,
1
);
}
else
{
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frameSize
+
32
-
1
)
/
32
,
(
batchSize
+
32
-
1
)
/
32
);
}
if
(
batchSize
==
1
)
{
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
batchSize
,
active_node
);
}
else
{
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
batchSize
,
active_node
);
}
if
(
value
.
prevOutValue
&&
grad
.
prevOutGrad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batchSize
,
frameSize
,
frameSize
,
1
,
grad
.
gateGrad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateWeight
,
frameSize
,
0
,
grad
.
resetOutputGrad
,
frameSize
);
if
(
grad
.
stateWeightGrad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
}
}
if
(
batchSize
==
1
)
{
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
batchSize
,
active_gate
);
}
else
{
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
batchSize
,
active_gate
);
}
if
(
grad
.
prevOutGrad
&&
value
.
prevOutValue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batchSize
,
frameSize
,
frameSize
*
2
,
1
,
grad
.
gateGrad
,
frameSize
*
3
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
grad
.
prevOutGrad
,
frameSize
);
if
(
grad
.
gateWeightGrad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
*
2
,
batchSize
,
1
,
value
.
prevOutValue
,
frameSize
,
grad
.
gateGrad
,
frameSize
*
3
,
1
,
grad
.
gateWeightGrad
,
frameSize
*
2
);
}
}
}
};
template
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
float
>;
template
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
double
>;
template
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
float
>;
template
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
\ No newline at end of file
paddle/operators/math/gru_compute.h
0 → 100644
浏览文件 @
b87eabae
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
// typedef enum {
// HL_ACTIVATION_SIGMOID = 0,
// HL_ACTIVATION_RELU = 1,
// HL_ACTIVATION_TANH = 2,
// HL_ACTIVATION_LINEAR = 3,
// HL_ACTIVATION_END
// } activation_mode_t;
// inline activation_mode_t ActiveType(const std::string &type) {
// if (type == "sigmoid") {
// return HL_ACTIVATION_SIGMOID;
// } else if (type == "relu") {
// return HL_ACTIVATION_RELU;
// } else if (type == "tanh") {
// return HL_ACTIVATION_TANH;
// } else if (type == "linear" || type == "") {
// return HL_ACTIVATION_LINEAR;
// } else {
// PADDLE_THROW("Do not support activation type.");
// }
// }
template
<
typename
T
>
struct
hl_gru_value
{
T
*
gateWeight
;
T
*
stateWeight
;
T
*
gateValue
;
T
*
resetOutputValue
;
T
*
outputValue
;
T
*
prevOutValue
;
};
template
<
typename
T
>
struct
hl_gru_grad
{
T
*
gateWeightGrad
;
T
*
stateWeightGrad
;
T
*
gateGrad
;
T
*
resetOutputGrad
;
T
*
outputGrad
;
T
*
prevOutGrad
;
};
template
<
typename
Place
,
typename
T
>
struct
GRUUnitFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
};
template
<
typename
Place
,
typename
T
>
struct
GRUUnitGradFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/sequence2batch.h
浏览文件 @
b87eabae
...
...
@@ -21,6 +21,128 @@ namespace paddle {
namespace
operators
{
namespace
math
{
// template <typename Place, typename T>
// class CopyMatrixRowsFunctor {
// public:
// // If is_src_index is true,
// // copy the indexed rows of input src to the output dst.
// // If is_src_index is false,
// // copy the input src to the indexed rows of output dst.
// // The indexed rows are based on the input index.
// void operator()(const platform::DeviceContext& context,
// const framework::LoDTensor& src, const size_t* index,
// framework::LoDTensor& dst, bool is_src_index);
// };
// template <typename Place, typename T>
// class LoDTensor2BatchFunctor {
// // Calculate the length of each sequence and
// // sort sequence index by the length.
// // example: sequences = {s0, s1, s2}
// // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
// //
// struct SeqInfo {
// SeqInfo(int start, int length, int seq_idx)
// : start(start), length(length), seq_idx(seq_idx) {}
// int start;
// int length;
// int seq_idx;
// };
// public:
// void operator()(const platform::DeviceContext& context,
// const framework::LoDTensor& lod_tensor,
// framework::LoDTensor& batch, bool is_reverse) const {
// auto lods = lod_tensor.lod();
// PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence
// now.");
// auto lod = lods[0];
// std::vector<SeqInfo> seq_info;
// for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
// int length = lod[seq_id + 1] - lod[seq_id];
// seq_info.emplace_back(lod[seq_id], length, seq_id);
// }
// std::sort(seq_info.begin(), seq_info.end(),
// [](SeqInfo a, SeqInfo b) { return a.length > b.length; });
// // calculate the start position of each batch
// // (numBatch equal the maxLength of sequences)
// // example: sequences = {s0, s1, s2}
// // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// // num_batch = 5,
// // batchIndex = {b0, b1, b2, b3, b4}
// // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// // batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// // batch_start_positions[0] = len(b0)
// // batch_start_positions[1] = len(b0) + len(b1)
// // batch_start_positions[2] = len(b0) + len(b1) + len(b2)
// // ...
// // seq2batch_idx[12] = {4, 0, 9,
// // 5, 1, 10,
// // 6, 2, 11,
// // 7, 3,
// // 8}
// // The batch number represents batch size after rearranging the
// // input LodTensor. It is also the maximum length of input sequence.
// paddle::framework::LoD batch_lods;
// batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods.emplace_back(std::vector<size_t>{0});
// // batch_lods[0] is the start positions for batch LoDTensor
// int num_batch = seq_info[0].length;
// batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// // batch_lods[1] is the raw index in the input LoDTensor
// auto dims = lod_tensor.dims();
// batch_lods[1].resize(static_cast<size_t>(dims[0]));
// size_t* batch_starts = batch_lods[0].data();
// size_t* seq2batch_idx = batch_lods[1].data();
// batch_starts[0] = 0;
// for (size_t n = 0; n < num_batch; n++) {
// auto batch_id = static_cast<int>(batch_starts[n]);
// for (size_t i = 0; i < seq_info.size(); ++i) {
// size_t seq_len = seq_info[i].length;
// int start = seq_info[i].start;
// if (n < seq_len) {
// seq2batch_idx[batch_id] =
// is_reverse ? start + seq_len - 1 - n : start + n;
// batch_id++;
// } else {
// break;
// }
// }
// batch_starts[n + 1] = static_cast<size_t>(batch_id);
// }
// batch.set_lod(batch_lods);
// CopyMatrixRowsFunctor<Place, T> to_batch;
// to_batch(context, lod_tensor, seq2batch_idx, batch, true);
// }
// };
// template <typename Place, typename T>
// class Batch2LoDTensorFunctor {
// public:
// void operator()(const platform::DeviceContext& context,
// const framework::LoDTensor& batch,
// framework::LoDTensor& lod_tensor) const {
// auto in_lod = batch.lod();
// PADDLE_ENFORCE_EQ(in_lod.size(), 2UL,
// "The LoD size of input `batch` should be 2.");
// auto out_lod = lod_tensor.lod()[0];
// auto num = out_lod[out_lod.size() - 1];
// PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]);
// PADDLE_ENFORCE_EQ(num, in_lod[1].size());
// PADDLE_ENFORCE_EQ(num, batch.dims()[0]);
// CopyMatrixRowsFunctor<Place, T> to_seq;
// size_t* index = in_lod[1].data();
// to_seq(context, batch, index, lod_tensor, false);
// }
// };
template
<
typename
Place
,
typename
T
>
class
CopyMatrixRowsFunctor
{
public:
...
...
@@ -53,7 +175,18 @@ class LoDTensor2BatchFunctor {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
lod_tensor
,
framework
::
LoDTensor
&
batch
,
bool
is_reverse
)
const
{
framework
::
LoDTensor
&
batch
,
bool
is_reverse
=
false
,
bool
is_cal_batch_lod
=
true
)
const
{
if
(
!
is_cal_batch_lod
)
{
auto
lods
=
batch
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
2UL
);
PADDLE_ENFORCE_EQ
(
lods
[
1
].
size
(),
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
to_batch
(
context
,
lod_tensor
,
lods
[
1
].
data
(),
batch
,
true
);
return
;
}
auto
lods
=
lod_tensor
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
auto
lod
=
lods
[
0
];
...
...
@@ -101,10 +234,10 @@ class LoDTensor2BatchFunctor {
size_t
*
batch_starts
=
batch_lods
[
0
].
data
();
size_t
*
seq2batch_idx
=
batch_lods
[
1
].
data
();
batch_starts
[
0
]
=
0
;
for
(
size_
t
n
=
0
;
n
<
num_batch
;
n
++
)
{
for
(
in
t
n
=
0
;
n
<
num_batch
;
n
++
)
{
auto
batch_id
=
static_cast
<
int
>
(
batch_starts
[
n
]);
for
(
size_t
i
=
0
;
i
<
seq_info
.
size
();
++
i
)
{
size_
t
seq_len
=
seq_info
[
i
].
length
;
in
t
seq_len
=
seq_info
[
i
].
length
;
int
start
=
seq_info
[
i
].
start
;
if
(
n
<
seq_len
)
{
seq2batch_idx
[
batch_id
]
=
...
...
@@ -132,11 +265,8 @@ class Batch2LoDTensorFunctor {
auto
in_lod
=
batch
.
lod
();
PADDLE_ENFORCE_EQ
(
in_lod
.
size
(),
2UL
,
"The LoD size of input `batch` should be 2."
);
auto
out_lod
=
lod_tensor
.
lod
()[
0
];
auto
num
=
out_lod
[
out_lod
.
size
()
-
1
];
PADDLE_ENFORCE_EQ
(
num
,
lod_tensor
.
dims
()[
0
]);
PADDLE_ENFORCE_EQ
(
num
,
in_lod
[
1
].
size
());
PADDLE_ENFORCE_EQ
(
num
,
batch
.
dims
()[
0
]);
PADDLE_ENFORCE_EQ
(
in_lod
[
1
].
size
(),
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
CopyMatrixRowsFunctor
<
Place
,
T
>
to_seq
;
size_t
*
index
=
in_lod
[
1
].
data
();
to_seq
(
context
,
batch
,
index
,
lod_tensor
,
false
);
...
...
python/paddle/v2/framework/tests/test_gru_op.py
0 → 100644
浏览文件 @
b87eabae
import
unittest
import
numpy
as
np
import
math
from
op_test
import
OpTest
SIGMOID_THRESHOLD_MIN
=
-
40.0
SIGMOID_THRESHOLD_MAX
=
13.0
EXP_MAX_INPUT
=
40.0
def
identity
(
x
):
return
x
def
sigmoid
(
x
):
y
=
np
.
copy
(
x
)
y
[
x
<
SIGMOID_THRESHOLD_MIN
]
=
SIGMOID_THRESHOLD_MIN
y
[
x
>
SIGMOID_THRESHOLD_MAX
]
=
SIGMOID_THRESHOLD_MAX
return
1.
/
(
1.
+
np
.
exp
(
-
y
))
def
tanh
(
x
):
y
=
-
2.
*
x
y
[
y
>
EXP_MAX_INPUT
]
=
EXP_MAX_INPUT
return
(
2.
/
(
1.
+
np
.
exp
(
y
)))
-
1.
def
relu
(
x
):
return
np
.
maximum
(
x
,
0
)
class
TestGRUOp
(
OpTest
):
batch_size
=
9
frame_size
=
5
activate
=
{
'identity'
:
identity
,
'sigmoid'
:
sigmoid
,
'tanh'
:
tanh
,
'relu'
:
relu
}
@
staticmethod
def
seq_to_batch
(
lod
,
is_reverse
):
idx_in_seq_list
=
[]
seq_starts
=
lod
[
0
]
seq_lens
=
[]
for
i
in
range
(
len
(
seq_starts
)
-
1
):
seq_lens
.
append
(
seq_starts
[
i
+
1
]
-
seq_starts
[
i
])
sorted_seqs
=
sorted
(
range
(
len
(
seq_lens
)),
lambda
x
,
y
:
seq_lens
[
y
]
-
seq_lens
[
x
])
num_batch
=
seq_lens
[
sorted_seqs
[
0
]]
for
batch_idx
in
range
(
num_batch
):
idx_in_seq
=
[]
for
i
in
range
(
len
(
seq_lens
)):
if
seq_lens
[
sorted_seqs
[
i
]]
<=
batch_idx
:
break
idx
=
(
seq_starts
[
sorted_seqs
[
i
]
+
1
]
-
1
-
batch_idx
)
if
is_reverse
else
(
seq_starts
[
sorted_seqs
[
i
]]
+
batch_idx
)
idx_in_seq
.
append
(
idx
)
idx_in_seq_list
.
append
(
idx_in_seq
)
return
idx_in_seq_list
def
gru_step
(
self
,
x
,
h_p
,
w
,
b
):
print
x
.
shape
,
h_p
.
shape
,
w
.
shape
,
b
.
shape
batch_size
=
x
.
shape
[
0
]
frame_size
=
w
.
shape
[
0
]
g
=
x
+
np
.
tile
(
b
,
(
batch_size
,
1
))
w_u_r
=
w
.
flatten
()[:
frame_size
*
frame_size
*
2
].
reshape
(
(
frame_size
,
frame_size
*
2
))
u_r
=
self
.
activate
[
self
.
attrs
[
'gate_activation'
]](
np
.
dot
(
h_p
,
w_u_r
)
+
g
[:,
:
frame_size
*
2
])
u
=
u_r
[:,
:
frame_size
]
r
=
u_r
[:,
frame_size
:
frame_size
*
2
]
r_h_p
=
r
*
h_p
w_c
=
w
.
flatten
()[
frame_size
*
frame_size
*
2
:].
reshape
(
(
frame_size
,
frame_size
))
c
=
self
.
activate
[
self
.
attrs
[
'activation'
]](
np
.
dot
(
r_h_p
,
w_c
)
+
g
[:,
frame_size
*
2
:])
g
=
np
.
hstack
((
u_r
,
c
))
h
=
u
*
c
+
(
1
-
u
)
*
h_p
return
g
,
r_h_p
,
h
def
gru
(
self
):
input
,
lod
=
self
.
inputs
[
'Input'
]
w
=
self
.
inputs
[
'Weight'
]
b
=
self
.
inputs
[
'Bias'
]
if
self
.
inputs
.
has_key
(
'Bias'
)
else
np
.
zeros
(
(
1
,
self
.
frame_size
*
3
))
batch_gate
=
self
.
outputs
[
'BatchGate'
]
batch_reset_hidden_prev
=
self
.
outputs
[
'BatchResetHiddenPrev'
]
batch_hidden
=
self
.
outputs
[
'BatchHidden'
]
hidden
=
self
.
outputs
[
'Hidden'
]
idx_in_seq_list
=
self
.
idx_in_seq_list
h_p
=
self
.
inputs
[
'H0'
]
if
self
.
inputs
.
has_key
(
'H0'
)
else
np
.
zeros
(
(
len
(
idx_in_seq_list
[
0
]),
self
.
frame_size
))
num_batch
=
len
(
idx_in_seq_list
)
end_idx
=
0
for
batch_idx
in
range
(
num_batch
):
print
idx_in_seq_list
[
batch_idx
]
x
=
input
[
idx_in_seq_list
[
batch_idx
]]
g
,
r_h_p
,
h
=
self
.
gru_step
(
x
,
h_p
,
w
,
b
)
if
batch_idx
<
(
num_batch
-
1
):
h_p
=
h
[:
len
(
idx_in_seq_list
[
batch_idx
+
1
])]
start_idx
=
end_idx
end_idx
=
start_idx
+
len
(
idx_in_seq_list
[
batch_idx
])
batch_gate
[
start_idx
:
end_idx
]
=
g
batch_reset_hidden_prev
[
start_idx
:
end_idx
]
=
r_h_p
batch_hidden
[
start_idx
:
end_idx
]
=
h
hidden
[
idx_in_seq_list
[
batch_idx
]]
=
h
return
batch_gate
,
batch_reset_hidden_prev
,
hidden
def
set_data
(
self
):
lod
=
[[
0
,
2
,
6
,
9
]]
#[[0, 1, 2, 3]]
self
.
idx_in_seq_list
=
self
.
seq_to_batch
(
lod
,
self
.
is_reverse
)
print
self
.
idx_in_seq_list
batch_size
=
self
.
batch_size
frame_size
=
self
.
frame_size
input
=
np
.
random
.
rand
(
batch_size
,
frame_size
*
3
).
astype
(
'float64'
)
h0
=
np
.
random
.
rand
(
len
(
self
.
idx_in_seq_list
[
0
]),
frame_size
).
astype
(
'float64'
)
weight
=
np
.
random
.
rand
(
frame_size
,
frame_size
*
3
).
astype
(
'float64'
)
bias
=
np
.
random
.
rand
(
1
,
frame_size
*
3
).
astype
(
'float64'
)
self
.
inputs
=
{
'Input'
:
(
input
,
lod
),
'H0'
:
h0
,
'Weight'
:
weight
,
'Bias'
:
bias
}
self
.
outputs
=
{
'BatchGate'
:
np
.
zeros
(
(
batch_size
,
frame_size
*
3
),
dtype
=
'float64'
),
'BatchResetHiddenPrev'
:
np
.
zeros
(
(
batch_size
,
frame_size
),
dtype
=
'float64'
),
'BatchHidden'
:
np
.
zeros
(
(
batch_size
,
frame_size
),
dtype
=
'float64'
),
'Hidden'
:
np
.
zeros
(
(
batch_size
,
frame_size
),
dtype
=
'float64'
)
}
def
set_confs
(
self
):
self
.
is_reverse
=
False
self
.
attrs
=
{
'activation'
:
'tanh'
,
'gate_activation'
:
'sigmoid'
,
'is_reverse'
:
self
.
is_reverse
}
def
setUp
(
self
):
self
.
op_type
=
"gru"
self
.
set_confs
()
self
.
set_data
()
self
.
gru
()
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Input'
,
'H0'
,
'Weight'
,
'Bias'
],
[
'Hidden'
])
class
TestGRUOpNoInitial
(
TestGRUOp
):
def
set_data
(
self
):
super
(
TestGRUOpNoInitial
,
self
).
set_data
()
self
.
inputs
.
pop
(
'H0'
)
def
test_check_grad
(
self
):
self
.
check_grad
([
'Input'
,
'Weight'
,
'Bias'
],
[
'Hidden'
])
class
TestGRUOpReverse
(
TestGRUOp
):
def
set_confs
(
self
):
self
.
is_reverse
=
True
self
.
attrs
=
{
'activation'
:
'identity'
,
'gate_activation'
:
'sigmoid'
,
'is_reverse'
:
self
.
is_reverse
}
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录