Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1e06a32a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1e06a32a
编写于
11月 14, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vexp jitcode of size 8
test=develop
上级
23544096
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
241 addition
and
88 deletion
+241
-88
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+126
-0
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+24
-0
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+1
-0
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+9
-22
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+71
-65
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+8
-0
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+2
-1
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
1e06a32a
...
...
@@ -151,6 +151,132 @@ void ReluJitCode::generate() {
}
ret
();
}
bool
VExpJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
void
VExpJitCode
::
generate
()
{
preCode
();
// push some?
// in: ymm0, out: ymm1
// use ymm 0~5 (and ymm 14~15 if avx only)
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
ymm_fx
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
ymm_fx
,
ymm_fx
,
ymm_tmp
);
vroundps
(
ymm_fy
,
ymm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
ymm_mask
,
ymm_fy
,
ymm_fx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
ymm_mask
,
ymm_mask
,
ymm_tmp
);
vsubps
(
ymm_fx
,
ymm_fy
,
ymm_mask
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
// ymm_z use same with mask
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
ymm_dst
,
ymm_src
,
ymm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
AVX_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
}
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_z
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
// build 2^n
ymm_t
ymm_int
=
ymm_fx
;
vcvttps2dq
(
ymm_int
,
ymm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
if
(
MayIUse
(
avx2
))
{
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
// use ymm_int, ymm_tmp and reg_ptr_global
xmm_t
xtmp1
=
xmm_t
(
ymm_int
);
// or magic number should equal the ymm_int
xmm_t
xtmp2
=
xmm_t
(
ymm_tmp
);
// or magic number should equal the ymm_tmp
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_global
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_global
+
AVX_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_global
],
xtmp1
);
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_global
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_global
+
(
AVX_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_global
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_global
]);
}
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
// ret();
postCode
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
1e06a32a
...
...
@@ -108,6 +108,30 @@ class ReluJitCode : public JitCode {
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
class
VExpJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
VExpJitCode
);
explicit
VExpJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_fx
=
ymm_t
(
2
);
ymm_t
ymm_fy
=
ymm_t
(
3
);
ymm_t
ymm_mask
=
ymm_t
(
4
);
ymm_t
ymm_z
=
ymm_t
(
4
);
ymm_t
ymm_tmp
=
ymm_t
(
5
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
1e06a32a
...
...
@@ -117,6 +117,7 @@ template <typename T>
class
VExpKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
1e06a32a
...
...
@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
...
...
@@ -128,18 +124,11 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
#endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
...
...
@@ -191,7 +180,7 @@ bool VMulKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
...
...
@@ -241,7 +230,7 @@ bool VAddKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
...
...
@@ -273,7 +262,7 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
...
...
@@ -322,7 +311,7 @@ bool VScalKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
...
...
@@ -355,14 +344,14 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/*
init
*/
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions*/
*
8
/*
everage byte for each instruction
*/
;
size_t
sz
=
96
/*
init size
*/
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions
*/
*
8
/*
average bytes for each instruction
*/
;
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
...
...
@@ -388,8 +377,6 @@ bool VReluKernelImpl<float>::useJIT(int d) {
}
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
1e06a32a
...
...
@@ -16,6 +16,11 @@ limitations under the License. */
#include <cmath> // for exp
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
...
...
@@ -30,41 +35,84 @@ namespace math {
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
// TODO(TJ): move refer codes to one file
template
<
typename
T
>
void
VExpRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
std
::
exp
(
x
[
i
]);
}
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VExpMKL
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
>
void
VExpMKL
<
float
>
(
const
float
*
x
,
float
*
y
,
int
n
)
{
platform
::
dynload
::
vsExp
(
n
,
x
,
y
);
}
template
<
>
void
VExpMKL
<
double
>
(
const
double
*
x
,
double
*
y
,
int
n
)
{
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
}
#endif
/* VExp JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
>
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
public:
explicit
VExpKernelImpl
(
int
d
)
:
VExpKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
y
[
i
]
=
std
::
exp
(
x
[
i
]);
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VExpKernelImpl
(
int
d
)
:
VExpKernel
<
T
>
()
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VExpJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VExpMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
VExpRefer
<
T
>
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
VExpRefer
(
x
,
y
,
this
->
num_
);
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VExpJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VExpKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VExpJitCode
::
init
(
d
);
}
#endif
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VExpKernelImpl<float, isa, block>::ComputeDeprecated(const float* x, \
float* y) const { \
platform::dynload::vsExp(this->num_, x, y); \
}
template
<
>
bool
VExpKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VExpKernelImpl<double, isa, block>::ComputeDeprecated( \
const double* x, double* y) const { \
platform::dynload::vdExp(this->num_, x, y); \
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kLT8
);
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT8LT16
);
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
template
<
>
bool
VExpKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
namespace
detail
{
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
#ifdef __AVX__
namespace
detail
{
#define ALIGN32 __attribute__((aligned(32)))
...
...
@@ -195,7 +243,6 @@ __m256 ExpAVX(__m256 x) {
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
}
#endif
#ifdef __AVX2__
__m256
ExpAVX2
(
__m256
x
)
{
...
...
@@ -211,47 +258,6 @@ __m256 ExpAVX2(__m256 x) {
}
// namespace detail
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, expisa(tmp)); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = expisa(tmp0); \
tmp1 = expisa(tmp1); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
REGISTER_JITKERNEL_DEPRECATED
(
vexp
,
VExpKernel
);
/* VSigmoid JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
1e06a32a
...
...
@@ -15,12 +15,20 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
1e06a32a
...
...
@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
ComputeDeprecated
(
x_data
,
ztgt_data
);
// ker->ComputeDeprecated(x_data, ztgt_data);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录