Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3cace737
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3cace737
编写于
10月 16, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add lstm implementation.
上级
9106a4bb
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
1436 addition
and
22 deletion
+1436
-22
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+44
-10
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+24
-11
paddle/operators/math/detail/hl_activation_functions.h
paddle/operators/math/detail/hl_activation_functions.h
+64
-0
paddle/operators/math/detail/hl_avx_functions.cc
paddle/operators/math/detail/hl_avx_functions.cc
+68
-0
paddle/operators/math/detail/hl_avx_functions.h
paddle/operators/math/detail/hl_avx_functions.h
+32
-0
paddle/operators/math/detail/hl_cpu_functions.cc
paddle/operators/math/detail/hl_cpu_functions.cc
+44
-0
paddle/operators/math/detail/hl_functions.h
paddle/operators/math/detail/hl_functions.h
+63
-0
paddle/operators/math/detail/hl_gpu_functions.h
paddle/operators/math/detail/hl_gpu_functions.h
+80
-0
paddle/operators/math/detail/lstm_cpu_kernel.h
paddle/operators/math/detail/lstm_cpu_kernel.h
+306
-0
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+244
-0
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+138
-0
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+73
-0
paddle/operators/math/lstm_compute.cu
paddle/operators/math/lstm_compute.cu
+73
-0
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+87
-0
paddle/operators/math/sequence2batch.cc
paddle/operators/math/sequence2batch.cc
+31
-0
paddle/operators/math/sequence2batch.cu
paddle/operators/math/sequence2batch.cu
+47
-0
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+18
-1
未找到文件。
paddle/operators/lstm_op.cc
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lstm_
unit_
op.h"
#include "paddle/operators/lstm_op.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -44,8 +44,36 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same."
);
}
int
frame_size
=
x_dims
[
1
];
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"The rank of Input(Weight) should be 2."
);
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
frame_size
,
"The first dimension of Input(Weight) "
"should be %d."
,
frame_size
);
PADDLE_ENFORCE_EQ
(
w_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Weight) "
"should be 4 * %d."
,
frame_size
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection"
,
frame_size
);
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Bias) should be "
"4 * %d if diable peepholes connection"
,
frame_size
);
}
ctx
->
SetOutputDim
(
"Hidden"
,
x_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
x_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
x_dims
);
ctx
->
ShareLoD
(
"Input"
,
"Hidden"
);
ctx
->
ShareLoD
(
"Input"
,
"Cell"
);
}
...
...
@@ -82,6 +110,8 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"2. `use_peepholes = True` "
" - The shape is (1 x 7*D). "
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."
);
AddOutput
(
"Batch"
,
"(LoDTensor) save the reorganized input as batch info. "
)
.
AsIntermediate
();
AddOutput
(
"Hidden"
,
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`."
);
...
...
@@ -92,6 +122,10 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"is_reverse"
,
"(bool, defalut: False) "
"whether to compute reversed LSTM."
)
.
SetDefault
(
true
);
AddAttr
<
std
::
string
>
(
"gate_activation"
,
"(string, defalut: sigmoid)"
...
...
paddle/operators/lstm_op.h
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
...
...
@@ -25,7 +24,21 @@ using framework::Tensor;
template
<
typename
Place
,
typename
T
>
class
LSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
);
auto
*
batch_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Batch"
);
auto
*
bias_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Bias"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
(
ctx
.
device_context
(),
input_t
,
batch_t
,
is_reverse
);
auto
in_dims
=
input_t
->
dims
();
int
frame_size
=
in_dims
[
1
];
if
(
bias_t
)
{
auto
b
=
EigenMatrix
<
T
>::
From
(
*
bias
);
}
}
};
template
<
typename
Place
,
typename
T
>
...
...
paddle/operators/math/detail/hl_activation_functions.h
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_ACTIVATION_FUNCTIONS_H_
#define HL_ACTIVATION_FUNCTIONS_H_
#include "hl_functions.h"
/**
* Active functions: sigmoid, relu, tanh and linear.
*/
#define HPPL_ACTIVE_FUNCTION \
{ hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear }
namespace
hppl
{
/**
* Hppl supports sigmoid, relu, tanh, linear active functions
* for neural networks' forward and backward activation.
*/
template
<
class
T
>
class
Active
{
public:
typedef
T
(
*
forward
)(
T
);
typedef
T
(
*
backward
)(
T
,
T
);
};
#ifdef __NVCC__
namespace
gpu
{
static
__device__
Active
<
float
>::
forward
forward
[]
=
HPPL_ACTIVE_FUNCTION
;
static
__device__
Active
<
float
>::
backward
backward
[]
=
HPPL_ACTIVE_FUNCTION
;
static
__device__
Active
<
double
>::
forward
forward
[]
=
HPPL_ACTIVE_FUNCTION
;
static
__device__
Active
<
double
>::
backward
backward
[]
=
HPPL_ACTIVE_FUNCTION
;
}
// namespace gpu
#else
namespace
cpu
{
static
Active
<
float
>::
forward
forward
[]
=
HPPL_ACTIVE_FUNCTION
;
static
Active
<
float
>::
backward
backward
[]
=
HPPL_ACTIVE_FUNCTION
;
static
Active
<
double
>::
forward
forward
[]
=
HPPL_ACTIVE_FUNCTION
;
static
Active
<
double
>::
backward
backward
[]
=
HPPL_ACTIVE_FUNCTION
;
}
// namespace cpu
#ifdef __AVX__
namespace
avx
{
static
Active
<
__m256
>::
forward
forward
[]
=
HPPL_ACTIVE_FUNCTION
;
static
Active
<
__m256
>::
backward
backward
[]
=
HPPL_ACTIVE_FUNCTION
;
}
// namespace avx
#endif
#endif
}
// namespace hppl
#endif // HL_ACTIVATION_FUNCTIONS_H_
paddle/operators/math/detail/hl_avx_functions.cc
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <immintrin.h>
#include "hl_functions.h"
namespace
hppl
{
extern
__m256
exp
(
__m256
a
);
__m256
relu
(
const
__m256
a
)
{
__m256
tmp
=
_mm256_set1_ps
(
0.0
f
);
return
_mm256_max_ps
(
a
,
tmp
);
}
__m256
sigmoid
(
const
__m256
a
)
{
__m256
max
=
_mm256_set1_ps
(
SIGMOID_THRESHOLD_MAX
);
__m256
min
=
_mm256_set1_ps
(
SIGMOID_THRESHOLD_MIN
);
__m256
tmp
=
_mm256_max_ps
(
a
,
min
);
tmp
=
_mm256_min_ps
(
tmp
,
max
);
tmp
=
_mm256_sub_ps
(
_mm256_set1_ps
(
0.0
f
),
tmp
);
tmp
=
exp
(
tmp
);
tmp
=
_mm256_add_ps
(
_mm256_set1_ps
(
1.0
f
),
tmp
);
tmp
=
_mm256_div_ps
(
_mm256_set1_ps
(
1.0
f
),
tmp
);
return
tmp
;
}
__m256
tanh
(
const
__m256
a
)
{
__m256
max
=
_mm256_set1_ps
(
EXP_MAX_INPUT
);
__m256
tmp
=
_mm256_mul_ps
(
_mm256_set1_ps
(
-
2.0
f
),
a
);
tmp
=
_mm256_min_ps
(
tmp
,
max
);
tmp
=
exp
(
tmp
);
return
_mm256_sub_ps
(
_mm256_div_ps
(
_mm256_set1_ps
(
2.0
f
),
_mm256_add_ps
(
_mm256_set1_ps
(
1.0
f
),
tmp
)),
_mm256_set1_ps
(
1.0
f
));
}
__m256
linear
(
const
__m256
a
)
{
return
a
;
}
__m256
relu
(
const
__m256
a
,
const
__m256
b
)
{
return
_mm256_mul_ps
(
a
,
_mm256_and_ps
(
_mm256_cmp_ps
(
b
,
_mm256_set1_ps
(
0.0
f
),
_CMP_GT_OS
),
_mm256_set1_ps
(
1.0
f
)));
}
__m256
sigmoid
(
const
__m256
a
,
const
__m256
b
)
{
return
_mm256_mul_ps
(
_mm256_mul_ps
(
a
,
b
),
_mm256_sub_ps
(
_mm256_set1_ps
(
1.0
f
),
b
));
}
__m256
tanh
(
const
__m256
a
,
const
__m256
b
)
{
return
_mm256_mul_ps
(
a
,
_mm256_sub_ps
(
_mm256_set1_ps
(
1.0
f
),
_mm256_mul_ps
(
b
,
b
)));
}
__m256
linear
(
const
__m256
a
,
const
__m256
b
)
{
return
a
;
}
}
// namespace hppl
paddle/operators/math/detail/hl_avx_functions.h
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_AVX_FUNCTIONS_H_
#define HL_AVX_FUNCTIONS_H_
#include <immintrin.h>
namespace
hppl
{
__m256
relu
(
const
__m256
a
);
__m256
sigmoid
(
const
__m256
a
);
__m256
tanh
(
const
__m256
a
);
__m256
linear
(
const
__m256
a
);
__m256
relu
(
const
__m256
a
,
const
__m256
b
);
__m256
sigmoid
(
const
__m256
a
,
const
__m256
b
);
__m256
tanh
(
const
__m256
a
,
const
__m256
b
);
__m256
linear
(
const
__m256
a
,
const
__m256
b
);
}
// namespace hppl
#endif // HL_AVX_FUNCTIONS_H_
paddle/operators/math/detail/hl_cpu_functions.cc
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "/paddle/operators/math/detail/hl_functions.h"
namespace
hppl
{
real
relu
(
const
real
a
)
{
return
a
>
0.0
f
?
a
:
0.0
f
;
}
real
sigmoid
(
const
real
a
)
{
const
real
min
=
SIGMOID_THRESHOLD_MIN
;
const
real
max
=
SIGMOID_THRESHOLD_MAX
;
real
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
1.0
/
(
1.0
+
exp
(
-
tmp
));
}
real
tanh
(
const
real
a
)
{
real
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
real
linear
(
const
real
a
)
{
return
a
;
}
real
relu
(
const
real
a
,
const
real
b
)
{
return
a
*
(
b
>
0.0
f
?
1.0
f
:
0.0
f
);
}
real
sigmoid
(
const
real
a
,
const
real
b
)
{
return
a
*
b
*
(
1
-
b
);
}
real
tanh
(
const
real
a
,
const
real
b
)
{
return
a
*
(
1.0
f
-
b
*
b
);
}
real
linear
(
const
real
a
,
const
real
b
)
{
return
a
;
}
}
// namespace hppl
paddle/operators/math/detail/hl_functions.h
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_FUNCTIONS_H_
#define HL_FUNCTIONS_H_
/**
* sigmoid threshold maximum
*/
#define SIGMOID_THRESHOLD_MIN -40.0
/**
* sigmoid threshold minimum
*/
#define SIGMOID_THRESHOLD_MAX 13.0
#ifndef __NVCC__
namespace
hppl
{
/*
* forward activation
*/
template
<
typename
T
>
T
relu
(
const
T
a
);
template
<
typename
T
>
T
sigmoid
(
const
T
a
);
template
<
typename
T
>
T
tanh
(
const
T
a
);
template
<
typename
T
>
T
linear
(
const
T
a
);
/*
* backward activation
*/
template
<
typename
T
>
T
relu
(
const
T
a
,
const
T
b
);
template
<
typename
T
>
T
sigmoid
(
const
T
a
,
const
T
b
);
template
<
typename
T
>
T
tanh
(
const
T
a
,
const
T
b
);
template
<
typename
T
>
T
linear
(
const
T
a
,
const
T
b
);
}
// namespace hppl
#ifdef __AVX__
#include "hl_avx_functions.h"
#endif
#else
#include "hl_gpu_functions.h"
#endif
#endif // HL_FUNCTIONS_H_
paddle/operators/math/detail/hl_gpu_functions.h
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_GPU_FUNCTIONS_CUH_
#define HL_GPU_FUNCTIONS_CUH_
#include "hl_base.h"
namespace
hppl
{
template
<
typename
T
>
__device__
static
T
relu
(
const
T
a
)
{
return
a
>
0.0
f
?
a
:
0.0
f
;
}
template
<
>
__device__
static
float
sigmoid
(
const
float
a
)
{
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
float
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
__fdividef
(
1.0
f
,
1.0
f
+
__expf
(
-
tmp
));
}
template
<
>
__device__
static
double
sigmoid
(
const
double
a
)
{
const
double
min
=
SIGMOID_THRESHOLD_MIN
;
const
double
max
=
SIGMOID_THRESHOLD_MAX
;
double
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
1.0
/
(
1.0
+
exp
(
-
tmp
));
}
template
<
>
__device__
static
float
tanh
(
const
float
a
)
{
return
__fdividef
(
2.0
f
,
(
1.0
f
+
__expf
(
-
2.0
f
*
a
)))
-
1.0
f
;
}
template
<
>
__device__
static
double
tanh
(
const
double
a
)
{
return
(
2.0
/
(
1.0
+
exp
(
-
2.0
*
a
)))
-
1.0
;
}
template
<
typename
T
>
__device__
static
T
linear
(
const
T
a
)
{
return
a
;
}
template
<
typename
T
>
__device__
static
T
relu
(
const
T
a
,
const
T
b
)
{
return
a
*
(
b
>
0.0
f
?
1.0
f
:
0.0
f
);
}
template
<
typename
T
>
__device__
static
T
sigmoid
(
const
T
a
,
const
T
b
)
{
return
a
*
b
*
(
1
-
b
);
}
template
<
typename
T
>
__device__
static
T
tanh
(
const
T
a
,
const
T
b
)
{
return
a
*
(
1.0
f
-
b
*
b
);
}
template
<
typename
T
>
__device__
static
T
linear
(
const
T
a
,
const
T
b
)
{
return
a
;
}
}
// namespace hppl
#endif // HL_GPU_FUNCTIONS_CUH_
paddle/operators/math/detail/lstm_cpu_kernel.h
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/lstm_compute.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
detail
{
#ifndef __NVCC__
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
lstm_value
value
,
int
frameSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
T
rValueIn
;
T
rValueIg
;
T
rValueFg
;
T
rValueOg
;
T
rCheckI
;
T
rCheckF
;
T
rCheckO
;
T
rState
;
T
rPrevState
=
0
;
T
rStateAtv
;
T
rOut
;
T
*
valueIn
=
value
.
gateValue
;
T
*
valueIg
=
value
.
gateValue
+
frameSize
;
T
*
valueFg
=
value
.
gateValue
+
frameSize
*
2
;
T
*
valueOg
=
value
.
gateValue
+
frameSize
*
3
;
for
(
int
i
=
0
;
i
<
frameSize
;
i
++
)
{
rValueIn
=
valueIn
[
i
];
rValueIg
=
valueIg
[
i
];
rValueFg
=
valueFg
[
i
];
rValueOg
=
valueOg
[
i
];
rCheckI
=
value
.
checkIg
[
i
];
rCheckF
=
value
.
checkFg
[
i
];
rCheckO
=
value
.
checkOg
[
i
];
if
(
value
.
prevStateValue
)
{
rPrevState
=
value
.
prevStateValue
[
i
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
hppl
::
cpu
::
forward
[
active_node
],
hppl
::
cpu
::
forward
[
active_gate
],
hppl
::
cpu
::
forward
[
active_state
]);
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
valueFg
[
i
]
=
rValueFg
;
valueOg
[
i
]
=
rValueOg
;
value
.
stateValue
[
i
]
=
rState
;
value
.
stateActiveValue
[
i
]
=
rStateAtv
;
value
.
outputValue
[
i
]
=
rOut
;
}
}
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
lstm_value
value
,
lstm_grad
grad
,
int
frameSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
T
rValueIn
;
T
rValueIg
;
T
rValueFg
;
T
rValueOg
;
T
rGradIn
;
T
rGradIg
;
T
rGradFg
;
T
rGradOg
;
T
rPrevState
=
0
;
T
rPrevStateGrad
;
T
rState
;
T
rStateGrad
;
T
rStateAtv
;
T
rOutputGrad
;
T
rCheckI
;
T
rCheckF
;
T
rCheckO
;
T
rCheckIGrad
;
T
rCheckFGrad
;
T
rCheckOGrad
;
T
*
valueIn
=
value
.
gateValue
;
T
*
valueIg
=
value
.
gateValue
+
frameSize
;
T
*
valueFg
=
value
.
gateValue
+
frameSize
*
2
;
T
*
valueOg
=
value
.
gateValue
+
frameSize
*
3
;
T
*
gradIn
=
grad
.
gateGrad
;
T
*
gradIg
=
grad
.
gateGrad
+
frameSize
;
T
*
gradFg
=
grad
.
gateGrad
+
frameSize
*
2
;
T
*
gradOg
=
grad
.
gateGrad
+
frameSize
*
3
;
for
(
int
i
=
0
;
i
<
frameSize
;
i
++
)
{
rValueIn
=
valueIn
[
i
];
rValueIg
=
valueIg
[
i
];
rValueFg
=
valueFg
[
i
];
rValueOg
=
valueOg
[
i
];
rCheckI
=
value
.
checkIg
[
i
];
rCheckF
=
value
.
checkFg
[
i
];
rCheckO
=
value
.
checkOg
[
i
];
rState
=
value
.
stateValue
[
i
];
rStateAtv
=
value
.
stateActiveValue
[
i
];
rOutputGrad
=
grad
.
outputGrad
[
i
];
rStateGrad
=
grad
.
stateGrad
[
i
];
if
(
value
.
prevStateValue
)
{
rPrevState
=
value
.
prevStateValue
[
i
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
hppl
::
cpu
::
backward
[
active_node
],
hppl
::
cpu
::
backward
[
active_gate
],
hppl
::
cpu
::
backward
[
active_state
]);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
gradFg
[
i
]
=
rGradFg
;
gradOg
[
i
]
=
rGradOg
;
grad
.
stateGrad
[
i
]
=
rStateGrad
;
if
(
grad
.
prevStateGrad
)
grad
.
prevStateGrad
[
i
]
=
rPrevStateGrad
;
if
(
value
.
prevStateValue
)
{
if
(
grad
.
checkIgGrad
)
grad
.
checkIgGrad
[
i
]
+=
rCheckIGrad
;
if
(
grad
.
checkFgGrad
)
grad
.
checkFgGrad
[
i
]
+=
rCheckFGrad
;
}
if
(
grad
.
checkOgGrad
)
grad
.
checkOgGrad
[
i
]
+=
rCheckOGrad
;
}
}
template
<
class
Op
>
void
avx_lstm_forward_one_sequence
(
Op
op
,
lstm_value
value
,
int
frameSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
#ifdef __AVX__
__m256
rValueIn
;
__m256
rValueIg
;
__m256
rValueFg
;
__m256
rValueOg
;
__m256
rCheckI
;
__m256
rCheckF
;
__m256
rCheckO
;
__m256
rState
;
__m256
rPrevState
=
_mm256_set1_ps
(
0.0
f
);
__m256
rStateAtv
;
__m256
rOut
;
__m256
*
valueIn
=
(
__m256
*
)
value
.
gateValue
;
__m256
*
valueIg
=
(
__m256
*
)(
value
.
gateValue
+
frameSize
);
__m256
*
valueFg
=
(
__m256
*
)(
value
.
gateValue
+
frameSize
*
2
);
__m256
*
valueOg
=
(
__m256
*
)(
value
.
gateValue
+
frameSize
*
3
);
for
(
int
i
=
0
;
i
<
frameSize
/
8
;
i
++
)
{
rValueIn
=
valueIn
[
i
];
rValueIg
=
valueIg
[
i
];
rValueFg
=
valueFg
[
i
];
rValueOg
=
valueOg
[
i
];
rCheckI
=
((
__m256
*
)
value
.
checkIg
)[
i
];
rCheckF
=
((
__m256
*
)
value
.
checkFg
)[
i
];
rCheckO
=
((
__m256
*
)
value
.
checkOg
)[
i
];
if
(
value
.
prevStateValue
)
{
rPrevState
=
((
__m256
*
)
value
.
prevStateValue
)[
i
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
hppl
::
avx
::
forward
[
active_node
],
hppl
::
avx
::
forward
[
active_gate
],
hppl
::
avx
::
forward
[
active_state
]);
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
valueFg
[
i
]
=
rValueFg
;
valueOg
[
i
]
=
rValueOg
;
((
__m256
*
)
value
.
stateValue
)[
i
]
=
rState
;
((
__m256
*
)
value
.
stateActiveValue
)[
i
]
=
rStateAtv
;
((
__m256
*
)
value
.
outputValue
)[
i
]
=
rOut
;
}
#endif
}
template
<
class
Op
>
void
avx_lstm_backward_one_sequence
(
Op
op
,
lstm_value
value
,
lstm_grad
grad
,
int
frameSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
#ifdef __AVX__
__m256
rValueIn
;
__m256
rValueIg
;
__m256
rValueFg
;
__m256
rValueOg
;
__m256
rGradIn
;
__m256
rGradIg
;
__m256
rGradFg
;
__m256
rGradOg
;
__m256
rPrevState
=
_mm256_set1_ps
(
0.0
f
);
__m256
rPrevStateGrad
;
__m256
rStateGrad
;
__m256
rState
;
__m256
rStateAtv
;
__m256
rOutputGrad
;
__m256
rCheckI
;
__m256
rCheckF
;
__m256
rCheckO
;
__m256
rCheckIGrad
;
__m256
rCheckFGrad
;
__m256
rCheckOGrad
;
__m256
*
valueIn
=
(
__m256
*
)
value
.
gateValue
;
__m256
*
valueIg
=
(
__m256
*
)(
value
.
gateValue
+
frameSize
);
__m256
*
valueFg
=
(
__m256
*
)(
value
.
gateValue
+
frameSize
*
2
);
__m256
*
valueOg
=
(
__m256
*
)(
value
.
gateValue
+
frameSize
*
3
);
__m256
*
gradIn
=
(
__m256
*
)
grad
.
gateGrad
;
__m256
*
gradIg
=
(
__m256
*
)(
grad
.
gateGrad
+
frameSize
);
__m256
*
gradFg
=
(
__m256
*
)(
grad
.
gateGrad
+
frameSize
*
2
);
__m256
*
gradOg
=
(
__m256
*
)(
grad
.
gateGrad
+
frameSize
*
3
);
for
(
int
i
=
0
;
i
<
frameSize
/
8
;
i
++
)
{
rValueIn
=
valueIn
[
i
];
rValueIg
=
valueIg
[
i
];
rValueFg
=
valueFg
[
i
];
rValueOg
=
valueOg
[
i
];
rCheckI
=
((
__m256
*
)
value
.
checkIg
)[
i
];
rCheckF
=
((
__m256
*
)
value
.
checkFg
)[
i
];
rCheckO
=
((
__m256
*
)
value
.
checkOg
)[
i
];
rState
=
((
__m256
*
)
value
.
stateValue
)[
i
];
rStateAtv
=
((
__m256
*
)
value
.
stateActiveValue
)[
i
];
rOutputGrad
=
((
__m256
*
)
grad
.
outputGrad
)[
i
];
rStateGrad
=
((
__m256
*
)
grad
.
stateGrad
)[
i
];
if
(
value
.
prevStateValue
)
{
rPrevState
=
((
__m256
*
)
value
.
prevStateValue
)[
i
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
hppl
::
avx
::
backward
[
active_node
],
hppl
::
avx
::
backward
[
active_gate
],
hppl
::
avx
::
backward
[
active_state
]);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
gradFg
[
i
]
=
rGradFg
;
gradOg
[
i
]
=
rGradOg
;
((
__m256
*
)
grad
.
stateGrad
)[
i
]
=
rStateGrad
;
if
(
grad
.
prevStateGrad
)
((
__m256
*
)
grad
.
prevStateGrad
)[
i
]
=
rPrevStateGrad
;
if
(
value
.
prevStateValue
)
{
if
(
grad
.
checkIgGrad
)
((
__m256
*
)
grad
.
checkIgGrad
)[
i
]
+=
rCheckIGrad
;
if
(
grad
.
checkFgGrad
)
((
__m256
*
)
grad
.
checkFgGrad
)[
i
]
+=
rCheckFGrad
;
}
if
(
grad
.
checkOgGrad
)
((
__m256
*
)
grad
.
checkOgGrad
)[
i
]
+=
rCheckOGrad
;
}
#endif
}
template
<
class
T
,
class
Op
>
void
cpu_lstm_forward
(
Op
op
,
lstm_value
value
,
int
frameSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
if
(
Op
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
avx_lstm_forward_one_sequence
(
op
,
value
,
frameSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frameSize
,
active_node
,
active_gate
,
active_state
);
}
}
template
<
class
T
,
class
Op
>
void
cpu_lstm_backward
(
Op
op
,
lstm_value
value
,
lstm_grad
grad
,
int
frameSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
if
(
Op
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
avx_lstm_backward_one_sequence
(
op
,
value
,
grad
,
frameSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frameSize
,
active_node
,
active_gate
,
active_state
);
}
}
#endif
}
// namespace detail
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/detail/lstm_gpu_kernel.h
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/detail/lstm_kernel.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
detail
{
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template
<
class
T
,
class
Op
,
bool
isBatch
>
__global__
void
KeLstmForward
(
Op
op
,
lstm_value
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
int
batchIdx
=
0
;
if
(
isBatch
)
{
batchIdx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batchIdx
>=
batchSize
)
return
;
value
.
gateValue
+=
batchIdx
*
frameSize
*
4
;
value
.
outputValue
+=
batchIdx
*
frameSize
;
value
.
stateValue
+=
batchIdx
*
frameSize
;
value
.
stateActiveValue
+=
batchIdx
*
frameSize
;
}
T
rState
;
T
rPrevState
=
0
;
T
rStateAtv
;
T
rOut
;
T
rValueIn
;
T
rValueIg
;
T
rValueFg
;
T
rValueOg
;
T
rCheckI
=
value
.
checkIg
[
frameIdx
];
T
rCheckF
=
value
.
checkFg
[
frameIdx
];
T
rCheckO
=
value
.
checkOg
[
frameIdx
];
rValueIn
=
value
.
gateValue
[
frameIdx
];
rValueIg
=
value
.
gateValue
[
frameIdx
+
frameSize
];
rValueFg
=
value
.
gateValue
[
frameIdx
+
frameSize
*
2
];
rValueOg
=
value
.
gateValue
[
frameIdx
+
frameSize
*
3
];
if
(
value
.
prevStateValue
)
{
if
(
isBatch
)
value
.
prevStateValue
+=
batchIdx
*
frameSize
;
rPrevState
=
value
.
prevStateValue
[
frameIdx
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
hppl
::
gpu
::
forward
[
active_node
],
hppl
::
gpu
::
forward
[
active_gate
],
hppl
::
gpu
::
forward
[
active_state
]);
value
.
gateValue
[
frameIdx
]
=
rValueIn
;
value
.
gateValue
[
frameIdx
+
frameSize
]
=
rValueIg
;
value
.
gateValue
[
frameIdx
+
frameSize
*
2
]
=
rValueFg
;
value
.
gateValue
[
frameIdx
+
frameSize
*
3
]
=
rValueOg
;
value
.
stateValue
[
frameIdx
]
=
rState
;
value
.
stateActiveValue
[
frameIdx
]
=
rStateAtv
;
value
.
outputValue
[
frameIdx
]
=
rOut
;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template
<
class
T
,
class
Op
,
bool
isBatch
>
__global__
void
KeLstmBackward
(
Op
op
,
lstm_value
value
,
lstm_grad
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
int
batchIdx
=
0
;
if
(
isBatch
)
{
batchIdx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batchIdx
>=
batchSize
)
return
;
value
.
gateValue
+=
batchIdx
*
frameSize
*
4
;
value
.
stateValue
+=
batchIdx
*
frameSize
;
value
.
stateActiveValue
+=
batchIdx
*
frameSize
;
grad
.
gateGrad
+=
batchIdx
*
frameSize
*
4
;
grad
.
stateGrad
+=
batchIdx
*
frameSize
;
grad
.
outputGrad
+=
batchIdx
*
frameSize
;
}
T
rValueIn
;
T
rValueIg
;
T
rValueFg
;
T
rValueOg
;
T
rGradIn
;
T
rGradIg
;
T
rGradFg
;
T
rGradOg
;
T
rPrevState
=
0
;
T
rPrevStateGrad
;
T
rState
;
T
rStateGrad
;
T
rStateAtv
;
T
rOutputGrad
;
T
rCheckI
=
value
.
checkIg
[
frameIdx
];
T
rCheckF
=
value
.
checkFg
[
frameIdx
];
T
rCheckO
=
value
.
checkOg
[
frameIdx
];
T
rCheckIGrad
;
T
rCheckFGrad
;
T
rCheckOGrad
;
rValueIn
=
value
.
gateValue
[
frameIdx
];
rValueIg
=
value
.
gateValue
[
frameIdx
+
frameSize
];
rValueFg
=
value
.
gateValue
[
frameIdx
+
frameSize
*
2
];
rValueOg
=
value
.
gateValue
[
frameIdx
+
frameSize
*
3
];
rState
=
value
.
stateValue
[
frameIdx
];
rStateAtv
=
value
.
stateActiveValue
[
frameIdx
];
rOutputGrad
=
grad
.
outputGrad
[
frameIdx
];
rStateGrad
=
grad
.
stateGrad
[
frameIdx
];
if
(
value
.
prevStateValue
)
{
if
(
isBatch
)
value
.
prevStateValue
+=
batchIdx
*
frameSize
;
rPrevState
=
value
.
prevStateValue
[
frameIdx
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
hppl
::
gpu
::
backward
[
active_node
],
hppl
::
gpu
::
backward
[
active_gate
],
hppl
::
gpu
::
backward
[
active_state
]);
grad
.
gateGrad
[
frameIdx
]
=
rGradIn
;
grad
.
gateGrad
[
frameIdx
+
frameSize
]
=
rGradIg
;
grad
.
gateGrad
[
frameIdx
+
frameSize
*
2
]
=
rGradFg
;
grad
.
gateGrad
[
frameIdx
+
frameSize
*
3
]
=
rGradOg
;
grad
.
stateGrad
[
frameIdx
]
=
rStateGrad
;
if
(
grad
.
prevStateGrad
)
{
if
(
isBatch
)
grad
.
prevStateGrad
+=
batchIdx
*
frameSize
;
grad
.
prevStateGrad
[
frameIdx
]
=
rPrevStateGrad
;
}
if
(
isBatch
)
{
if
(
value
.
prevStateValue
)
{
if
(
grad
.
checkIgGrad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
checkIgGrad
+
frameIdx
,
rCheckIGrad
);
if
(
grad
.
checkFgGrad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
checkFgGrad
+
frameIdx
,
rCheckFGrad
);
}
if
(
grad
.
checkOgGrad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
checkOgGrad
+
frameIdx
,
rCheckOGrad
);
}
else
{
if
(
value
.
prevStateValue
)
{
if
(
grad
.
checkIgGrad
)
grad
.
checkIgGrad
[
frameIdx
]
+=
rCheckIGrad
;
if
(
grad
.
checkFgGrad
)
grad
.
checkFgGrad
[
frameIdx
]
+=
rCheckFGrad
;
}
if
(
grad
.
checkOgGrad
)
grad
.
checkOgGrad
[
frameIdx
]
+=
rCheckOGrad
;
}
}
template
<
class
T
,
class
Op
>
void
gpu_lstm_forward
(
Op
op
,
lstm_value
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
dim3
threads
;
dim3
grid
;
if
(
batchSize
==
1
)
{
int
framePerBlock
=
frameSize
<=
1024
?
frameSize
:
1024
;
int
frameBlocks
=
(
frameSize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
framePerBlock
,
1
);
grid
=
dim3
(
frameBlocks
,
1
);
}
else
{
/* framePerBlock = 32 batchPerBlock = 32 */
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frameSize
+
32
-
1
)
/
32
,
(
batchSize
+
32
-
1
)
/
32
);
}
if
(
batchSize
==
1
)
{
KeLstmForward
<
T
,
Op
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
op
,
value
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
KeLstmForward
<
T
,
Op
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
op
,
value
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
}
template
<
class
T
,
class
Op
>
void
gpu_lstm_backward
(
Op
op
,
lstm_value
value
,
lstm_grad
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
dim3
threads
;
dim3
grid
;
if
(
batchSize
==
1
)
{
int
framePerBlock
=
frameSize
<=
1024
?
frameSize
:
1024
;
int
frameBlocks
=
(
frameSize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
framePerBlock
,
1
);
grid
=
dim3
(
frameBlocks
,
1
);
}
else
{
/* framePerBlock = 32 batchPerBlock = 32 */
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frameSize
+
32
-
1
)
/
32
,
(
batchSize
+
32
-
1
)
/
32
);
}
if
(
batchSize
==
1
)
{
KeLstmBackward
<
T
,
Op
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
op
,
value
,
grad
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
KeLstmBackward
<
T
,
Op
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
op
,
value
,
grad
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
}
}
// namespace detail
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/detail/lstm_kernel.h
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_activation_functions.h"
#ifdef __CUDA_ARCH__
#define INLINE __device__ inline
#else
#define INLINE inline
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
detail
{
namespace
forward
{
template
<
class
T
>
class
lstm
{
public:
INLINE
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
prevState
,
T
&
state
,
T
&
stateAtv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
Active
<
T
>::
forward
actInput
,
Active
<
T
>::
forward
actGate
,
Active
<
T
>::
forward
actState
)
{
valueIn
=
actInput
(
valueIn
);
valueIg
=
actGate
(
valueIg
+
prevState
*
checkI
);
valueFg
=
actGate
(
valueFg
+
prevState
*
checkF
);
state
=
valueIn
*
valueIg
+
prevState
*
valueFg
;
valueOg
=
actGate
(
valueOg
+
state
*
checkO
);
stateAtv
=
actState
(
state
);
output
=
valueOg
*
stateAtv
;
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
INLINE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
prevState
,
__m256
&
state
,
__m256
&
stateAtv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
Active
<
__m256
>::
forward
actInput
,
Active
<
__m256
>::
forward
actGate
,
Active
<
__m256
>::
forward
actState
)
{
valueIn
=
actInput
(
valueIn
);
valueIg
=
actGate
(
_mm256_add_ps
(
valueIg
,
_mm256_mul_ps
(
prevState
,
checkI
)));
valueFg
=
actGate
(
_mm256_add_ps
(
valueFg
,
_mm256_mul_ps
(
prevState
,
checkF
)));
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
valueIn
,
valueIg
),
_mm256_mul_ps
(
prevState
,
valueFg
));
valueOg
=
actGate
(
_mm256_add_ps
(
valueOg
,
_mm256_mul_ps
(
state
,
checkO
)));
stateAtv
=
actState
(
state
);
output
=
_mm256_mul_ps
(
valueOg
,
stateAtv
);
}
#endif
#endif
};
}
// namespace forward
namespace
backward
{
template
<
class
T
>
class
lstm
{
public:
INLINE
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
gradIn
,
T
&
gradIg
,
T
&
gradFg
,
T
&
gradOg
,
T
&
prevState
,
T
&
prevStateGrad
,
T
&
state
,
T
&
stateGrad
,
T
&
stateAtv
,
T
&
outputGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
Active
<
T
>::
backward
actInput
,
Active
<
T
>::
backward
actGate
,
Active
<
T
>::
backward
actState
)
{
gradOg
=
actGate
(
outputGrad
*
stateAtv
,
valueOg
);
stateGrad
+=
actState
(
outputGrad
*
valueOg
,
stateAtv
)
+
gradOg
*
checkO
;
gradIn
=
actInput
(
stateGrad
*
valueIg
,
valueIn
);
gradIg
=
actGate
(
stateGrad
*
valueIn
,
valueIg
);
gradFg
=
actGate
(
stateGrad
*
prevState
,
valueFg
);
prevStateGrad
=
gradIg
*
checkI
+
gradFg
*
checkF
+
stateGrad
*
valueFg
;
checkIGrad
=
gradIg
*
prevState
;
checkFGrad
=
gradFg
*
prevState
;
checkOGrad
=
gradOg
*
state
;
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
INLINE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
gradIn
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradOg
,
__m256
&
prevState
,
__m256
&
prevStateGrad
,
__m256
&
state
,
__m256
&
stateGrad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
Active
<
__m256
>::
backward
actInput
,
Active
<
__m256
>::
backward
actGate
,
Active
<
__m256
>::
backward
actState
)
{
gradOg
=
actGate
(
_mm256_mul_ps
(
outputGrad
,
stateAtv
),
valueOg
);
stateGrad
=
_mm256_add_ps
(
actState
(
_mm256_mul_ps
(
outputGrad
,
valueOg
),
stateAtv
),
stateGrad
);
stateGrad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradOg
,
checkO
),
stateGrad
);
gradIn
=
actInput
(
_mm256_mul_ps
(
stateGrad
,
valueIg
),
valueIn
);
gradIg
=
actGate
(
_mm256_mul_ps
(
stateGrad
,
valueIn
),
valueIg
);
gradFg
=
actGate
(
_mm256_mul_ps
(
stateGrad
,
prevState
),
valueFg
);
prevStateGrad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradIg
,
checkI
),
_mm256_mul_ps
(
gradFg
,
checkF
));
prevStateGrad
=
_mm256_add_ps
(
_mm256_mul_ps
(
stateGrad
,
valueFg
),
prevStateGrad
);
checkIGrad
=
_mm256_mul_ps
(
gradIg
,
prevState
);
checkFGrad
=
_mm256_mul_ps
(
gradFg
,
prevState
);
checkOGrad
=
_mm256_mul_ps
(
gradOg
,
state
);
}
#endif
#endif
};
}
// namespace backward
}
// namespace detail
}
// namespace math
}
// namespace operators
}
// namespace paddle
#endif
/* HL_LSTM_OPS_CUH_ */
paddle/operators/math/lstm_compute.cc
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "LstmCompute.h"
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
#include "paddle/operators/math/detail/lstm_kernel.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
class
T
>
struct
LstmUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
lstm_value
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frameSize
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gateValue
+=
frameSize
*
4
;
value
.
stateValue
+=
frameSize
;
value
.
stateActiveValue
+=
frameSize
;
value
.
outputValue
+=
frameSize
;
if
(
value
.
prevStateValue
)
{
value
.
prevStateValue
+=
frameSize
;
}
}
}
};
template
<
class
T
>
struct
LstmUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
lstm_value
value
,
lstm_grad
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
detail
::
cpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frameSize
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gateValue
+=
frameSize
*
4
;
value
.
stateValue
+=
frameSize
;
value
.
stateActiveValue
+=
frameSize
;
value
.
outputValue
+=
frameSize
;
if
(
value
.
prevStateValue
)
{
value
.
prevStateValue
+=
frameSize
;
}
grad
.
gateGrad
+=
frameSize
*
4
;
grad
.
stateGrad
+=
frameSize
;
grad
.
stateActiveGrad
+=
frameSize
;
grad
.
outputGrad
+=
frameSize
;
if
(
grad
.
prevStateGrad
)
{
grad
.
prevStateGrad
+=
frameSize
;
}
}
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/lstm_compute.cu
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "LstmCompute.h"
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
#include "paddle/operators/math/detail/lstm_kernel.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
class
T
>
struct
LstmUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
lstm_value
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
gpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frameSize
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gateValue
+=
frameSize
*
4
;
value
.
stateValue
+=
frameSize
;
value
.
stateActiveValue
+=
frameSize
;
value
.
outputValue
+=
frameSize
;
if
(
value
.
prevStateValue
)
{
value
.
prevStateValue
+=
frameSize
;
}
}
}
};
template
<
class
T
>
struct
LstmUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
lstm_value
value
,
lstm_grad
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
detail
::
gpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frameSize
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gateValue
+=
frameSize
*
4
;
value
.
stateValue
+=
frameSize
;
value
.
stateActiveValue
+=
frameSize
;
value
.
outputValue
+=
frameSize
;
if
(
value
.
prevStateValue
)
{
value
.
prevStateValue
+=
frameSize
;
}
grad
.
gateGrad
+=
frameSize
*
4
;
grad
.
stateGrad
+=
frameSize
;
grad
.
stateActiveGrad
+=
frameSize
;
grad
.
outputGrad
+=
frameSize
;
if
(
grad
.
prevStateGrad
)
{
grad
.
prevStateGrad
+=
frameSize
;
}
}
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/lstm_compute.h
0 → 100644
浏览文件 @
3cace737
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/macros.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
typedef
enum
{
HL_ACTIVATION_SIGMOID
=
0
,
HL_ACTIVATION_RELU
=
1
,
HL_ACTIVATION_TANH
=
2
,
HL_ACTIVATION_LINEAR
=
3
,
HL_ACTIVATION_END
}
activation_mode_t
;
template
<
T
>
struct
lstm_value
{
real
*
gateValue
;
real
*
prevStateValue
;
real
*
stateValue
;
real
*
stateActiveValue
;
real
*
outputValue
;
real
*
checkIg
;
real
*
checkFg
;
real
*
checkOg
;
};
template
<
T
>
struct
lstm_grad
{
real
*
gateGrad
;
real
*
prevStateGrad
;
real
*
stateGrad
;
real
*
stateActiveGrad
;
real
*
outputGrad
;
real
*
checkIgGrad
;
real
*
checkFgGrad
;
real
*
checkOgGrad
;
};
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
return
HL_ACTIVATION_SIGMOID
;
}
else
if
(
type
==
"relu"
)
{
return
HL_ACTIVATION_RELU
;
}
else
if
(
type
==
"tanh"
)
{
return
HL_ACTIVATION_TANH
;
}
else
if
(
type
==
"linear"
||
type
==
""
)
{
return
HL_ACTIVATION_LINEAR
;
}
else
{
PADDLE_THROW
(
"Do not support activation type."
);
}
}
template
<
typename
Place
,
typename
T
>
class
LstmUnitFunctor
{
public:
static
void
compute
(
lstm_value
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
);
};
template
<
typename
Place
,
typename
T
>
class
LstmUnitGradFunctor
{
public:
static
void
compute
(
lstm_value
value
,
lstm_grad
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/sequence2batch.cc
浏览文件 @
3cace737
...
...
@@ -18,6 +18,37 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
>
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
src
,
const
size_t
*
index
,
framework
::
Tensor
&
dst
,
bool
is_src_index
)
{
auto
src_dims
=
src
.
dims
();
auto
dst_dims
=
dst
.
dims
();
PADDLE_ENFORCE
(
src_dims
.
size
(),
2
,
"The src must be matrix with rank 2."
);
PADDLE_ENFORCE
(
dst_dims
.
size
(),
2
,
"The dst must be matrix with rank 2."
);
PADDLE_ENFORCE_EQ
(
src_dims
[
1
],
dst_dims
[
1
],
"The width of src and dst must be same."
);
auto
height
=
dst_dims
[
0
];
auto
width
=
dst_dims
[
1
];
auto
*
src_data
=
src
.
data
<
T
>
();
auto
*
dst_data
=
dst
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
if
(
is_src_index
)
{
memcpy
(
dst_data
+
i
*
width
,
src_data
+
index
[
i
]
*
width
,
width
*
sizeof
(
T
));
}
else
{
memcpy
(
dst_data
+
index
[
i
]
*
width
,
src_data
+
i
*
width
,
width
*
sizeof
(
T
));
}
}
}
};
template
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
LoDTensor2BatchFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Batch2LoDTensor2Functor
<
platform
::
CPUPlace
,
float
>;
...
...
paddle/operators/math/sequence2batch.cu
浏览文件 @
3cace737
...
...
@@ -18,6 +18,53 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
CopyMatrixRowsKernel
(
const
T
*
src
,
T
*
dst
,
const
int
*
index
,
int
height
,
int
width
,
const
bool
is_src_index
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
threadIdx
.
y
;
int
id
=
blockIdx
.
x
+
idy
*
GridDimX
;
while
(
id
<
height
)
{
int
src_idx
=
is_src_index
?
index
[
id
]
:
id
;
int
dst_idx
=
is_src_index
?
id
:
index
[
id
];
T
*
src_data
=
src
+
src_idx
*
width
;
T
*
dst_data
=
dst
+
dst_idx
*
width
;
for
(
int
i
=
idx
;
i
<
width
;
i
+=
BlockDimX
)
{
dst_data
[
i
]
=
src_data
[
i
];
}
id
+=
BlockDimY
*
GridDimX
;
}
}
template
<
typename
T
>
class
CopyMatrixRowsFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
src
,
const
size_t
*
index
,
framework
::
Tensor
&
dst
,
bool
is_src_index
)
{
auto
src_dims
=
src
.
dims
();
auto
dst_dims
=
dst
.
dims
();
PADDLE_ENFORCE
(
src_dims
.
size
(),
2
,
"The src must be matrix with rank 2."
);
PADDLE_ENFORCE
(
dst_dims
.
size
(),
2
,
"The dst must be matrix with rank 2."
);
PADDLE_ENFORCE_EQ
(
src_dims
[
1
],
dst_dims
[
1
],
"The width of src and dst must be same."
);
auto
height
=
dst_dims
[
0
];
auto
width
=
dst_dims
[
1
];
auto
*
src_data
=
src
.
data
<
T
>
();
auto
*
dst_data
=
dst
.
data
<
T
>
();
dim3
threads
(
128
,
8
);
dim3
grid
(
8
,
1
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
);
CopyMatrixRowsKernel
<
T
,
128
,
8
,
8
><<<
grid
,
threads
,
0
,
stream
>>>
(
src_data
,
dst_data
,
index
,
height
,
width
);
}
};
template
class
CopyMatrixRowsFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
CopyMatrixRowsFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
LoDTensor2BatchFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
Batch2LoDTensor2Functor
<
platform
::
GPUPlace
,
float
>;
...
...
paddle/operators/math/sequence2batch.h
浏览文件 @
3cace737
...
...
@@ -16,6 +16,19 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
Place
,
typename
T
>
class
CopyMatrixRowsFunctor
{
public:
// If is_src_index is true,
// copy the indexed rows of input src to the output dst.
// If is_src_index is false,
// copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index.
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
src
,
const
size_t
*
index
,
framework
::
Tensor
&
dst
,
const
bool
is_src_index
);
};
template
<
typename
Place
,
typename
T
>
class
LoDTensor2BatchFunctor
{
public:
...
...
@@ -97,8 +110,11 @@ class LoDTensor2BatchFunctor {
}
batch_starts
[
n
+
1
]
=
batch_id
;
}
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
to_batch
(
context
,
lod_tensor
,
batch
,
true
);
}
}
}
;
template
<
typename
Place
,
typename
T
>
class
Batch2LoDTensor2Functor
{
...
...
@@ -107,6 +123,7 @@ class Batch2LoDTensor2Functor {
const
framework
::
LoDTensor
&
batch
,
framework
::
LoDTensor
&
lod_tensor
,
const
bool
is_reverse
)
const
;
};
}
// namespace math
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录