提交 2937314d 编写于 作者: T tensor-tang

refine vmul and test

上级 6c986e12
...@@ -110,12 +110,6 @@ FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT) ...@@ -110,12 +110,6 @@ FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT)
FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE) FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE)
#endif #endif
/// lt8
#ifdef PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx2, kLT8)
VMUL_MKL_FLOAT(jit::avx512f, kLT8)
#endif
/// eq8 /// eq8
#define VMUL_INTRI8_FLOAT(isa) \ #define VMUL_INTRI8_FLOAT(isa) \
template <> \ template <> \
...@@ -128,28 +122,17 @@ VMUL_MKL_FLOAT(jit::avx512f, kLT8) ...@@ -128,28 +122,17 @@ VMUL_MKL_FLOAT(jit::avx512f, kLT8)
_mm256_storeu_ps(z, tmpx); \ _mm256_storeu_ps(z, tmpx); \
} }
// mkl > avx > for, ">" means better // avx > for > mkl
#ifdef PADDLE_WITH_MKLML #ifdef __AVX__
VMUL_MKL_FLOAT(jit::avx, kEQ8);
#elif defined __AVX__
VMUL_INTRI8_FLOAT(jit::avx); VMUL_INTRI8_FLOAT(jit::avx);
#endif #endif
// avx2 > mkl > for
// avx2 > for > mkl
#ifdef __AVX2__ #ifdef __AVX2__
VMUL_INTRI8_FLOAT(jit::avx2) VMUL_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx2, kEQ8)
#endif #endif
// TODO(TJ): test and complete avx512 // TODO(TJ): test and complete avx512
/// eq16
#ifdef PADDLE_WITH_MKLML
// TODO(TJ): test and complete me
VMUL_MKL_FLOAT(jit::avx, kEQ16)
VMUL_MKL_FLOAT(jit::avx2, kEQ16)
VMUL_MKL_FLOAT(jit::avx512f, kEQ16)
#endif
#undef VMUL_INTRI8_FLOAT #undef VMUL_INTRI8_FLOAT
#undef VMUL_MKL_FLOAT #undef VMUL_MKL_FLOAT
#undef VMUL_MKL_DOUBLE #undef VMUL_MKL_DOUBLE
...@@ -181,13 +164,6 @@ FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT) ...@@ -181,13 +164,6 @@ FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT)
FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE) FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE)
#endif #endif
/// lt8
#ifdef PADDLE_WITH_MKLML
VADD_MKL_FLOAT(jit::avx, kLT8)
VADD_MKL_FLOAT(jit::avx2, kLT8)
VADD_MKL_FLOAT(jit::avx512f, kLT8)
#endif
/// eq8 /// eq8
#define VADD_INTRI8_FLOAT(isa) \ #define VADD_INTRI8_FLOAT(isa) \
template <> \ template <> \
...@@ -200,28 +176,14 @@ VADD_MKL_FLOAT(jit::avx512f, kLT8) ...@@ -200,28 +176,14 @@ VADD_MKL_FLOAT(jit::avx512f, kLT8)
_mm256_storeu_ps(z, tmpx); \ _mm256_storeu_ps(z, tmpx); \
} }
// mkl > avx > for, ">" means better #ifdef __AVX__
#ifdef PADDLE_USE_MKLML
VADD_MKL_FLOAT(jit::avx, kEQ8)
#elif defined __AVX__
VADD_INTRI8_FLOAT(jit::avx) VADD_INTRI8_FLOAT(jit::avx)
#endif #endif
// avx2 > mkl > for
#ifdef __AVX2__ #ifdef __AVX2__
VADD_INTRI8_FLOAT(jit::avx2) VADD_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_WITH_MKLML
VADD_MKL_FLOAT(jit::avx2, kEQ8)
#endif #endif
// TODO(TJ): test and complete avx512 // TODO(TJ): test and complete avx512
/// eq16
#ifdef PADDLE_WITH_MKLML
// TODO(TJ): test and complete me
VADD_MKL_FLOAT(jit::avx, kEQ16)
VADD_MKL_FLOAT(jit::avx2, kEQ16)
VADD_MKL_FLOAT(jit::avx512f, kEQ16)
#endif
#undef VADD_INTRI8_FLOAT #undef VADD_INTRI8_FLOAT
#undef VADD_MKL_FLOAT #undef VADD_MKL_FLOAT
#undef VADD_MKL_DOUBLE #undef VADD_MKL_DOUBLE
......
...@@ -48,8 +48,14 @@ void RandomVec(const int n, T* a) { ...@@ -48,8 +48,14 @@ void RandomVec(const int n, T* a) {
constexpr int repeat = 20000; constexpr int repeat = 20000;
void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vmul_intri(const int n, const float* x, const float* y, float* z) { void vmul_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy; __m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x); tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y); tmpy = _mm256_loadu_ps(y);
...@@ -58,15 +64,15 @@ void vmul_intri(const int n, const float* x, const float* y, float* z) { ...@@ -58,15 +64,15 @@ void vmul_intri(const int n, const float* x, const float* y, float* z) {
} }
#endif #endif
void vmul_ref(const int n, const float* x, const float* y, float* z) { #ifdef PADDLE_WITH_MKLML
for (int i = 0; i < n; ++i) { void vmul_mkl(const int n, const float* x, const float* y, float* z) {
z[i] = x[i] * y[i]; paddle::platform::dynload::vsMul(n, x, y, z);
}
} }
#endif
TEST(JitKernel, vmul) { TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data()); RandomVec<float>(d, x.data());
...@@ -79,41 +85,44 @@ TEST(JitKernel, vmul) { ...@@ -79,41 +85,44 @@ TEST(JitKernel, vmul) {
float* ztgt_data = ztgt.data(); float* ztgt_data = ztgt.data();
float* zref_data = zref.data(); float* zref_data = zref.data();
#ifdef PADDLE_WITH_MKLML auto trefs = GetCurrentUS();
auto s0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
paddle::platform::dynload::vsMul(d, x_data, y_data, zref_data); vmul_ref(d, x_data, y_data, zref_data);
} }
#endif auto trefe = GetCurrentUS();
auto st = GetCurrentUS(); #ifdef PADDLE_WITH_MKLML
for (int i = 0; i < repeat; ++i) { auto tmkls = GetCurrentUS();
ker->Compute(d, x_data, y_data, ztgt_data);
}
auto mt = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vmul_ref(d, x_data, y_data, zref_data); vmul_mkl(d, x_data, y_data, zref_data);
} }
auto et = GetCurrentUS(); auto tmkle = GetCurrentUS();
#endif
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
if (d == 8) { if (d == 8) {
auto si0 = GetCurrentUS(); auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vmul_intri(d, x_data, y_data, zref_data); vmul_intri8(d, x_data, y_data, zref_data);
} }
auto si1 = GetCurrentUS(); auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
} }
#endif #endif
VLOG(3) << "Vec size " << d << ": refer takes: " << (et - mt) / repeat auto ttgts = GetCurrentUS();
<< " us, tgt takes: " << (mt - st) / repeat for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
<< " us, mkl takes: " << (st - s0) / repeat << " us"; << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, "
#else #else
<< " us"; << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat;
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册