提交 dee5d35c 编写于 作者: T tensor-tang

refine vmul

上级 92031968
...@@ -125,10 +125,8 @@ inline void vec_scal<float, platform::jit::avx2>(const int n, const float a, ...@@ -125,10 +125,8 @@ inline void vec_scal<float, platform::jit::avx2>(const int n, const float a,
} }
template <> template <>
inline void vec_scal<float, platform::jit::avx512_common>(const int n, inline void vec_scal<float, platform::jit::avx512f>(const int n, const float a,
const float a, const float* x, float* y) {
const float* x,
float* y) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_scal<float, platform::jit::avx2>(n, a, x, y); vec_scal<float, platform::jit::avx2>(n, a, x, y);
} }
...@@ -181,7 +179,7 @@ inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a, ...@@ -181,7 +179,7 @@ inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a,
} }
template <> template <>
inline void vec_bias_sub<float, platform::jit::avx512_common>(const int n, inline void vec_bias_sub<float, platform::jit::avx512f>(const int n,
const float a, const float a,
const float* x, const float* x,
float* y) { float* y) {
...@@ -242,7 +240,7 @@ inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x, ...@@ -242,7 +240,7 @@ inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x,
} }
template <> template <>
inline void vec_cross<float, platform::jit::avx512_common>( inline void vec_cross<float, platform::jit::avx512f>(
const int n, const float* x, const float* y, const float* z, float* out) { const int n, const float* x, const float* y, const float* z, float* out) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_cross<float, platform::jit::avx>(n, x, y, z, out); vec_cross<float, platform::jit::avx>(n, x, y, z, out);
...@@ -296,7 +294,7 @@ inline void vec_add_bias<float, platform::jit::avx2>(const int n, const float a, ...@@ -296,7 +294,7 @@ inline void vec_add_bias<float, platform::jit::avx2>(const int n, const float a,
} }
template <> template <>
inline void vec_add_bias<float, platform::jit::avx512_common>(const int n, inline void vec_add_bias<float, platform::jit::avx512f>(const int n,
const float a, const float a,
const float* x, const float* x,
float* y) { float* y) {
...@@ -390,7 +388,7 @@ inline void vec_sigmoid<float, platform::jit::avx2>(const int n, const float* x, ...@@ -390,7 +388,7 @@ inline void vec_sigmoid<float, platform::jit::avx2>(const int n, const float* x,
} }
template <> template <>
inline void vec_sigmoid<float, platform::jit::avx512_common>(const int n, inline void vec_sigmoid<float, platform::jit::avx512f>(const int n,
const float* x, const float* x,
float* y) { float* y) {
// TODO(TJ): enable me // TODO(TJ): enable me
...@@ -454,8 +452,7 @@ inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x, ...@@ -454,8 +452,7 @@ inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
} }
template <> template <>
inline void vec_relu<float, platform::jit::avx512_common>(const int n, inline void vec_relu<float, platform::jit::avx512f>(const int n, const float* x,
const float* x,
float* y) { float* y) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_relu<float, platform::jit::avx2>(n, x, y); vec_relu<float, platform::jit::avx2>(n, x, y);
......
...@@ -110,7 +110,7 @@ TEST(CpuVecTest, sigmoid) { ...@@ -110,7 +110,7 @@ TEST(CpuVecTest, sigmoid) {
TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>); TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>); TestAndBench<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>); TestAndBench<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512_common>, TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512f>,
ref_sigmoid<float>); ref_sigmoid<float>);
} }
TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>); TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
...@@ -123,8 +123,7 @@ TEST(CpuVecTest, tanh) { ...@@ -123,8 +123,7 @@ TEST(CpuVecTest, tanh) {
TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>); TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>); TestAndBench<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>); TestAndBench<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx512_common>, TestAndBench<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
ref_tanh<float>);
} }
TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>); TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>);
} }
...@@ -136,8 +135,7 @@ TEST(CpuVecTest, relu) { ...@@ -136,8 +135,7 @@ TEST(CpuVecTest, relu) {
TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>); TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>); TestAndBench<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>); TestAndBench<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx512_common>, TestAndBench<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
ref_relu<float>);
} }
TestAndBench<double>(30, vec_relu<double>, ref_relu<double>); TestAndBench<double>(30, vec_relu<double>, ref_relu<double>);
} }
...@@ -170,7 +168,7 @@ TEST(CpuVecTest, inplace_sigmoid) { ...@@ -170,7 +168,7 @@ TEST(CpuVecTest, inplace_sigmoid) {
TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>); TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>); TestInplace<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>); TestInplace<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx512_common>, TestInplace<float>(sz, vec_sigmoid<float, jit::avx512f>,
ref_sigmoid<float>); ref_sigmoid<float>);
} }
TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>); TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
...@@ -183,8 +181,7 @@ TEST(CpuVecTest, inplace_tanh) { ...@@ -183,8 +181,7 @@ TEST(CpuVecTest, inplace_tanh) {
TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>); TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>); TestInplace<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>); TestInplace<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx512_common>, TestInplace<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
ref_tanh<float>);
} }
TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>); TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>);
} }
...@@ -196,8 +193,7 @@ TEST(CpuVecTest, inplace_relu) { ...@@ -196,8 +193,7 @@ TEST(CpuVecTest, inplace_relu) {
TestInplace<float>(sz, vec_relu<float>, ref_relu<float>); TestInplace<float>(sz, vec_relu<float>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>); TestInplace<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>); TestInplace<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx512_common>, TestInplace<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
ref_relu<float>);
} }
TestInplace<double>(30, vec_relu<double>, ref_relu<double>); TestInplace<double>(30, vec_relu<double>, ref_relu<double>);
} }
...@@ -41,6 +41,8 @@ KernelPool& KernelPool::Instance() { ...@@ -41,6 +41,8 @@ KernelPool& KernelPool::Instance() {
Compute = src<t, isa, kLT8>; \ Compute = src<t, isa, kLT8>; \
} else if (d == AVX_FLOAT_BLOCK) { \ } else if (d == AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ8>; \ Compute = src<t, isa, kEQ8>; \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kGT8LT16>; \
} else if (d == AVX512_FLOAT_BLOCK) { \ } else if (d == AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ16>; \ Compute = src<t, isa, kEQ16>; \
} else { \ } else { \
...@@ -48,8 +50,8 @@ KernelPool& KernelPool::Instance() { ...@@ -48,8 +50,8 @@ KernelPool& KernelPool::Instance() {
} }
#define SEARCH_ISA_BLOCK(src, t) \ #define SEARCH_ISA_BLOCK(src, t) \
if (jit::MayIUse(jit::avx512_common)) { \ if (jit::MayIUse(jit::avx512f)) { \
SEARCH_BLOCK(src, t, jit::avx512_common); \ SEARCH_BLOCK(src, t, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \ } else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(src, t, jit::avx2); \ SEARCH_BLOCK(src, t, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \ } else if (jit::MayIUse(jit::avx)) { \
...@@ -58,11 +60,12 @@ KernelPool& KernelPool::Instance() { ...@@ -58,11 +60,12 @@ KernelPool& KernelPool::Instance() {
SEARCH_BLOCK(src, t, jit::isa_any); \ SEARCH_BLOCK(src, t, jit::isa_any); \
} }
#define FOR_EACH_BLOCK(macro_, isa) \ // do not include lt8, eq8, eq16
macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kEQ16) macro_(isa, kGT16) #define FOR_EACH_COMMON_BLOCK(macro_, isa) \
macro_(isa, kGT8LT16) macro_(isa, kGT16)
#define FOR_EACH_ISA_BLOCK(macro_) \ #define FOR_EACH_ISA_COMMON_BLOCK(macro_) \
FOR_EACH_BLOCK(macro_, jit::avx512_common) \ FOR_EACH_BLOCK(macro_, jit::avx512f) \
FOR_EACH_BLOCK(macro_, jit::avx2) \ FOR_EACH_BLOCK(macro_, jit::avx2) \
FOR_EACH_BLOCK(macro_, jit::avx) \ FOR_EACH_BLOCK(macro_, jit::avx) \
FOR_EACH_BLOCK(macro_, jit::any) FOR_EACH_BLOCK(macro_, jit::any)
...@@ -80,23 +83,55 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) { ...@@ -80,23 +83,55 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) {
#ifdef PADDLE_USE_MKLML #ifdef PADDLE_USE_MKLML
#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \ #define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \
template <> \ template <> \
static void VMulCompute<float, isa, block>(const int n, const float* x, \ void VMulCompute<float, isa, block>(const int n, const float* x, \
const float* y, float* z) { \ const float* y, float* z) { \
platform::dynload::vsMul(n, x, y, z); \ platform::dynload::vsMul(n, x, y, z); \
} }
#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \ #define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \
template <> \ template <> \
static void VMulCompute<double, isa, block>(const int n, const double* x, \ void VMulCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \ const double* y, float* z) { \
platform::dynload::vdMul(n, x, y, z); \ platform::dynload::vdMul(n, x, y, z); \
} }
FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT) FOR_EACH_ISA_COMMON_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT)
FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE) FOR_EACH_ISA_COMMON_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE)
// TODO(TJ): add EQ8 DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kLT8)
DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kEQ16)
#endif #endif
// mkl > avx > for, ">" means better
#ifdef PADDLE_USE_MKLML
DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kEQ8)
#elif defined __AVX__
template <>
void VMulCompute<float, jit::avx, kEQ8>(const int n, const float* x,
const float* y, float* z) {
__m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y);
tmpx = _mm256_mul_ps(tmpx, tmpy);
_mm256_storeu_ps(z, tmpx);
}
#endif
// avx2 > mkl > for
#ifdef __AVX2__
template <>
void VMulCompute<float, jit::avx2, kEQ8>(const int n, const float* x,
const float* y, float* z) {
__m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y);
tmpx = _mm256_mul_ps(tmpx, tmpy);
_mm256_storeu_ps(z, tmpx);
}
#elif defined PADDLE_USE_MKLML
DEFINE_VMUL_COMPUTE_FLOAT(jit::avx2, kEQ8)
#endif
// TODO(TJ): test and complete avx512
#undef DEFINE_VMUL_COMPUTE_FLOAT #undef DEFINE_VMUL_COMPUTE_FLOAT
#undef DEFINE_VMUL_COMPUTE_DOUBLE #undef DEFINE_VMUL_COMPUTE_DOUBLE
#undef VMUL_ANY #undef VMUL_ANY
...@@ -142,8 +177,8 @@ LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str, ...@@ -142,8 +177,8 @@ LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
: Kernel(), d_(d) { : Kernel(), d_(d) {
d2_ = d * 2; d2_ = d * 2;
d3_ = d * 3; d3_ = d * 3;
if (platform::jit::MayIUse(platform::jit::avx512_common)) { if (platform::jit::MayIUse(platform::jit::avx512f)) {
math::VecActivations<float, platform::jit::avx512_common> act_functor; math::VecActivations<float, platform::jit::avx512f> act_functor;
act_gate_ = act_functor(act_gate_str); act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str); act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str); act_cand_ = act_functor(act_cand_str);
......
...@@ -36,7 +36,7 @@ namespace jitkernel { ...@@ -36,7 +36,7 @@ namespace jitkernel {
#define AVX512_FLOAT_BLOCK 16 #define AVX512_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8 #define AVX512_DOUBLE_BLOCK 8
typedef enum { kLT8, kEQ8, kEQ16, kGT16 } jit_block; typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
class Kernel { class Kernel {
public: public:
......
...@@ -128,7 +128,7 @@ bool MayIUse(const cpu_isa_t cpu_isa) { ...@@ -128,7 +128,7 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
return cpu.has(Cpu::tAVX); return cpu.has(Cpu::tAVX);
case avx2: case avx2:
return cpu.has(Cpu::tAVX2); return cpu.has(Cpu::tAVX2);
case avx512_common: case avx512f:
return cpu.has(Cpu::tAVX512F); return cpu.has(Cpu::tAVX512F);
case avx512_core: case avx512_core:
return true && cpu.has(Cpu::tAVX512F) && cpu.has(Cpu::tAVX512BW) && return true && cpu.has(Cpu::tAVX512F) && cpu.has(Cpu::tAVX512BW) &&
......
...@@ -43,7 +43,7 @@ typedef enum { ...@@ -43,7 +43,7 @@ typedef enum {
sse42, sse42,
avx, avx,
avx2, avx2,
avx512_common, avx512f,
avx512_core, avx512_core,
avx512_core_vnni, avx512_core_vnni,
avx512_mic, avx512_mic,
......
...@@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
platform::SetNumThreads(FLAGS_paddle_num_threads); platform::SetNumThreads(FLAGS_paddle_num_threads);
#endif #endif
if (platform::jit::MayIUse(platform::jit::avx512_common)) { if (platform::jit::MayIUse(platform::jit::avx512f)) {
#ifndef __AVX512F__ #ifndef __AVX512F__
LOG(WARNING) << "AVX512F is available, Please re-compile on local machine"; LOG(WARNING) << "AVX512F is available, Please re-compile on local machine";
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册