提交 2937314d 编写于 作者: T tensor-tang

refine vmul and test

上级 6c986e12
......@@ -110,12 +110,6 @@ FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT)
FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE)
#endif
/// lt8
#ifdef PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx2, kLT8)
VMUL_MKL_FLOAT(jit::avx512f, kLT8)
#endif
/// eq8
#define VMUL_INTRI8_FLOAT(isa) \
template <> \
......@@ -128,28 +122,17 @@ VMUL_MKL_FLOAT(jit::avx512f, kLT8)
_mm256_storeu_ps(z, tmpx); \
}
// mkl > avx > for, ">" means better
#ifdef PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx, kEQ8);
#elif defined __AVX__
// avx > for > mkl
#ifdef __AVX__
VMUL_INTRI8_FLOAT(jit::avx);
#endif
// avx2 > mkl > for
// avx2 > for > mkl
#ifdef __AVX2__
VMUL_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx2, kEQ8)
#endif
// TODO(TJ): test and complete avx512
/// eq16
#ifdef PADDLE_WITH_MKLML
// TODO(TJ): test and complete me
VMUL_MKL_FLOAT(jit::avx, kEQ16)
VMUL_MKL_FLOAT(jit::avx2, kEQ16)
VMUL_MKL_FLOAT(jit::avx512f, kEQ16)
#endif
#undef VMUL_INTRI8_FLOAT
#undef VMUL_MKL_FLOAT
#undef VMUL_MKL_DOUBLE
......@@ -181,13 +164,6 @@ FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT)
FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE)
#endif
/// lt8
#ifdef PADDLE_WITH_MKLML
VADD_MKL_FLOAT(jit::avx, kLT8)
VADD_MKL_FLOAT(jit::avx2, kLT8)
VADD_MKL_FLOAT(jit::avx512f, kLT8)
#endif
/// eq8
#define VADD_INTRI8_FLOAT(isa) \
template <> \
......@@ -200,28 +176,14 @@ VADD_MKL_FLOAT(jit::avx512f, kLT8)
_mm256_storeu_ps(z, tmpx); \
}
// mkl > avx > for, ">" means better
#ifdef PADDLE_USE_MKLML
VADD_MKL_FLOAT(jit::avx, kEQ8)
#elif defined __AVX__
#ifdef __AVX__
VADD_INTRI8_FLOAT(jit::avx)
#endif
// avx2 > mkl > for
#ifdef __AVX2__
VADD_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_WITH_MKLML
VADD_MKL_FLOAT(jit::avx2, kEQ8)
#endif
// TODO(TJ): test and complete avx512
/// eq16
#ifdef PADDLE_WITH_MKLML
// TODO(TJ): test and complete me
VADD_MKL_FLOAT(jit::avx, kEQ16)
VADD_MKL_FLOAT(jit::avx2, kEQ16)
VADD_MKL_FLOAT(jit::avx512f, kEQ16)
#endif
#undef VADD_INTRI8_FLOAT
#undef VADD_MKL_FLOAT
#undef VADD_MKL_DOUBLE
......
......@@ -48,8 +48,14 @@ void RandomVec(const int n, T* a) {
constexpr int repeat = 20000;
void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
#if defined __AVX__ || defined __AVX2__
void vmul_intri(const int n, const float* x, const float* y, float* z) {
void vmul_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y);
......@@ -58,15 +64,15 @@ void vmul_intri(const int n, const float* x, const float* y, float* z) {
}
#endif
void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
#ifdef PADDLE_WITH_MKLML
void vmul_mkl(const int n, const float* x, const float* y, float* z) {
paddle::platform::dynload::vsMul(n, x, y, z);
}
#endif
TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256}) {
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
......@@ -79,41 +85,44 @@ TEST(JitKernel, vmul) {
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
#ifdef PADDLE_WITH_MKLML
auto s0 = GetCurrentUS();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
paddle::platform::dynload::vsMul(d, x_data, y_data, zref_data);
vmul_ref(d, x_data, y_data, zref_data);
}
#endif
auto trefe = GetCurrentUS();
auto st = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data);
}
auto mt = GetCurrentUS();
#ifdef PADDLE_WITH_MKLML
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vmul_ref(d, x_data, y_data, zref_data);
vmul_mkl(d, x_data, y_data, zref_data);
}
auto et = GetCurrentUS();
auto tmkle = GetCurrentUS();
#endif
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vmul_intri(d, x_data, y_data, zref_data);
vmul_intri8(d, x_data, y_data, zref_data);
}
auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
}
#endif
VLOG(3) << "Vec size " << d << ": refer takes: " << (et - mt) / repeat
<< " us, tgt takes: " << (mt - st) / repeat
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl takes: " << (st - s0) / repeat << " us";
<< " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, "
#else
<< " us";
<< " us, "
#endif
<< "tgt takes: " << (ttgte - ttgts) / repeat;
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册