提交 6c986e12 编写于 作者: T tensor-tang

fix macro and add vmul unit test

上级 8c69764d
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
......@@ -62,7 +61,7 @@ namespace jit = platform::jit;
FOR_EACH_COMMON_BLOCK(macro_, jit::avx512f) \
FOR_EACH_COMMON_BLOCK(macro_, jit::avx2) \
FOR_EACH_COMMON_BLOCK(macro_, jit::avx) \
FOR_EACH_COMMON_BLOCK(macro_, jit::any)
FOR_EACH_COMMON_BLOCK(macro_, jit::isa_any)
#define FOR_EACH_ALL_BLOCK(macro_, isa) \
macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kGT8LT16) macro_(isa, kEQ16) \
......@@ -72,7 +71,7 @@ namespace jit = platform::jit;
FOR_EACH_ALL_BLOCK(macro_, jit::avx512f) \
FOR_EACH_ALL_BLOCK(macro_, jit::avx2) \
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \
FOR_EACH_ALL_BLOCK(macro_, jit::any)
FOR_EACH_ALL_BLOCK(macro_, jit::isa_any)
#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \
template <> \
......@@ -92,7 +91,7 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) {
}
}
#ifdef PADDLE_USE_MKLML
#ifdef PADDLE_WITH_MKLML
#define VMUL_MKL_FLOAT(isa, block) \
template <> \
void VMulCompute<float, isa, block>(const int n, const float* x, \
......@@ -103,7 +102,7 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) {
#define VMUL_MKL_DOUBLE(isa, block) \
template <> \
void VMulCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \
const double* y, double* z) { \
platform::dynload::vdMul(n, x, y, z); \
}
......@@ -112,7 +111,7 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE)
#endif
/// lt8
#ifdef PADDLE_USE_MKLML
#ifdef PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx2, kLT8)
VMUL_MKL_FLOAT(jit::avx512f, kLT8)
#endif
......@@ -130,21 +129,21 @@ VMUL_MKL_FLOAT(jit::avx512f, kLT8)
}
// mkl > avx > for, ">" means better
#ifdef PADDLE_USE_MKLML
VMUL_MKL_FLOAT(jit::avx, kEQ8)
#ifdef PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx, kEQ8);
#elif defined __AVX__
VMUL_INTRI8_FLOAT(jit::avx)
VMUL_INTRI8_FLOAT(jit::avx);
#endif
// avx2 > mkl > for
#ifdef __AVX2__
VMUL_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_USE_MKLML
#elif defined PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx2, kEQ8)
#endif
// TODO(TJ): test and complete avx512
/// eq16
#ifdef PADDLE_USE_MKLML
#ifdef PADDLE_WITH_MKLML
// TODO(TJ): test and complete me
VMUL_MKL_FLOAT(jit::avx, kEQ16)
VMUL_MKL_FLOAT(jit::avx2, kEQ16)
......@@ -163,7 +162,7 @@ static void VAddCompute(const int n, const T* x, const T* y, T* z) {
}
}
#ifdef PADDLE_USE_MKLML
#ifdef PADDLE_WITH_MKLML
#define VADD_MKL_FLOAT(isa, block) \
template <> \
void VAddCompute<float, isa, block>(const int n, const float* x, \
......@@ -174,7 +173,7 @@ static void VAddCompute(const int n, const T* x, const T* y, T* z) {
#define VADD_MKL_DOUBLE(isa, block) \
template <> \
void VAddCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \
const double* y, double* z) { \
platform::dynload::vdAdd(n, x, y, z); \
}
......@@ -183,7 +182,7 @@ FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE)
#endif
/// lt8
#ifdef PADDLE_USE_MKLML
#ifdef PADDLE_WITH_MKLML
VADD_MKL_FLOAT(jit::avx, kLT8)
VADD_MKL_FLOAT(jit::avx2, kLT8)
VADD_MKL_FLOAT(jit::avx512f, kLT8)
......@@ -210,13 +209,13 @@ VADD_INTRI8_FLOAT(jit::avx)
// avx2 > mkl > for
#ifdef __AVX2__
VADD_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_USE_MKLML
#elif defined PADDLE_WITH_MKLML
VADD_MKL_FLOAT(jit::avx2, kEQ8)
#endif
// TODO(TJ): test and complete avx512
/// eq16
#ifdef PADDLE_USE_MKLML
#ifdef PADDLE_WITH_MKLML
// TODO(TJ): test and complete me
VADD_MKL_FLOAT(jit::avx, kEQ16)
VADD_MKL_FLOAT(jit::avx2, kEQ16)
......
......@@ -20,6 +20,14 @@ limitations under the License. */
#include "glog/logging.h"
#include "gtest/gtest.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
......@@ -38,17 +46,26 @@ void RandomVec(const int n, T* a) {
}
}
constexpr int repeat = 10000;
constexpr int repeat = 20000;
TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel;
#if defined __AVX__ || defined __AVX2__
void vmul_intri(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y);
tmpx = _mm256_mul_ps(tmpx, tmpy);
_mm256_storeu_ps(z, tmpx);
}
#endif
auto ref = [](const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
};
void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
......@@ -61,18 +78,42 @@ TEST(JitKernel, vmul) {
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
#ifdef PADDLE_WITH_MKLML
auto s0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
paddle::platform::dynload::vsMul(d, x_data, y_data, zref_data);
}
#endif
auto st = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data);
}
auto mt = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ref(d, x_data, y_data, zref_data);
vmul_ref(d, x_data, y_data, zref_data);
}
auto et = GetCurrentUS();
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vmul_intri(d, x_data, y_data, zref_data);
}
auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
}
#endif
VLOG(3) << "Vec size " << d << ": refer takes: " << (et - mt) / repeat
<< " us, tgt takes: " << (mt - st) / repeat;
<< " us, tgt takes: " << (mt - st) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl takes: " << (st - s0) / repeat << " us";
#else
<< " us";
#endif
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册