Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1c8a0c4b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1c8a0c4b
编写于
10月 31, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine activation function pointer for LSTM operator.
上级
2c5d4c6d
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
279 addition
and
590 deletion
+279
-590
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+2
-1
paddle/operators/math/detail/CMakeLists.txt
paddle/operators/math/detail/CMakeLists.txt
+1
-3
paddle/operators/math/detail/activation_functions.h
paddle/operators/math/detail/activation_functions.h
+170
-0
paddle/operators/math/detail/avx_functions.cc
paddle/operators/math/detail/avx_functions.cc
+19
-3
paddle/operators/math/detail/hl_activation_functions.h
paddle/operators/math/detail/hl_activation_functions.h
+0
-188
paddle/operators/math/detail/hl_avx_functions.h
paddle/operators/math/detail/hl_avx_functions.h
+0
-32
paddle/operators/math/detail/hl_cpu_functions.cc
paddle/operators/math/detail/hl_cpu_functions.cc
+0
-89
paddle/operators/math/detail/hl_functions.h
paddle/operators/math/detail/hl_functions.h
+0
-71
paddle/operators/math/detail/hl_gpu_functions.h
paddle/operators/math/detail/hl_gpu_functions.h
+0
-93
paddle/operators/math/detail/lstm_cpu_kernel.h
paddle/operators/math/detail/lstm_cpu_kernel.h
+17
-11
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+19
-11
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+49
-86
python/paddle/v2/framework/tests/test_lstm_op.py
python/paddle/v2/framework/tests/test_lstm_op.py
+2
-2
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
1c8a0c4b
...
...
@@ -20,7 +20,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_test
(
program_desc_test SRCS program_desc_test.cc DEPS proto_desc
)
cc_test
(
program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context
)
cc_library
(
op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute
)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto
)
...
...
paddle/operators/math/detail/CMakeLists.txt
浏览文件 @
1c8a0c4b
if
(
WITH_AVX
)
cc_library
(
activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc
)
else
()
cc_library
(
activation_functions SRCS hl_cpu_functions.cc
)
cc_library
(
activation_functions SRCS avx_functions.cc
)
endif
()
paddle/operators/math/detail/activation_functions.h
0 → 100644
浏览文件 @
1c8a0c4b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/platform/hostdevice.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
detail
{
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
namespace
forward
{
template
<
typename
T
>
DEVICE
T
linear
(
const
T
a
)
{
return
a
;
}
template
<
typename
T
>
DEVICE
T
relu
(
const
T
a
)
{
return
a
>
static_cast
<
T
>
(
0.0
)
?
a
:
static_cast
<
T
>
(
0.0
);
}
template
<
typename
T
>
DEVICE
T
sigmoid
(
const
T
a
)
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
T
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
T
>
(
1.0
)
/
(
static_cast
<
T
>
(
1.0
)
+
exp
(
-
tmp
));
}
template
<
typename
T
>
DEVICE
T
tanh
(
const
T
a
)
{
T
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
}
// namespace forward
namespace
backward
{
template
<
typename
T
>
DEVICE
T
linear
(
const
T
a
,
const
T
b
)
{
return
a
;
}
template
<
typename
T
>
DEVICE
T
relu
(
const
T
a
,
const
T
b
)
{
return
a
*
(
b
>
0.0
?
1.0
:
0.0
);
}
template
<
typename
T
>
DEVICE
T
sigmoid
(
const
T
a
,
const
T
b
)
{
return
a
*
b
*
(
1.0
-
b
);
}
template
<
typename
T
>
DEVICE
T
tanh
(
const
T
a
,
const
T
b
)
{
return
a
*
(
1.0
-
b
*
b
);
}
}
// namespace backward
template
<
typename
T
>
struct
Active
{
typedef
T
(
*
Act
)(
T
);
typedef
T
(
*
ActGrad
)(
T
,
T
);
};
static
DEVICE
Active
<
float
>::
Act
kActFloat
[]
=
{
&
forward
::
sigmoid
<
float
>
,
&
forward
::
relu
<
float
>
,
&
forward
::
tanh
<
float
>
,
&
forward
::
linear
<
float
>
};
static
DEVICE
Active
<
float
>::
ActGrad
kActGradFloat
[]
=
{
&
backward
::
sigmoid
<
float
>
,
&
backward
::
relu
<
float
>
,
&
backward
::
tanh
<
float
>
,
&
backward
::
linear
<
float
>
};
static
DEVICE
Active
<
double
>::
Act
kActDouble
[]
=
{
&
forward
::
sigmoid
<
double
>
,
&
forward
::
relu
<
double
>
,
&
forward
::
tanh
<
double
>
,
&
forward
::
linear
<
double
>
};
static
DEVICE
Active
<
double
>::
ActGrad
kActGradDouble
[]
=
{
&
backward
::
sigmoid
<
double
>
,
&
backward
::
relu
<
double
>
,
&
backward
::
tanh
<
double
>
,
&
backward
::
linear
<
double
>
};
namespace
forward
{
inline
DEVICE
float
activation
(
float
a
,
int
index
)
{
return
kActFloat
[
index
](
a
);
}
inline
DEVICE
double
activation
(
double
a
,
int
index
)
{
return
kActDouble
[
index
](
a
);
}
}
// namespace forward
namespace
backward
{
inline
DEVICE
float
activation
(
float
a
,
float
b
,
int
index
)
{
return
kActGradFloat
[
index
](
a
,
b
);
}
inline
DEVICE
double
activation
(
double
a
,
double
b
,
int
index
)
{
return
kActGradDouble
[
index
](
a
,
b
);
}
}
// namespace backward
#ifdef __AVX__
namespace
forward
{
namespace
avx
{
__m256
relu
(
const
__m256
a
);
__m256
sigmoid
(
const
__m256
a
);
__m256
tanh
(
const
__m256
a
);
__m256
linear
(
const
__m256
a
);
}
// namespace avx
}
// namespace forward
namespace
backward
{
namespace
avx
{
__m256
relu
(
const
__m256
a
,
const
__m256
b
);
__m256
sigmoid
(
const
__m256
a
,
const
__m256
b
);
__m256
tanh
(
const
__m256
a
,
const
__m256
b
);
__m256
linear
(
const
__m256
a
,
const
__m256
b
);
}
// namespace avx
}
// namespace backward
static
Active
<
__m256
>::
Act
kActAvx
[]
=
{
&
forward
::
avx
::
sigmoid
,
&
forward
::
avx
::
relu
,
&
forward
::
avx
::
tanh
,
&
forward
::
avx
::
linear
};
static
Active
<
__m256
>::
ActGrad
kActGradAvx
[]
=
{
&
backward
::
avx
::
sigmoid
,
&
backward
::
avx
::
relu
,
&
backward
::
avx
::
tanh
,
&
backward
::
avx
::
linear
};
namespace
forward
{
inline
__m256
activation
(
__m256
a
,
int
index
)
{
return
kActAvx
[
index
](
a
);
}
}
// namespace forward
namespace
backward
{
inline
__m256
activation
(
__m256
a
,
__m256
b
,
int
index
)
{
return
kActGradAvx
[
index
](
a
,
b
);
}
}
// namespace backward
#endif
}
// namespace detail
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/detail/
hl_
avx_functions.cc
→
paddle/operators/math/detail/avx_functions.cc
浏览文件 @
1c8a0c4b
...
...
@@ -13,14 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <immintrin.h>
#include "
hl
_functions.h"
#include "
paddle/operators/math/detail/activation
_functions.h"
// TODO(qingqing) refine this dependence
#include "paddle/cuda/src/avx_mathfun.h"
namespace
hppl
{
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
detail
{
__m256
exp
(
__m256
a
)
{
return
exp256_ps
(
a
);
}
namespace
forward
{
namespace
avx
{
__m256
relu
(
const
__m256
a
)
{
__m256
tmp
=
_mm256_set1_ps
(
0.0
f
);
return
_mm256_max_ps
(
a
,
tmp
);
...
...
@@ -50,6 +55,11 @@ __m256 tanh(const __m256 a) {
__m256
linear
(
const
__m256
a
)
{
return
a
;
}
}
// namespace avx
}
// namespace forward
namespace
backward
{
namespace
avx
{
__m256
relu
(
const
__m256
a
,
const
__m256
b
)
{
return
_mm256_mul_ps
(
a
,
_mm256_and_ps
(
_mm256_cmp_ps
(
b
,
_mm256_set1_ps
(
0.0
f
),
_CMP_GT_OS
),
...
...
@@ -67,4 +77,10 @@ __m256 tanh(const __m256 a, const __m256 b) {
}
__m256
linear
(
const
__m256
a
,
const
__m256
b
)
{
return
a
;
}
}
// namespace hppl
}
// namespace avx
}
// namespace backward
}
// namespace detail
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/detail/hl_activation_functions.h
已删除
100644 → 0
浏览文件 @
2c5d4c6d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_ACTIVATION_FUNCTIONS_H_
#define HL_ACTIVATION_FUNCTIONS_H_
#include "hl_functions.h"
#include "paddle/operators/math/lstm_compute.h"
/**
* Active functions: sigmoid, relu, tanh and linear.
*/
#define FLOAT_ACTIVE_FUNCTION \
{ \
hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \
hppl::typef::linear \
}
#define DOUBLE_ACTIVE_FUNCTION \
{ \
hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \
hppl::typed::linear \
}
#define AVX_ACTIVE_FUNCTION \
{ hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear }
namespace
hppl
{
using
activation_mode_t
=
paddle
::
operators
::
math
::
activation_mode_t
;
/**
* Hppl supports sigmoid, relu, tanh, linear active functions
* for neural networks' forward and backward activation.
*/
template
<
class
T
>
class
Active
{
public:
typedef
T
(
*
forward
)(
T
);
typedef
T
(
*
backward
)(
T
,
T
);
};
template
<
typename
T
>
struct
ForwardActType
;
template
<
>
struct
ForwardActType
<
float
>
{
using
type
=
Active
<
float
>::
forward
;
};
template
<
>
struct
ForwardActType
<
double
>
{
using
type
=
Active
<
double
>::
forward
;
};
template
<
typename
T
>
struct
BackwardActType
;
template
<
>
struct
BackwardActType
<
float
>
{
using
type
=
Active
<
float
>::
backward
;
};
template
<
>
struct
BackwardActType
<
double
>
{
using
type
=
Active
<
double
>::
backward
;
};
#ifdef __NVCC__
namespace
gpu
{
static
__device__
Active
<
float
>::
forward
forward
[]
=
FLOAT_ACTIVE_FUNCTION
;
static
__device__
Active
<
float
>::
backward
backward
[]
=
FLOAT_ACTIVE_FUNCTION
;
static
__device__
Active
<
double
>::
forward
forward_d
[]
=
DOUBLE_ACTIVE_FUNCTION
;
static
__device__
Active
<
double
>::
backward
backward_d
[]
=
DOUBLE_ACTIVE_FUNCTION
;
template
<
typename
T
>
struct
ForwardAct
{
__device__
typename
ForwardActType
<
T
>::
type
operator
()(
activation_mode_t
type
);
};
template
<
>
struct
ForwardAct
<
float
>
{
__device__
ForwardActType
<
float
>::
type
operator
()(
activation_mode_t
type
)
{
return
forward
[
type
];
}
};
template
<
>
struct
ForwardAct
<
double
>
{
__device__
ForwardActType
<
double
>::
type
operator
()(
activation_mode_t
type
)
{
return
forward_d
[
type
];
}
};
template
<
typename
T
>
struct
BackwardAct
{
__device__
typename
BackwardActType
<
T
>::
type
operator
()(
activation_mode_t
type
);
};
template
<
>
struct
BackwardAct
<
float
>
{
__device__
BackwardActType
<
float
>::
type
operator
()(
activation_mode_t
type
)
{
return
backward
[
type
];
}
};
template
<
>
struct
BackwardAct
<
double
>
{
__device__
BackwardActType
<
double
>::
type
operator
()(
activation_mode_t
type
)
{
return
backward_d
[
type
];
}
};
}
// namespace gpu
#else
namespace
cpu
{
static
Active
<
float
>::
forward
forward
[]
=
FLOAT_ACTIVE_FUNCTION
;
static
Active
<
float
>::
backward
backward
[]
=
FLOAT_ACTIVE_FUNCTION
;
static
Active
<
double
>::
forward
forward_d
[]
=
DOUBLE_ACTIVE_FUNCTION
;
static
Active
<
double
>::
backward
backward_d
[]
=
DOUBLE_ACTIVE_FUNCTION
;
template
<
typename
T
>
struct
ForwardAct
{
typename
ForwardActType
<
T
>::
type
operator
()(
activation_mode_t
type
);
};
template
<
>
struct
ForwardAct
<
float
>
{
ForwardActType
<
float
>::
type
operator
()(
activation_mode_t
type
)
{
return
forward
[
type
];
}
};
template
<
>
struct
ForwardAct
<
double
>
{
ForwardActType
<
double
>::
type
operator
()(
activation_mode_t
type
)
{
return
forward_d
[
type
];
}
};
template
<
typename
T
>
struct
BackwardAct
{
typename
BackwardActType
<
T
>::
type
operator
()(
activation_mode_t
type
);
};
template
<
>
struct
BackwardAct
<
float
>
{
BackwardActType
<
float
>::
type
operator
()(
activation_mode_t
type
)
{
return
backward
[
type
];
}
};
template
<
>
struct
BackwardAct
<
double
>
{
BackwardActType
<
double
>::
type
operator
()(
activation_mode_t
type
)
{
return
backward_d
[
type
];
}
};
}
// namespace cpu
#ifdef __AVX__
namespace
avx
{
static
Active
<
__m256
>::
forward
forward
[]
=
AVX_ACTIVE_FUNCTION
;
static
Active
<
__m256
>::
backward
backward
[]
=
AVX_ACTIVE_FUNCTION
;
}
// namespace avx
#endif
#endif
}
// namespace hppl
#endif // HL_ACTIVATION_FUNCTIONS_H_
paddle/operators/math/detail/hl_avx_functions.h
已删除
100644 → 0
浏览文件 @
2c5d4c6d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_AVX_FUNCTIONS_H_
#define HL_AVX_FUNCTIONS_H_
#include <immintrin.h>
namespace
hppl
{
__m256
relu
(
const
__m256
a
);
__m256
sigmoid
(
const
__m256
a
);
__m256
tanh
(
const
__m256
a
);
__m256
linear
(
const
__m256
a
);
__m256
relu
(
const
__m256
a
,
const
__m256
b
);
__m256
sigmoid
(
const
__m256
a
,
const
__m256
b
);
__m256
tanh
(
const
__m256
a
,
const
__m256
b
);
__m256
linear
(
const
__m256
a
,
const
__m256
b
);
}
// namespace hppl
#endif // HL_AVX_FUNCTIONS_H_
paddle/operators/math/detail/hl_cpu_functions.cc
已删除
100644 → 0
浏览文件 @
2c5d4c6d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "hl_functions.h"
namespace
hppl
{
namespace
typef
{
float
relu
(
const
float
a
)
{
return
a
>
static_cast
<
float
>
(
0.0
)
?
a
:
static_cast
<
float
>
(
0.0
);
}
float
sigmoid
(
const
float
a
)
{
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
float
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
float
>
(
1.0
)
/
(
static_cast
<
float
>
(
1.0
)
+
exp
(
-
tmp
));
}
float
tanh
(
const
float
a
)
{
float
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
float
linear
(
const
float
a
)
{
return
a
;
}
float
relu
(
const
float
a
,
const
float
b
)
{
return
a
*
(
b
>
0.0
?
1.0
:
0.0
);
}
float
sigmoid
(
const
float
a
,
const
float
b
)
{
return
a
*
b
*
(
static_cast
<
float
>
(
1
)
-
b
);
}
float
tanh
(
const
float
a
,
const
float
b
)
{
return
a
*
(
static_cast
<
float
>
(
1
)
-
b
*
b
);
}
float
linear
(
const
float
a
,
const
float
b
)
{
return
a
;
}
}
// namespace typef
namespace
typed
{
double
relu
(
const
double
a
)
{
return
a
>
static_cast
<
double
>
(
0.0
)
?
a
:
static_cast
<
double
>
(
0.0
);
}
double
sigmoid
(
const
double
a
)
{
const
double
min
=
SIGMOID_THRESHOLD_MIN
;
const
double
max
=
SIGMOID_THRESHOLD_MAX
;
double
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
double
>
(
1.0
)
/
(
static_cast
<
double
>
(
1.0
)
+
exp
(
-
tmp
));
}
double
tanh
(
const
double
a
)
{
double
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
double
linear
(
const
double
a
)
{
return
a
;
}
double
relu
(
const
double
a
,
const
double
b
)
{
return
a
*
(
b
>
0.0
?
1.0
:
0.0
);
}
double
sigmoid
(
const
double
a
,
const
double
b
)
{
return
a
*
b
*
(
static_cast
<
double
>
(
1
)
-
b
);
}
double
tanh
(
const
double
a
,
const
double
b
)
{
return
a
*
(
static_cast
<
double
>
(
1
)
-
b
*
b
);
}
double
linear
(
const
double
a
,
const
double
b
)
{
return
a
;
}
}
// namespace typed
}
// namespace hppl
paddle/operators/math/detail/hl_functions.h
已删除
100644 → 0
浏览文件 @
2c5d4c6d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_FUNCTIONS_H_
#define HL_FUNCTIONS_H_
/**
* sigmoid threshold maximum
*/
#define SIGMOID_THRESHOLD_MIN -40.0
/**
* sigmoid threshold minimum
*/
#define SIGMOID_THRESHOLD_MAX 13.0
/**
* The maximum input value for exp, used to avoid overflow problem.
* currently only used for tanh function.
*/
#define EXP_MAX_INPUT 40.0
#ifndef __NVCC__
namespace
hppl
{
namespace
typef
{
float
relu
(
const
float
a
);
float
sigmoid
(
const
float
a
);
float
tanh
(
const
float
a
);
float
linear
(
const
float
a
);
float
relu
(
const
float
a
,
const
float
b
);
float
sigmoid
(
const
float
a
,
const
float
b
);
float
tanh
(
const
float
a
,
const
float
b
);
float
linear
(
const
float
a
,
const
float
b
);
}
// namespace typef
namespace
typed
{
double
relu
(
const
double
a
);
double
sigmoid
(
const
double
a
);
double
tanh
(
const
double
a
);
double
linear
(
const
double
a
);
double
relu
(
const
double
a
,
const
double
b
);
double
sigmoid
(
const
double
a
,
const
double
b
);
double
tanh
(
const
double
a
,
const
double
b
);
double
linear
(
const
double
a
,
const
double
b
);
}
// namespace typed
}
// namespace hppl
#ifdef __AVX__
#include "hl_avx_functions.h"
#endif
#else
#include "hl_gpu_functions.h"
#endif
#endif // HL_FUNCTIONS_H_
paddle/operators/math/detail/hl_gpu_functions.h
已删除
100644 → 0
浏览文件 @
2c5d4c6d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_GPU_FUNCTIONS_CUH_
#define HL_GPU_FUNCTIONS_CUH_
#include "hl_base.h"
namespace
hppl
{
namespace
typef
{
__device__
static
float
relu
(
const
float
a
)
{
return
a
>
0.0
f
?
a
:
0.0
f
;
}
__device__
static
float
sigmoid
(
const
float
a
)
{
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
float
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
__fdividef
(
1.0
f
,
1.0
f
+
__expf
(
-
tmp
));
}
__device__
static
float
tanh
(
const
float
a
)
{
float
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
__fdividef
(
2.0
f
,
(
1.0
f
+
__expf
(
-
2.0
f
*
tmp
)))
-
1.0
f
;
}
__device__
static
float
linear
(
const
float
a
)
{
return
a
;
}
__device__
static
float
relu
(
const
float
a
,
const
float
b
)
{
return
a
*
(
b
>
0.0
f
?
1.0
f
:
0.0
f
);
}
__device__
static
float
sigmoid
(
const
float
a
,
const
float
b
)
{
return
a
*
b
*
(
1.0
f
-
b
);
}
__device__
static
float
tanh
(
const
float
a
,
const
float
b
)
{
return
a
*
(
1.0
f
-
b
*
b
);
}
__device__
static
float
linear
(
const
float
a
,
const
float
b
)
{
return
a
;
}
}
// namespace typef
namespace
typed
{
__device__
static
double
relu
(
const
double
a
)
{
return
a
>
0.0
?
a
:
0.0
;
}
__device__
static
double
sigmoid
(
const
double
a
)
{
const
double
min
=
SIGMOID_THRESHOLD_MIN
;
const
double
max
=
SIGMOID_THRESHOLD_MAX
;
double
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
1.0
/
(
1.0
+
exp
(
-
tmp
));
}
__device__
static
double
tanh
(
const
double
a
)
{
double
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
-
2.0
*
a
)))
-
1.0
;
}
__device__
static
double
linear
(
const
double
a
)
{
return
a
;
}
__device__
static
double
relu
(
const
double
a
,
const
double
b
)
{
return
a
*
(
b
>
0.0
?
1.0
:
0.0
);
}
__device__
static
double
sigmoid
(
const
double
a
,
const
double
b
)
{
return
a
*
b
*
(
1
-
b
);
}
__device__
static
double
tanh
(
const
double
a
,
const
double
b
)
{
return
a
*
(
1.0
-
b
*
b
);
}
__device__
static
double
linear
(
const
double
a
,
const
double
b
)
{
return
a
;
}
}
// namespace typef
}
// namespace hppl
#endif // HL_GPU_FUNCTIONS_CUH_
paddle/operators/math/detail/lstm_cpu_kernel.h
浏览文件 @
1c8a0c4b
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/
hl_
activation_functions.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/lstm_compute.h"
namespace
paddle
{
...
...
@@ -26,7 +26,10 @@ namespace detail {
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
)
{
int
frameSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
T
rValueIn
;
T
rValueIg
;
T
rValueFg
;
...
...
@@ -58,7 +61,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
);
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
...
...
@@ -72,7 +75,10 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frameSize
)
{
LstmMetaGrad
<
T
>
grad
,
int
frameSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
T
rValueIn
;
T
rValueIg
;
T
rValueFg
;
...
...
@@ -122,7 +128,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
);
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
...
...
@@ -176,8 +182,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
hppl
::
avx
::
forward
[
active_node
],
hppl
::
avx
::
forward
[
active_gate
],
hppl
::
avx
::
forward
[
active_state
]);
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
...
...
@@ -246,8 +251,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
hppl
::
avx
::
backward
[
active_node
],
hppl
::
avx
::
backward
[
active_gate
],
hppl
::
avx
::
backward
[
active_state
]);
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
...
...
@@ -274,7 +278,8 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frameSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frameSize
);
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frameSize
,
active_node
,
active_gate
,
active_state
);
}
}
...
...
@@ -287,7 +292,8 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frameSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frameSize
);
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frameSize
,
active_node
,
active_gate
,
active_state
);
}
}
...
...
paddle/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
1c8a0c4b
...
...
@@ -13,13 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <
glog/logging.h
>
#include <
type_traits
>
namespace
paddle
{
namespace
operators
{
...
...
@@ -32,7 +31,9 @@ namespace detail {
*/
template
<
class
T
,
class
Op
,
bool
isBatch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
int
batchSize
)
{
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
...
...
@@ -69,7 +70,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
);
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
value
.
gateValue
[
frameIdx
]
=
rValueIn
;
value
.
gateValue
[
frameIdx
+
frameSize
]
=
rValueIg
;
...
...
@@ -88,7 +89,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
template
<
class
T
,
class
Op
,
bool
isBatch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frameSize
,
int
batchSize
)
{
int
batchSize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
...
...
@@ -141,7 +144,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
);
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
grad
.
gateGrad
[
frameIdx
]
=
rGradIn
;
grad
.
gateGrad
[
frameIdx
+
frameSize
]
=
rGradIg
;
...
...
@@ -197,11 +201,13 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if
(
batchSize
==
1
)
{
KeLstmForward
<
T
,
Op
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frameSize
,
batchSize
);
op
,
value
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
KeLstmForward
<
T
,
Op
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frameSize
,
batchSize
);
op
,
value
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
}
...
...
@@ -230,11 +236,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if
(
batchSize
==
1
)
{
KeLstmBackward
<
T
,
Op
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frameSize
,
batchSize
);
op
,
value
,
grad
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
else
{
KeLstmBackward
<
T
,
Op
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frameSize
,
batchSize
);
op
,
value
,
grad
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
active_state
);
}
}
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
1c8a0c4b
...
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/
hl_
activation_functions.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/hostdevice.h"
#include <type_traits>
...
...
@@ -24,45 +24,22 @@ namespace detail {
namespace
forward
{
template
<
typename
T
>
DEVICE
inline
T
sigmoid
(
const
T
a
)
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
T
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
T
>
(
1.0
)
/
(
static_cast
<
T
>
(
1.0
)
+
exp
(
-
tmp
));
}
template
<
typename
T
>
DEVICE
inline
T
tanh
(
const
T
a
)
{
T
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
template
<
class
T
>
class
lstm
{
public:
HOSTDEVICE
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
prevState
,
T
&
state
,
T
&
stateAtv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
)
{
#if 0
// TODO(qingqing) support to activation speficed by users
valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg;
valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state);
output = valueOg * stateAtv;
#else
valueIn
=
tanh
<
T
>
(
valueIn
);
valueIg
=
sigmoid
<
T
>
(
valueIg
+
prevState
*
checkI
);
valueFg
=
sigmoid
<
T
>
(
valueFg
+
prevState
*
checkF
);
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
valueIn
=
activation
(
valueIn
,
active_node
);
valueIg
=
activation
(
valueIg
+
prevState
*
checkI
,
active_gate
);
valueFg
=
activation
(
valueFg
+
prevState
*
checkF
,
active_gate
);
state
=
valueIn
*
valueIg
+
prevState
*
valueFg
;
valueOg
=
sigmoid
<
T
>
(
valueOg
+
state
*
checkO
);
stateAtv
=
tanh
<
T
>
(
state
);
valueOg
=
activation
(
valueOg
+
state
*
checkO
,
active_gate
);
stateAtv
=
activation
(
state
,
active_
state
);
output
=
valueOg
*
stateAtv
;
#endif
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
...
@@ -75,16 +52,19 @@ class lstm {
__m256
&
valueOg
,
__m256
&
prevState
,
__m256
&
state
,
__m256
&
stateAtv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
hppl
::
Active
<
__m256
>::
forward
actInput
,
hppl
::
Active
<
__m256
>::
forward
actGate
,
hppl
::
Active
<
__m256
>::
forward
actState
)
{
valueIn
=
actInput
(
valueIn
);
valueIg
=
actGate
(
_mm256_add_ps
(
valueIg
,
_mm256_mul_ps
(
prevState
,
checkI
)));
valueFg
=
actGate
(
_mm256_add_ps
(
valueFg
,
_mm256_mul_ps
(
prevState
,
checkF
)));
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
valueIn
=
activation
(
valueIn
,
active_node
);
valueIg
=
activation
(
_mm256_add_ps
(
valueIg
,
_mm256_mul_ps
(
prevState
,
checkI
)),
active_gate
);
valueFg
=
activation
(
_mm256_add_ps
(
valueFg
,
_mm256_mul_ps
(
prevState
,
checkF
)),
active_gate
);
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
valueIn
,
valueIg
),
_mm256_mul_ps
(
prevState
,
valueFg
));
valueOg
=
actGate
(
_mm256_add_ps
(
valueOg
,
_mm256_mul_ps
(
state
,
checkO
)));
stateAtv
=
actState
(
state
);
valueOg
=
activation
(
_mm256_add_ps
(
valueOg
,
_mm256_mul_ps
(
state
,
checkO
)),
active_gate
);
stateAtv
=
activation
(
state
,
active_state
);
output
=
_mm256_mul_ps
(
valueOg
,
stateAtv
);
}
#endif
...
...
@@ -95,16 +75,6 @@ class lstm {
namespace
backward
{
template
<
typename
T
>
DEVICE
inline
T
sigmoid
(
const
T
a
,
const
T
b
)
{
return
a
*
b
*
(
1.0
-
b
);
}
template
<
typename
T
>
DEVICE
inline
T
tanh
(
const
T
a
,
const
T
b
)
{
return
a
*
(
1.0
-
b
*
b
);
}
template
<
class
T
>
class
lstm
{
public:
...
...
@@ -113,29 +83,20 @@ class lstm {
T
&
prevState
,
T
&
prevStateGrad
,
T
&
state
,
T
&
stateGrad
,
T
&
stateAtv
,
T
&
outputGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
)
{
#if 0
// TODO(qingqing) support to activation speficed by users
gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn);
gradIg = actGate(stateGrad * valueIn, valueIg);
gradFg = actGate(stateGrad * prevState, valueFg);
T
&
checkFGrad
,
T
&
checkOGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
gradOg
=
activation
(
outputGrad
*
stateAtv
,
valueOg
,
active_gate
);
stateGrad
+=
activation
(
outputGrad
*
valueOg
,
stateAtv
,
active_state
)
+
gradOg
*
checkO
;
gradIn
=
activation
(
stateGrad
*
valueIg
,
valueIn
,
active_node
);
gradIg
=
activation
(
stateGrad
*
valueIn
,
valueIg
,
active_gate
);
gradFg
=
activation
(
stateGrad
*
prevState
,
valueFg
,
active_gate
);
prevStateGrad
=
gradIg
*
checkI
+
gradFg
*
checkF
+
stateGrad
*
valueFg
;
checkIGrad
=
gradIg
*
prevState
;
checkFGrad
=
gradFg
*
prevState
;
checkOGrad
=
gradOg
*
state
;
#else
gradOg
=
sigmoid
<
T
>
(
outputGrad
*
stateAtv
,
valueOg
);
stateGrad
+=
tanh
<
T
>
(
outputGrad
*
valueOg
,
stateAtv
)
+
gradOg
*
checkO
;
gradIn
=
tanh
<
T
>
(
stateGrad
*
valueIg
,
valueIn
);
gradIg
=
sigmoid
<
T
>
(
stateGrad
*
valueIn
,
valueIg
);
gradFg
=
sigmoid
<
T
>
(
stateGrad
*
prevState
,
valueFg
);
prevStateGrad
=
gradIg
*
checkI
+
gradFg
*
checkF
+
stateGrad
*
valueFg
;
checkIGrad
=
gradIg
*
prevState
;
checkFGrad
=
gradFg
*
prevState
;
checkOGrad
=
gradOg
*
state
;
#endif
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
...
@@ -143,24 +104,26 @@ class lstm {
#else
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
gradIn
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradOg
,
__m256
&
prevState
,
__m256
&
prevStateGrad
,
__m256
&
state
,
__m256
&
stateGrad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
hppl
::
Active
<
__m256
>::
backward
actInput
,
hppl
::
Active
<
__m256
>::
backward
actGate
,
hppl
::
Active
<
__m256
>::
backward
actState
)
{
gradOg
=
actGate
(
_mm256_mul_ps
(
outputGrad
,
stateAtv
),
valueOg
);
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
gradIn
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradOg
,
__m256
&
prevState
,
__m256
&
prevStateGrad
,
__m256
&
state
,
__m256
&
stateGrad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
gradOg
=
activation
(
_mm256_mul_ps
(
outputGrad
,
stateAtv
),
valueOg
,
active_gate
);
stateGrad
=
_mm256_add_ps
(
actState
(
_mm256_mul_ps
(
outputGrad
,
valueOg
),
stateAtv
),
stateGrad
);
activation
(
_mm256_mul_ps
(
outputGrad
,
valueOg
),
stateAtv
,
active_state
),
stateGrad
);
stateGrad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradOg
,
checkO
),
stateGrad
);
gradIn
=
actInput
(
_mm256_mul_ps
(
stateGrad
,
valueIg
),
valueIn
);
gradIg
=
actGate
(
_mm256_mul_ps
(
stateGrad
,
valueIn
),
valueIg
);
gradFg
=
actGate
(
_mm256_mul_ps
(
stateGrad
,
prevState
),
valueFg
);
gradIn
=
activation
(
_mm256_mul_ps
(
stateGrad
,
valueIg
),
valueIn
,
active_node
);
gradIg
=
activation
(
_mm256_mul_ps
(
stateGrad
,
valueIn
),
valueIg
,
active_gate
);
gradFg
=
activation
(
_mm256_mul_ps
(
stateGrad
,
prevState
),
valueFg
,
active_gate
);
prevStateGrad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradIg
,
checkI
),
_mm256_mul_ps
(
gradFg
,
checkF
));
prevStateGrad
=
...
...
python/paddle/v2/framework/tests/test_lstm_op.py
浏览文件 @
1c8a0c4b
...
...
@@ -157,7 +157,7 @@ class TestLstmOp(OpTest):
}
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
atol
=
1e-8
)
#TODO(qingqing) add more unit testing case
def
test_check_grad
(
self
):
...
...
@@ -167,7 +167,7 @@ class TestLstmOp(OpTest):
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
(
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
check_grad
(
[
'Input'
,
'Weight'
,
'Bias'
],
[
'Hidden'
],
max_relative_error
=
0.02
)
[
'Input'
,
'Weight'
,
'Bias'
],
[
'Hidden'
],
max_relative_error
=
5e-4
)
class
TestLstmOpHasNoInitial
(
TestLstmOp
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录