Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
dee5d35c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dee5d35c
编写于
9月 26, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine vmul
上级
92031968
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
100 addition
and
72 deletion
+100
-72
paddle/fluid/operators/math/cpu_vec.h
paddle/fluid/operators/math/cpu_vec.h
+16
-19
paddle/fluid/operators/math/cpu_vec_test.cc
paddle/fluid/operators/math/cpu_vec_test.cc
+6
-10
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+74
-39
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+1
-1
paddle/fluid/platform/cpu_info.cc
paddle/fluid/platform/cpu_info.cc
+1
-1
paddle/fluid/platform/cpu_info.h
paddle/fluid/platform/cpu_info.h
+1
-1
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+1
-1
未找到文件。
paddle/fluid/operators/math/cpu_vec.h
浏览文件 @
dee5d35c
...
@@ -125,10 +125,8 @@ inline void vec_scal<float, platform::jit::avx2>(const int n, const float a,
...
@@ -125,10 +125,8 @@ inline void vec_scal<float, platform::jit::avx2>(const int n, const float a,
}
}
template
<
>
template
<
>
inline
void
vec_scal
<
float
,
platform
::
jit
::
avx512_common
>
(
const
int
n
,
inline
void
vec_scal
<
float
,
platform
::
jit
::
avx512f
>
(
const
int
n
,
const
float
a
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
const
float
*
x
,
float
*
y
)
{
// TODO(TJ): enable me
// TODO(TJ): enable me
vec_scal
<
float
,
platform
::
jit
::
avx2
>
(
n
,
a
,
x
,
y
);
vec_scal
<
float
,
platform
::
jit
::
avx2
>
(
n
,
a
,
x
,
y
);
}
}
...
@@ -181,10 +179,10 @@ inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a,
...
@@ -181,10 +179,10 @@ inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a,
}
}
template
<
>
template
<
>
inline
void
vec_bias_sub
<
float
,
platform
::
jit
::
avx512
_common
>
(
const
int
n
,
inline
void
vec_bias_sub
<
float
,
platform
::
jit
::
avx512
f
>
(
const
int
n
,
const
float
a
,
const
float
a
,
const
float
*
x
,
const
float
*
x
,
float
*
y
)
{
float
*
y
)
{
// TODO(TJ): enable me
// TODO(TJ): enable me
vec_bias_sub
<
float
,
platform
::
jit
::
avx2
>
(
n
,
a
,
x
,
y
);
vec_bias_sub
<
float
,
platform
::
jit
::
avx2
>
(
n
,
a
,
x
,
y
);
}
}
...
@@ -242,7 +240,7 @@ inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x,
...
@@ -242,7 +240,7 @@ inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x,
}
}
template
<
>
template
<
>
inline
void
vec_cross
<
float
,
platform
::
jit
::
avx512
_common
>
(
inline
void
vec_cross
<
float
,
platform
::
jit
::
avx512
f
>
(
const
int
n
,
const
float
*
x
,
const
float
*
y
,
const
float
*
z
,
float
*
out
)
{
const
int
n
,
const
float
*
x
,
const
float
*
y
,
const
float
*
z
,
float
*
out
)
{
// TODO(TJ): enable me
// TODO(TJ): enable me
vec_cross
<
float
,
platform
::
jit
::
avx
>
(
n
,
x
,
y
,
z
,
out
);
vec_cross
<
float
,
platform
::
jit
::
avx
>
(
n
,
x
,
y
,
z
,
out
);
...
@@ -296,10 +294,10 @@ inline void vec_add_bias<float, platform::jit::avx2>(const int n, const float a,
...
@@ -296,10 +294,10 @@ inline void vec_add_bias<float, platform::jit::avx2>(const int n, const float a,
}
}
template
<
>
template
<
>
inline
void
vec_add_bias
<
float
,
platform
::
jit
::
avx512
_common
>
(
const
int
n
,
inline
void
vec_add_bias
<
float
,
platform
::
jit
::
avx512
f
>
(
const
int
n
,
const
float
a
,
const
float
a
,
const
float
*
x
,
const
float
*
x
,
float
*
y
)
{
float
*
y
)
{
// TODO(TJ): enable me
// TODO(TJ): enable me
vec_add_bias
<
float
,
platform
::
jit
::
avx2
>
(
n
,
a
,
x
,
y
);
vec_add_bias
<
float
,
platform
::
jit
::
avx2
>
(
n
,
a
,
x
,
y
);
}
}
...
@@ -390,9 +388,9 @@ inline void vec_sigmoid<float, platform::jit::avx2>(const int n, const float* x,
...
@@ -390,9 +388,9 @@ inline void vec_sigmoid<float, platform::jit::avx2>(const int n, const float* x,
}
}
template
<
>
template
<
>
inline
void
vec_sigmoid
<
float
,
platform
::
jit
::
avx512
_common
>
(
const
int
n
,
inline
void
vec_sigmoid
<
float
,
platform
::
jit
::
avx512
f
>
(
const
int
n
,
const
float
*
x
,
const
float
*
x
,
float
*
y
)
{
float
*
y
)
{
// TODO(TJ): enable me
// TODO(TJ): enable me
vec_sigmoid
<
float
,
platform
::
jit
::
avx2
>
(
n
,
x
,
y
);
vec_sigmoid
<
float
,
platform
::
jit
::
avx2
>
(
n
,
x
,
y
);
}
}
...
@@ -454,9 +452,8 @@ inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
...
@@ -454,9 +452,8 @@ inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
}
}
template
<
>
template
<
>
inline
void
vec_relu
<
float
,
platform
::
jit
::
avx512_common
>
(
const
int
n
,
inline
void
vec_relu
<
float
,
platform
::
jit
::
avx512f
>
(
const
int
n
,
const
float
*
x
,
const
float
*
x
,
float
*
y
)
{
float
*
y
)
{
// TODO(TJ): enable me
// TODO(TJ): enable me
vec_relu
<
float
,
platform
::
jit
::
avx2
>
(
n
,
x
,
y
);
vec_relu
<
float
,
platform
::
jit
::
avx2
>
(
n
,
x
,
y
);
}
}
...
...
paddle/fluid/operators/math/cpu_vec_test.cc
浏览文件 @
dee5d35c
...
@@ -110,7 +110,7 @@ TEST(CpuVecTest, sigmoid) {
...
@@ -110,7 +110,7 @@ TEST(CpuVecTest, sigmoid) {
TestAndBench
<
float
>
(
sz
,
vec_sigmoid
<
float
>
,
ref_sigmoid
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_sigmoid
<
float
>
,
ref_sigmoid
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx
>
,
ref_sigmoid
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx
>
,
ref_sigmoid
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx2
>
,
ref_sigmoid
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx2
>
,
ref_sigmoid
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx512
_common
>
,
TestAndBench
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx512
f
>
,
ref_sigmoid
<
float
>
);
ref_sigmoid
<
float
>
);
}
}
TestAndBench
<
double
>
(
30
,
vec_sigmoid
<
double
>
,
ref_sigmoid
<
double
>
);
TestAndBench
<
double
>
(
30
,
vec_sigmoid
<
double
>
,
ref_sigmoid
<
double
>
);
...
@@ -123,8 +123,7 @@ TEST(CpuVecTest, tanh) {
...
@@ -123,8 +123,7 @@ TEST(CpuVecTest, tanh) {
TestAndBench
<
float
>
(
sz
,
vec_tanh
<
float
>
,
ref_tanh
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_tanh
<
float
>
,
ref_tanh
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx
>
,
ref_tanh
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx
>
,
ref_tanh
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx2
>
,
ref_tanh
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx2
>
,
ref_tanh
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx512_common
>
,
TestAndBench
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx512f
>
,
ref_tanh
<
float
>
);
ref_tanh
<
float
>
);
}
}
TestAndBench
<
double
>
(
30
,
vec_tanh
<
double
>
,
ref_tanh
<
double
>
);
TestAndBench
<
double
>
(
30
,
vec_tanh
<
double
>
,
ref_tanh
<
double
>
);
}
}
...
@@ -136,8 +135,7 @@ TEST(CpuVecTest, relu) {
...
@@ -136,8 +135,7 @@ TEST(CpuVecTest, relu) {
TestAndBench
<
float
>
(
sz
,
vec_relu
<
float
>
,
ref_relu
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_relu
<
float
>
,
ref_relu
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx
>
,
ref_relu
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx
>
,
ref_relu
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx2
>
,
ref_relu
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx2
>
,
ref_relu
<
float
>
);
TestAndBench
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx512_common
>
,
TestAndBench
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx512f
>
,
ref_relu
<
float
>
);
ref_relu
<
float
>
);
}
}
TestAndBench
<
double
>
(
30
,
vec_relu
<
double
>
,
ref_relu
<
double
>
);
TestAndBench
<
double
>
(
30
,
vec_relu
<
double
>
,
ref_relu
<
double
>
);
}
}
...
@@ -170,7 +168,7 @@ TEST(CpuVecTest, inplace_sigmoid) {
...
@@ -170,7 +168,7 @@ TEST(CpuVecTest, inplace_sigmoid) {
TestInplace
<
float
>
(
sz
,
vec_sigmoid
<
float
>
,
ref_sigmoid
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_sigmoid
<
float
>
,
ref_sigmoid
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx
>
,
ref_sigmoid
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx
>
,
ref_sigmoid
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx2
>
,
ref_sigmoid
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx2
>
,
ref_sigmoid
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx512
_common
>
,
TestInplace
<
float
>
(
sz
,
vec_sigmoid
<
float
,
jit
::
avx512
f
>
,
ref_sigmoid
<
float
>
);
ref_sigmoid
<
float
>
);
}
}
TestInplace
<
double
>
(
30
,
vec_sigmoid
<
double
>
,
ref_sigmoid
<
double
>
);
TestInplace
<
double
>
(
30
,
vec_sigmoid
<
double
>
,
ref_sigmoid
<
double
>
);
...
@@ -183,8 +181,7 @@ TEST(CpuVecTest, inplace_tanh) {
...
@@ -183,8 +181,7 @@ TEST(CpuVecTest, inplace_tanh) {
TestInplace
<
float
>
(
sz
,
vec_tanh
<
float
>
,
ref_tanh
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_tanh
<
float
>
,
ref_tanh
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx
>
,
ref_tanh
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx
>
,
ref_tanh
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx2
>
,
ref_tanh
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx2
>
,
ref_tanh
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx512_common
>
,
TestInplace
<
float
>
(
sz
,
vec_tanh
<
float
,
jit
::
avx512f
>
,
ref_tanh
<
float
>
);
ref_tanh
<
float
>
);
}
}
TestInplace
<
double
>
(
30
,
vec_tanh
<
double
>
,
ref_tanh
<
double
>
);
TestInplace
<
double
>
(
30
,
vec_tanh
<
double
>
,
ref_tanh
<
double
>
);
}
}
...
@@ -196,8 +193,7 @@ TEST(CpuVecTest, inplace_relu) {
...
@@ -196,8 +193,7 @@ TEST(CpuVecTest, inplace_relu) {
TestInplace
<
float
>
(
sz
,
vec_relu
<
float
>
,
ref_relu
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_relu
<
float
>
,
ref_relu
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx
>
,
ref_relu
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx
>
,
ref_relu
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx2
>
,
ref_relu
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx2
>
,
ref_relu
<
float
>
);
TestInplace
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx512_common
>
,
TestInplace
<
float
>
(
sz
,
vec_relu
<
float
,
jit
::
avx512f
>
,
ref_relu
<
float
>
);
ref_relu
<
float
>
);
}
}
TestInplace
<
double
>
(
30
,
vec_relu
<
double
>
,
ref_relu
<
double
>
);
TestInplace
<
double
>
(
30
,
vec_relu
<
double
>
,
ref_relu
<
double
>
);
}
}
paddle/fluid/operators/math/jit_kernel.cc
浏览文件 @
dee5d35c
...
@@ -36,35 +36,38 @@ KernelPool& KernelPool::Instance() {
...
@@ -36,35 +36,38 @@ KernelPool& KernelPool::Instance() {
static
KernelPool
g_jit_kernels
;
static
KernelPool
g_jit_kernels
;
return
g_jit_kernels
;
return
g_jit_kernels
;
}
}
#define SEARCH_BLOCK(src, t, isa) \
#define SEARCH_BLOCK(src, t, isa) \
if (d < AVX_FLOAT_BLOCK) { \
if (d < AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kLT8>; \
Compute = src<t, isa, kLT8>; \
} else if (d == AVX_FLOAT_BLOCK) { \
} else if (d == AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ8>; \
Compute = src<t, isa, kEQ8>; \
} else if (d == AVX512_FLOAT_BLOCK) { \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ16>; \
Compute = src<t, isa, kGT8LT16>; \
} else { \
} else if (d == AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kGT16>; \
Compute = src<t, isa, kEQ16>; \
} else { \
Compute = src<t, isa, kGT16>; \
}
}
#define SEARCH_ISA_BLOCK(src, t)
\
#define SEARCH_ISA_BLOCK(src, t) \
if (jit::MayIUse(jit::avx512
_common
)) { \
if (jit::MayIUse(jit::avx512
f
)) { \
SEARCH_BLOCK(src, t, jit::avx512
_common
); \
SEARCH_BLOCK(src, t, jit::avx512
f
); \
} else if (jit::MayIUse(jit::avx2)) {
\
} else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(src, t, jit::avx2);
\
SEARCH_BLOCK(src, t, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) {
\
} else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(src, t, jit::avx);
\
SEARCH_BLOCK(src, t, jit::avx); \
} else {
\
} else { \
SEARCH_BLOCK(src, t, jit::isa_any);
\
SEARCH_BLOCK(src, t, jit::isa_any); \
}
}
#define FOR_EACH_BLOCK(macro_, isa) \
// do not include lt8, eq8, eq16
macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kEQ16) macro_(isa, kGT16)
#define FOR_EACH_COMMON_BLOCK(macro_, isa) \
macro_(isa, kGT8LT16) macro_(isa, kGT16)
#define FOR_EACH_ISA_
BLOCK(macro_)
\
#define FOR_EACH_ISA_
COMMON_BLOCK(macro_)
\
FOR_EACH_BLOCK(macro_, jit::avx512
_common)
\
FOR_EACH_BLOCK(macro_, jit::avx512
f)
\
FOR_EACH_BLOCK(macro_, jit::avx2)
\
FOR_EACH_BLOCK(macro_, jit::avx2) \
FOR_EACH_BLOCK(macro_, jit::avx)
\
FOR_EACH_BLOCK(macro_, jit::avx) \
FOR_EACH_BLOCK(macro_, jit::any)
FOR_EACH_BLOCK(macro_, jit::any)
#define VMUL_ANY \
#define VMUL_ANY \
...
@@ -78,24 +81,56 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) {
...
@@ -78,24 +81,56 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) {
}
}
#ifdef PADDLE_USE_MKLML
#ifdef PADDLE_USE_MKLML
#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block)
\
#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \
template <>
\
template <> \
static
void VMulCompute<float, isa, block>(const int n, const float* x, \
void VMulCompute<float, isa, block>(const int n, const float* x, \
const float* y, float* z) { \
const float* y, float* z) { \
platform::dynload::vsMul(n, x, y, z);
\
platform::dynload::vsMul(n, x, y, z); \
}
}
#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block)
\
#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \
template <>
\
template <> \
static
void VMulCompute<double, isa, block>(const int n, const double* x, \
void VMulCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \
const double* y, float* z) { \
platform::dynload::vdMul(n, x, y, z);
\
platform::dynload::vdMul(n, x, y, z); \
}
}
FOR_EACH_ISA_BLOCK
(
DEFINE_VMUL_COMPUTE_FLOAT
)
FOR_EACH_ISA_COMMON_BLOCK
(
DEFINE_VMUL_COMPUTE_FLOAT
)
FOR_EACH_ISA_BLOCK
(
DEFINE_VMUL_COMPUTE_DOUBLE
)
FOR_EACH_ISA_COMMON_BLOCK
(
DEFINE_VMUL_COMPUTE_DOUBLE
)
// TODO(TJ): add EQ8
DEFINE_VMUL_COMPUTE_FLOAT
(
jit
::
avx
,
kLT8
)
DEFINE_VMUL_COMPUTE_FLOAT
(
jit
::
avx
,
kEQ16
)
#endif
// mkl > avx > for, ">" means better
#ifdef PADDLE_USE_MKLML
DEFINE_VMUL_COMPUTE_FLOAT
(
jit
::
avx
,
kEQ8
)
#elif defined __AVX__
template
<
>
void
VMulCompute
<
float
,
jit
::
avx
,
kEQ8
>
(
const
int
n
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
__m256
tmpx
,
tmpy
;
tmpx
=
_mm256_loadu_ps
(
x
);
tmpy
=
_mm256_loadu_ps
(
y
);
tmpx
=
_mm256_mul_ps
(
tmpx
,
tmpy
);
_mm256_storeu_ps
(
z
,
tmpx
);
}
#endif
// avx2 > mkl > for
#ifdef __AVX2__
template
<
>
void
VMulCompute
<
float
,
jit
::
avx2
,
kEQ8
>
(
const
int
n
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
__m256
tmpx
,
tmpy
;
tmpx
=
_mm256_loadu_ps
(
x
);
tmpy
=
_mm256_loadu_ps
(
y
);
tmpx
=
_mm256_mul_ps
(
tmpx
,
tmpy
);
_mm256_storeu_ps
(
z
,
tmpx
);
}
#elif defined PADDLE_USE_MKLML
DEFINE_VMUL_COMPUTE_FLOAT
(
jit
::
avx2
,
kEQ8
)
#endif
#endif
// TODO(TJ): test and complete avx512
#undef DEFINE_VMUL_COMPUTE_FLOAT
#undef DEFINE_VMUL_COMPUTE_FLOAT
#undef DEFINE_VMUL_COMPUTE_DOUBLE
#undef DEFINE_VMUL_COMPUTE_DOUBLE
...
@@ -142,8 +177,8 @@ LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
...
@@ -142,8 +177,8 @@ LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
:
Kernel
(),
d_
(
d
)
{
:
Kernel
(),
d_
(
d
)
{
d2_
=
d
*
2
;
d2_
=
d
*
2
;
d3_
=
d
*
3
;
d3_
=
d
*
3
;
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512
_common
))
{
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512
f
))
{
math
::
VecActivations
<
float
,
platform
::
jit
::
avx512
_common
>
act_functor
;
math
::
VecActivations
<
float
,
platform
::
jit
::
avx512
f
>
act_functor
;
act_gate_
=
act_functor
(
act_gate_str
);
act_gate_
=
act_functor
(
act_gate_str
);
act_cell_
=
act_functor
(
act_cell_str
);
act_cell_
=
act_functor
(
act_cell_str
);
act_cand_
=
act_functor
(
act_cand_str
);
act_cand_
=
act_functor
(
act_cand_str
);
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
dee5d35c
...
@@ -36,7 +36,7 @@ namespace jitkernel {
...
@@ -36,7 +36,7 @@ namespace jitkernel {
#define AVX512_FLOAT_BLOCK 16
#define AVX512_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8
#define AVX512_DOUBLE_BLOCK 8
typedef
enum
{
kLT8
,
kEQ8
,
kEQ16
,
kGT16
}
jit_block
;
typedef
enum
{
kLT8
,
kEQ8
,
k
GT8LT16
,
k
EQ16
,
kGT16
}
jit_block
;
class
Kernel
{
class
Kernel
{
public:
public:
...
...
paddle/fluid/platform/cpu_info.cc
浏览文件 @
dee5d35c
...
@@ -128,7 +128,7 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
...
@@ -128,7 +128,7 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
return
cpu
.
has
(
Cpu
::
tAVX
);
return
cpu
.
has
(
Cpu
::
tAVX
);
case
avx2
:
case
avx2
:
return
cpu
.
has
(
Cpu
::
tAVX2
);
return
cpu
.
has
(
Cpu
::
tAVX2
);
case
avx512
_common
:
case
avx512
f
:
return
cpu
.
has
(
Cpu
::
tAVX512F
);
return
cpu
.
has
(
Cpu
::
tAVX512F
);
case
avx512_core
:
case
avx512_core
:
return
true
&&
cpu
.
has
(
Cpu
::
tAVX512F
)
&&
cpu
.
has
(
Cpu
::
tAVX512BW
)
&&
return
true
&&
cpu
.
has
(
Cpu
::
tAVX512F
)
&&
cpu
.
has
(
Cpu
::
tAVX512BW
)
&&
...
...
paddle/fluid/platform/cpu_info.h
浏览文件 @
dee5d35c
...
@@ -43,7 +43,7 @@ typedef enum {
...
@@ -43,7 +43,7 @@ typedef enum {
sse42
,
sse42
,
avx
,
avx
,
avx2
,
avx2
,
avx512
_common
,
avx512
f
,
avx512_core
,
avx512_core
,
avx512_core_vnni
,
avx512_core_vnni
,
avx512_mic
,
avx512_mic
,
...
...
paddle/fluid/platform/init.cc
浏览文件 @
dee5d35c
...
@@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
...
@@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
platform
::
SetNumThreads
(
FLAGS_paddle_num_threads
);
platform
::
SetNumThreads
(
FLAGS_paddle_num_threads
);
#endif
#endif
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512
_common
))
{
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512
f
))
{
#ifndef __AVX512F__
#ifndef __AVX512F__
LOG
(
WARNING
)
<<
"AVX512F is available, Please re-compile on local machine"
;
LOG
(
WARNING
)
<<
"AVX512F is available, Please re-compile on local machine"
;
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录