Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
34aac18c
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
34aac18c
编写于
10月 22, 2017
作者:
Q
qingqing01
提交者:
GitHub
10月 22, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3 from reyoung/pr/4929
Several Enhancement
上级
694bc64a
65906ef1
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
102 addition
and
97 deletion
+102
-97
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+8
-8
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+9
-9
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+42
-41
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+5
-4
paddle/operators/math/lstm_compute.cu
paddle/operators/math/lstm_compute.cu
+5
-4
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+5
-4
paddle/operators/math/sequence2batch.cc
paddle/operators/math/sequence2batch.cc
+0
-2
paddle/operators/math/sequence2batch.cu
paddle/operators/math/sequence2batch.cu
+1
-1
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+27
-24
未找到文件。
paddle/operators/lstm_op.cc
浏览文件 @
34aac18c
...
...
@@ -68,7 +68,7 @@ class LSTMOp : public framework::OperatorWithKernel {
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Bias) should be "
"4 * %d if diable peepholes connection"
,
"4 * %d if di
s
able peepholes connection"
,
frame_size
);
}
ctx
->
SetOutputDim
(
"Hidden"
,
{
x_dims
[
0
],
frame_size
});
...
...
@@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Input"
,
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTen
os
r is a matrix with shape (T X 4D), where, T is the "
"this LoDTen
so
r is a matrix with shape (T X 4D), where, T is the "
"total time steps in this mini-batch, D is the hidden size."
);
AddInput
(
"H0"
,
"(Tensor, optional) the initial hidden state is an optional "
...
...
@@ -112,7 +112,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."
);
AddOutput
(
"BatchGate"
,
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after
n
the nonlinear computation. This "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
...
...
@@ -135,18 +135,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"gateActivation"
,
"(string, defa
lu
t: sigmoid)"
"(string, defa
ul
t: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by defa
lu
t."
)
"gate, `sigmoid` by defa
ul
t."
)
.
SetDefault
(
"sigmoid"
);
AddAttr
<
std
::
string
>
(
"cellActivation"
,
"(string, defa
lu
t: tanh)"
"(string, defa
ul
t: tanh)"
"The activation for cell output, `tanh` by defalut."
)
.
SetDefault
(
"tanh"
);
AddAttr
<
std
::
string
>
(
"candidateActivation"
,
"(string, defa
lu
t: tanh)"
"(string, defa
ul
t: tanh)"
"The activation for candidate hidden state, "
"`tanh` by defa
lu
t."
)
"`tanh` by defa
ul
t."
)
.
SetDefault
(
"tanh"
);
AddComment
(
R"DOC(Long-Short Term Memory (LSTM) Operator
...
...
paddle/operators/lstm_op.h
浏览文件 @
34aac18c
...
...
@@ -52,7 +52,7 @@ class LSTMKernel : public framework::OpKernel<T> {
to_batch
(
ctx
.
device_context
(),
*
input
,
*
batch_gate
,
is_reverse
);
auto
in_dims
=
input
->
dims
();
int
frame_size
=
in_dims
[
1
]
/
4
;
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
)
;
framework
::
DDim
dims
({
in_dims
[
0
],
frame_size
});
if
(
bias
)
{
...
...
@@ -70,7 +70,7 @@ class LSTMKernel : public framework::OpKernel<T> {
math
::
LstmMetaValue
<
T
>
lstm_value
;
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
// the code sty
p
le in LstmMetaValue will be updated later.
// the code style in LstmMetaValue will be updated later.
lstm_value
.
checkIg
=
bias_data
+
4
*
frame_size
;
lstm_value
.
checkFg
=
lstm_value
.
checkIg
+
frame_size
;
lstm_value
.
checkOg
=
lstm_value
.
checkFg
+
frame_size
;
...
...
@@ -83,15 +83,15 @@ class LSTMKernel : public framework::OpKernel<T> {
framework
::
LoDTensor
batch_cell_pre_act
;
batch_cell_pre_act
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
auto
batch_lod
=
batch_gate
->
lod
()[
0
];
int
num_batch
=
batch_lod
.
size
()
-
1
;
auto
&
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gateActivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cellActivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidateActivation"
);
for
(
in
t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
batch_lod
[
n
]
;
int
bend
=
batch_lod
[
n
+
1
]
;
for
(
size_
t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
])
;
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
])
;
Tensor
gate_t
=
batch_gate
->
Slice
<
T
>
(
bstart
,
bend
);
Tensor
out_t
=
batch_out
.
Slice
<
T
>
(
bstart
,
bend
);
...
...
@@ -101,14 +101,14 @@ class LSTMKernel : public framework::OpKernel<T> {
int
cur_batch_size
=
bend
-
bstart
;
if
(
n
!=
0
)
{
int
pre_h_start
=
batch_lod
[
n
-
1
]
;
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
])
;
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_out
.
Slice
<
T
>
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
ctx
.
device_context
(),
pre_hidden_t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
// else if : support the initial hidden and cell
// else if :
FIXME
support the initial hidden and cell
lstm_value
.
gateValue
=
gate_t
.
data
<
T
>
();
lstm_value
.
outputValue
=
out_t
.
data
<
T
>
();
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
34aac18c
...
...
@@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/platform/hostdevice.h"
#ifdef __CUDA_ARCH__
#define INLINE __device__ inline
#else
#define INLINE inline
#endif
#include <type_traits>
namespace
paddle
{
namespace
operators
{
...
...
@@ -30,12 +27,12 @@ namespace forward {
template
<
class
T
>
class
lstm
{
public:
INLIN
E
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
prevState
,
T
&
state
,
T
&
stateAtv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
typename
hppl
::
ForwardActType
<
T
>::
type
actInput
,
typename
hppl
::
ForwardActType
<
T
>::
type
actGate
,
typename
hppl
::
ForwardActType
<
T
>::
type
actState
)
{
HOSTDEVIC
E
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
prevState
,
T
&
state
,
T
&
stateAtv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
typename
hppl
::
ForwardActType
<
T
>::
type
actInput
,
typename
hppl
::
ForwardActType
<
T
>::
type
actGate
,
typename
hppl
::
ForwardActType
<
T
>::
type
actState
)
{
valueIn
=
actInput
(
valueIn
);
valueIg
=
actGate
(
valueIg
+
prevState
*
checkI
);
valueFg
=
actGate
(
valueFg
+
prevState
*
checkF
);
...
...
@@ -45,17 +42,19 @@ class lstm {
output
=
valueOg
*
stateAtv
;
}
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
// If not compiled with AVX instructs. Disable AVX by default
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
INLINE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
prevState
,
__m256
&
state
,
__m256
&
stateAtv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
hppl
::
Active
<
__m256
>::
forward
actInput
,
hppl
::
Active
<
__m256
>::
forward
actGate
,
hppl
::
Active
<
__m256
>::
forward
actState
)
{
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
prevState
,
__m256
&
state
,
__m256
&
stateAtv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
hppl
::
Active
<
__m256
>::
forward
actInput
,
hppl
::
Active
<
__m256
>::
forward
actGate
,
hppl
::
Active
<
__m256
>::
forward
actState
)
{
valueIn
=
actInput
(
valueIn
);
valueIg
=
actGate
(
_mm256_add_ps
(
valueIg
,
_mm256_mul_ps
(
prevState
,
checkI
)));
valueFg
=
actGate
(
_mm256_add_ps
(
valueFg
,
_mm256_mul_ps
(
prevState
,
checkF
)));
...
...
@@ -76,14 +75,15 @@ namespace backward {
template
<
class
T
>
class
lstm
{
public:
INLINE
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
gradIn
,
T
&
gradIg
,
T
&
gradFg
,
T
&
gradOg
,
T
&
prevState
,
T
&
prevStateGrad
,
T
&
state
,
T
&
stateGrad
,
T
&
stateAtv
,
T
&
outputGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
typename
hppl
::
BackwardActType
<
T
>::
type
actInput
,
typename
hppl
::
BackwardActType
<
T
>::
type
actGate
,
typename
hppl
::
BackwardActType
<
T
>::
type
actState
)
{
HOSTDEVICE
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
gradIn
,
T
&
gradIg
,
T
&
gradFg
,
T
&
gradOg
,
T
&
prevState
,
T
&
prevStateGrad
,
T
&
state
,
T
&
stateGrad
,
T
&
stateAtv
,
T
&
outputGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
typename
hppl
::
BackwardActType
<
T
>::
type
actInput
,
typename
hppl
::
BackwardActType
<
T
>::
type
actGate
,
typename
hppl
::
BackwardActType
<
T
>::
type
actState
)
{
gradOg
=
actGate
(
outputGrad
*
stateAtv
,
valueOg
);
stateGrad
+=
actState
(
outputGrad
*
valueOg
,
stateAtv
)
+
gradOg
*
checkO
;
gradIn
=
actInput
(
stateGrad
*
valueIg
,
valueIn
);
...
...
@@ -95,21 +95,22 @@ class lstm {
checkOGrad
=
gradOg
*
state
;
}
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
// If not compiled with AVX instructs. Disable AVX by default
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
INLINE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
gradIn
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradOg
,
__m256
&
prevState
,
__m256
&
prevStateGrad
,
__m256
&
state
,
__m256
&
stateGrad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
hppl
::
Active
<
__m256
>::
backward
actInput
,
hppl
::
Active
<
__m256
>::
backward
actGate
,
hppl
::
Active
<
__m256
>::
backward
actState
)
{
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
gradIn
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradOg
,
__m256
&
prevState
,
__m256
&
prevStateGrad
,
__m256
&
state
,
__m256
&
stateGrad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
hppl
::
Active
<
__m256
>::
backward
actInput
,
hppl
::
Active
<
__m256
>::
backward
actGate
,
hppl
::
Active
<
__m256
>::
backward
actState
)
{
gradOg
=
actGate
(
_mm256_mul_ps
(
outputGrad
,
stateAtv
),
valueOg
);
stateGrad
=
_mm256_add_ps
(
actState
(
_mm256_mul_ps
(
outputGrad
,
valueOg
),
stateAtv
),
stateGrad
);
...
...
paddle/operators/math/lstm_compute.cc
浏览文件 @
34aac18c
...
...
@@ -24,8 +24,8 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
...
...
@@ -45,8 +45,9 @@ template <class T>
struct
LstmUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
ActiveType
(
cand_act
),
...
...
paddle/operators/math/lstm_compute.cu
浏览文件 @
34aac18c
...
...
@@ -24,8 +24,8 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
detail
::
gpu_lstm_forward
<
T
>
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
...
...
@@ -36,8 +36,9 @@ template <class T>
struct
LstmUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
34aac18c
...
...
@@ -72,8 +72,8 @@ class LstmUnitFunctor {
public:
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
);
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
);
};
template
<
typename
Place
,
typename
T
>
...
...
@@ -81,8 +81,9 @@ class LstmUnitGradFunctor {
public:
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
);
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
);
};
}
// namespace math
...
...
paddle/operators/math/sequence2batch.cc
浏览文件 @
34aac18c
...
...
@@ -51,8 +51,6 @@ class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
template
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
LoDTensor2BatchFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
LoDTensor2BatchFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Batch2LoDTensorFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Batch2LoDTensorFunctor
<
platform
::
CPUPlace
,
double
>;
...
...
paddle/operators/math/sequence2batch.cu
浏览文件 @
34aac18c
...
...
@@ -21,7 +21,7 @@ namespace math {
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
CopyMatrixRowsKernel
(
const
T
*
src
,
T
*
dst
,
const
size_t
*
index
,
int64_t
height
,
int64_t
width
,
const
bool
is_src_index
)
{
bool
is_src_index
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
threadIdx
.
y
;
int
id
=
blockIdx
.
x
+
idy
*
GridDimX
;
...
...
paddle/operators/math/sequence2batch.h
浏览文件 @
34aac18c
...
...
@@ -31,33 +31,33 @@ class CopyMatrixRowsFunctor {
// The indexed rows are based on the input index.
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
src
,
const
size_t
*
index
,
framework
::
LoDTensor
&
dst
,
const
bool
is_src_index
);
framework
::
LoDTensor
&
dst
,
bool
is_src_index
);
};
template
<
typename
Place
,
typename
T
>
class
LoDTensor2BatchFunctor
{
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
//
struct
SeqInfo
{
SeqInfo
(
int
start
,
int
length
,
int
seq_idx
)
:
start
(
start
),
length
(
length
),
seq_idx
(
seq_idx
)
{}
int
start
;
int
length
;
int
seq_idx
;
};
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
lod_tensor
,
framework
::
LoDTensor
&
batch
,
const
bool
is_reverse
)
const
{
framework
::
LoDTensor
&
batch
,
bool
is_reverse
)
const
{
auto
lods
=
lod_tensor
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
auto
lod
=
lods
[
0
];
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
//
struct
SeqInfo
{
SeqInfo
(
int
start
,
int
length
,
int
seq_idx
)
:
start
(
start
),
length
(
length
),
seq_idx
(
seq_idx
)
{}
int
start
;
int
length
;
int
seq_idx
;
};
std
::
vector
<
SeqInfo
>
seq_info
;
for
(
size_t
seq_id
=
0
;
seq_id
<
lod
.
size
()
-
1
;
++
seq_id
)
{
int
length
=
lod
[
seq_id
+
1
]
-
lod
[
seq_id
];
...
...
@@ -75,31 +75,34 @@ class LoDTensor2BatchFunctor {
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = len(b0)
// batch_start_positions[1] = len(b0) + len(b1)
// batch_start_positions[2] = len(b0) + len(b1) + len(b2)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// The batch number represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence.
paddle
::
framework
::
LoD
batch_lods
;
batch_lods
.
push
_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
push
_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
emplace
_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
emplace
_back
(
std
::
vector
<
size_t
>
{
0
});
// batch_lods[0] is the start positions for batch LoDTensor
int
num_batch
=
(
size_t
)
seq_info
[
0
].
length
;
batch_lods
[
0
].
resize
(
num_batch
+
1
);
int
num_batch
=
seq_info
[
0
].
length
;
batch_lods
[
0
].
resize
(
static_cast
<
size_t
>
(
num_batch
+
1
)
);
// batch_lods[1] is the raw index in the input LoDTensor
auto
dims
=
lod_tensor
.
dims
();
batch_lods
[
1
].
resize
(
dims
[
0
]
);
batch_lods
[
1
].
resize
(
static_cast
<
size_t
>
(
dims
[
0
])
);
size_t
*
batch_starts
=
batch_lods
[
0
].
data
();
size_t
*
seq2batch_idx
=
batch_lods
[
1
].
data
();
batch_starts
[
0
]
=
0
;
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
batch_id
=
batch_starts
[
n
]
;
auto
batch_id
=
static_cast
<
int
>
(
batch_starts
[
n
])
;
for
(
size_t
i
=
0
;
i
<
seq_info
.
size
();
++
i
)
{
size_t
seq_len
=
seq_info
[
i
].
length
;
int
start
=
seq_info
[
i
].
start
;
...
...
@@ -114,7 +117,7 @@ class LoDTensor2BatchFunctor {
break
;
}
}
batch_starts
[
n
+
1
]
=
batch_id
;
batch_starts
[
n
+
1
]
=
static_cast
<
size_t
>
(
batch_id
)
;
}
batch
.
set_lod
(
batch_lods
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录