Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1e06a32a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1e06a32a
编写于
11月 14, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vexp jitcode of size 8
test=develop
上级
23544096
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
241 addition
and
88 deletion
+241
-88
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+126
-0
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+24
-0
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+1
-0
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+9
-22
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+71
-65
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+8
-0
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+2
-1
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
1e06a32a
...
@@ -151,6 +151,132 @@ void ReluJitCode::generate() {
...
@@ -151,6 +151,132 @@ void ReluJitCode::generate() {
}
}
ret
();
ret
();
}
}
bool
VExpJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
void
VExpJitCode
::
generate
()
{
preCode
();
// push some?
// in: ymm0, out: ymm1
// use ymm 0~5 (and ymm 14~15 if avx only)
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
ymm_fx
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
ymm_fx
,
ymm_fx
,
ymm_tmp
);
vroundps
(
ymm_fy
,
ymm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
ymm_mask
,
ymm_fy
,
ymm_fx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
ymm_mask
,
ymm_mask
,
ymm_tmp
);
vsubps
(
ymm_fx
,
ymm_fy
,
ymm_mask
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
// ymm_z use same with mask
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
ymm_dst
,
ymm_src
,
ymm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
AVX_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
}
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_z
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
// build 2^n
ymm_t
ymm_int
=
ymm_fx
;
vcvttps2dq
(
ymm_int
,
ymm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
if
(
MayIUse
(
avx2
))
{
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
// use ymm_int, ymm_tmp and reg_ptr_global
xmm_t
xtmp1
=
xmm_t
(
ymm_int
);
// or magic number should equal the ymm_int
xmm_t
xtmp2
=
xmm_t
(
ymm_tmp
);
// or magic number should equal the ymm_tmp
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_global
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_global
+
AVX_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_global
],
xtmp1
);
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_global
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_global
+
(
AVX_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_global
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_global
]);
}
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
// ret();
postCode
();
}
}
// namespace gen
}
// namespace gen
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
1e06a32a
...
@@ -108,6 +108,30 @@ class ReluJitCode : public JitCode {
...
@@ -108,6 +108,30 @@ class ReluJitCode : public JitCode {
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
};
class
VExpJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
VExpJitCode
);
explicit
VExpJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_fx
=
ymm_t
(
2
);
ymm_t
ymm_fy
=
ymm_t
(
3
);
ymm_t
ymm_mask
=
ymm_t
(
4
);
ymm_t
ymm_z
=
ymm_t
(
4
);
ymm_t
ymm_tmp
=
ymm_t
(
5
);
};
}
// namespace gen
}
// namespace gen
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
1e06a32a
...
@@ -117,6 +117,7 @@ template <typename T>
...
@@ -117,6 +117,7 @@ template <typename T>
class
VExpKernel
:
public
VActKernel
<
T
>
{
class
VExpKernel
:
public
VActKernel
<
T
>
{
public:
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
1e06a32a
...
@@ -25,10 +25,6 @@ limitations under the License. */
...
@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h"
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
...
@@ -128,18 +124,11 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
...
@@ -128,18 +124,11 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
#endif
#endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */
/* VMUL JitKernel */
template
<
typename
T
>
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
...
@@ -191,7 +180,7 @@ bool VMulKernelImpl<double>::useMKL(int d) {
...
@@ -191,7 +180,7 @@ bool VMulKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
...
@@ -241,7 +230,7 @@ bool VAddKernelImpl<double>::useMKL(int d) {
...
@@ -241,7 +230,7 @@ bool VAddKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
...
@@ -273,7 +262,7 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
...
@@ -273,7 +262,7 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
...
@@ -322,7 +311,7 @@ bool VScalKernelImpl<double>::useMKL(int d) {
...
@@ -322,7 +311,7 @@ bool VScalKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
...
@@ -355,14 +344,14 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
...
@@ -355,14 +344,14 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/*
init
*/
+
size_t
sz
=
96
/*
init size
*/
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions*/
*
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions
*/
*
8
/*
everage byte for each instruction
*/
;
8
/*
average bytes for each instruction
*/
;
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
return
;
...
@@ -388,8 +377,6 @@ bool VReluKernelImpl<float>::useJIT(int d) {
...
@@ -388,8 +377,6 @@ bool VReluKernelImpl<float>::useJIT(int d) {
}
}
#endif
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
1e06a32a
...
@@ -16,6 +16,11 @@ limitations under the License. */
...
@@ -16,6 +16,11 @@ limitations under the License. */
#include <cmath> // for exp
#include <cmath> // for exp
#include <string>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#endif
...
@@ -30,41 +35,84 @@ namespace math {
...
@@ -30,41 +35,84 @@ namespace math {
namespace
jitkernel
{
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
namespace
jit
=
platform
::
jit
;
// TODO(TJ): move refer codes to one file
template
<
typename
T
>
void
VExpRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
std
::
exp
(
x
[
i
]);
}
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VExpMKL
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
>
void
VExpMKL
<
float
>
(
const
float
*
x
,
float
*
y
,
int
n
)
{
platform
::
dynload
::
vsExp
(
n
,
x
,
y
);
}
template
<
>
void
VExpMKL
<
double
>
(
const
double
*
x
,
double
*
y
,
int
n
)
{
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
}
#endif
/* VExp JitKernel */
/* VExp JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
>
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
public:
public:
explicit
VExpKernelImpl
(
int
d
)
:
VExpKernel
<
T
>
()
{
this
->
num_
=
d
;
}
JITKERNEL_DECLARE_STATIC_FUNC
;
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
explicit
VExpKernelImpl
(
int
d
)
:
VExpKernel
<
T
>
()
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
y
[
i
]
=
std
::
exp
(
x
[
i
]);
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VExpJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VExpMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
VExpRefer
<
T
>
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
VExpRefer
(
x
,
y
,
this
->
num_
);
}
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VExpJitCode
>
jitcode_
{
nullptr
};
#endif
};
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VExpKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VExpJitCode
::
init
(
d
);
}
#endif
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template
<
>
template <> \
bool
VExpKernelImpl
<
float
>::
useMKL
(
int
d
)
{
void VExpKernelImpl<float, isa, block>::ComputeDeprecated(const float* x, \
return
d
>
512
;
float* y) const { \
}
platform::dynload::vsExp(this->num_, x, y); \
}
#define MKL_DOUBLE(isa, block) \
template
<
>
template <> \
bool
VExpKernelImpl
<
double
>::
useMKL
(
int
d
)
{
void VExpKernelImpl<double, isa, block>::ComputeDeprecated( \
return
true
;
const double* x, double* y) const { \
}
platform::dynload::vdExp(this->num_, x, y); \
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kLT8
);
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT8LT16
);
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#endif
namespace
detail
{
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
#ifdef __AVX__
namespace
detail
{
#define ALIGN32 __attribute__((aligned(32)))
#define ALIGN32 __attribute__((aligned(32)))
...
@@ -195,7 +243,6 @@ __m256 ExpAVX(__m256 x) {
...
@@ -195,7 +243,6 @@ __m256 ExpAVX(__m256 x) {
y
=
_mm256_mul_ps
(
y
,
pow2n
);
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
return
y
;
}
}
#endif
#ifdef __AVX2__
#ifdef __AVX2__
__m256
ExpAVX2
(
__m256
x
)
{
__m256
ExpAVX2
(
__m256
x
)
{
...
@@ -211,47 +258,6 @@ __m256 ExpAVX2(__m256 x) {
...
@@ -211,47 +258,6 @@ __m256 ExpAVX2(__m256 x) {
}
// namespace detail
}
// namespace detail
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, expisa(tmp)); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = expisa(tmp0); \
tmp1 = expisa(tmp1); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
REGISTER_JITKERNEL_DEPRECATED
(
vexp
,
VExpKernel
);
/* VSigmoid JitKernel */
/* VSigmoid JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
1e06a32a
...
@@ -15,12 +15,20 @@ limitations under the License. */
...
@@ -15,12 +15,20 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
namespace
jitkernel
{
namespace
jitkernel
{
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
std::string ker_class##Impl<float>::name(int d) { \
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
1e06a32a
...
@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
...
@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
ComputeDeprecated
(
x_data
,
ztgt_data
);
// ker->ComputeDeprecated(x_data, ztgt_data);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录