提交 6c986e12 编写于 作者: T tensor-tang

fix macro and add vmul unit test

上级 8c69764d
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
...@@ -62,7 +61,7 @@ namespace jit = platform::jit; ...@@ -62,7 +61,7 @@ namespace jit = platform::jit;
FOR_EACH_COMMON_BLOCK(macro_, jit::avx512f) \ FOR_EACH_COMMON_BLOCK(macro_, jit::avx512f) \
FOR_EACH_COMMON_BLOCK(macro_, jit::avx2) \ FOR_EACH_COMMON_BLOCK(macro_, jit::avx2) \
FOR_EACH_COMMON_BLOCK(macro_, jit::avx) \ FOR_EACH_COMMON_BLOCK(macro_, jit::avx) \
FOR_EACH_COMMON_BLOCK(macro_, jit::any) FOR_EACH_COMMON_BLOCK(macro_, jit::isa_any)
#define FOR_EACH_ALL_BLOCK(macro_, isa) \ #define FOR_EACH_ALL_BLOCK(macro_, isa) \
macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kGT8LT16) macro_(isa, kEQ16) \ macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kGT8LT16) macro_(isa, kEQ16) \
...@@ -72,7 +71,7 @@ namespace jit = platform::jit; ...@@ -72,7 +71,7 @@ namespace jit = platform::jit;
FOR_EACH_ALL_BLOCK(macro_, jit::avx512f) \ FOR_EACH_ALL_BLOCK(macro_, jit::avx512f) \
FOR_EACH_ALL_BLOCK(macro_, jit::avx2) \ FOR_EACH_ALL_BLOCK(macro_, jit::avx2) \
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \ FOR_EACH_ALL_BLOCK(macro_, jit::avx) \
FOR_EACH_ALL_BLOCK(macro_, jit::any) FOR_EACH_ALL_BLOCK(macro_, jit::isa_any)
#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \ #define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \
template <> \ template <> \
...@@ -92,7 +91,7 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) { ...@@ -92,7 +91,7 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) {
} }
} }
#ifdef PADDLE_USE_MKLML #ifdef PADDLE_WITH_MKLML
#define VMUL_MKL_FLOAT(isa, block) \ #define VMUL_MKL_FLOAT(isa, block) \
template <> \ template <> \
void VMulCompute<float, isa, block>(const int n, const float* x, \ void VMulCompute<float, isa, block>(const int n, const float* x, \
...@@ -103,7 +102,7 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) { ...@@ -103,7 +102,7 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) {
#define VMUL_MKL_DOUBLE(isa, block) \ #define VMUL_MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VMulCompute<double, isa, block>(const int n, const double* x, \ void VMulCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \ const double* y, double* z) { \
platform::dynload::vdMul(n, x, y, z); \ platform::dynload::vdMul(n, x, y, z); \
} }
...@@ -112,7 +111,7 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE) ...@@ -112,7 +111,7 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE)
#endif #endif
/// lt8 /// lt8
#ifdef PADDLE_USE_MKLML #ifdef PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx2, kLT8) VMUL_MKL_FLOAT(jit::avx2, kLT8)
VMUL_MKL_FLOAT(jit::avx512f, kLT8) VMUL_MKL_FLOAT(jit::avx512f, kLT8)
#endif #endif
...@@ -130,21 +129,21 @@ VMUL_MKL_FLOAT(jit::avx512f, kLT8) ...@@ -130,21 +129,21 @@ VMUL_MKL_FLOAT(jit::avx512f, kLT8)
} }
// mkl > avx > for, ">" means better // mkl > avx > for, ">" means better
#ifdef PADDLE_USE_MKLML #ifdef PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx, kEQ8) VMUL_MKL_FLOAT(jit::avx, kEQ8);
#elif defined __AVX__ #elif defined __AVX__
VMUL_INTRI8_FLOAT(jit::avx) VMUL_INTRI8_FLOAT(jit::avx);
#endif #endif
// avx2 > mkl > for // avx2 > mkl > for
#ifdef __AVX2__ #ifdef __AVX2__
VMUL_INTRI8_FLOAT(jit::avx2) VMUL_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_USE_MKLML #elif defined PADDLE_WITH_MKLML
VMUL_MKL_FLOAT(jit::avx2, kEQ8) VMUL_MKL_FLOAT(jit::avx2, kEQ8)
#endif #endif
// TODO(TJ): test and complete avx512 // TODO(TJ): test and complete avx512
/// eq16 /// eq16
#ifdef PADDLE_USE_MKLML #ifdef PADDLE_WITH_MKLML
// TODO(TJ): test and complete me // TODO(TJ): test and complete me
VMUL_MKL_FLOAT(jit::avx, kEQ16) VMUL_MKL_FLOAT(jit::avx, kEQ16)
VMUL_MKL_FLOAT(jit::avx2, kEQ16) VMUL_MKL_FLOAT(jit::avx2, kEQ16)
...@@ -163,7 +162,7 @@ static void VAddCompute(const int n, const T* x, const T* y, T* z) { ...@@ -163,7 +162,7 @@ static void VAddCompute(const int n, const T* x, const T* y, T* z) {
} }
} }
#ifdef PADDLE_USE_MKLML #ifdef PADDLE_WITH_MKLML
#define VADD_MKL_FLOAT(isa, block) \ #define VADD_MKL_FLOAT(isa, block) \
template <> \ template <> \
void VAddCompute<float, isa, block>(const int n, const float* x, \ void VAddCompute<float, isa, block>(const int n, const float* x, \
...@@ -174,7 +173,7 @@ static void VAddCompute(const int n, const T* x, const T* y, T* z) { ...@@ -174,7 +173,7 @@ static void VAddCompute(const int n, const T* x, const T* y, T* z) {
#define VADD_MKL_DOUBLE(isa, block) \ #define VADD_MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VAddCompute<double, isa, block>(const int n, const double* x, \ void VAddCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \ const double* y, double* z) { \
platform::dynload::vdAdd(n, x, y, z); \ platform::dynload::vdAdd(n, x, y, z); \
} }
...@@ -183,7 +182,7 @@ FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE) ...@@ -183,7 +182,7 @@ FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE)
#endif #endif
/// lt8 /// lt8
#ifdef PADDLE_USE_MKLML #ifdef PADDLE_WITH_MKLML
VADD_MKL_FLOAT(jit::avx, kLT8) VADD_MKL_FLOAT(jit::avx, kLT8)
VADD_MKL_FLOAT(jit::avx2, kLT8) VADD_MKL_FLOAT(jit::avx2, kLT8)
VADD_MKL_FLOAT(jit::avx512f, kLT8) VADD_MKL_FLOAT(jit::avx512f, kLT8)
...@@ -210,13 +209,13 @@ VADD_INTRI8_FLOAT(jit::avx) ...@@ -210,13 +209,13 @@ VADD_INTRI8_FLOAT(jit::avx)
// avx2 > mkl > for // avx2 > mkl > for
#ifdef __AVX2__ #ifdef __AVX2__
VADD_INTRI8_FLOAT(jit::avx2) VADD_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_USE_MKLML #elif defined PADDLE_WITH_MKLML
VADD_MKL_FLOAT(jit::avx2, kEQ8) VADD_MKL_FLOAT(jit::avx2, kEQ8)
#endif #endif
// TODO(TJ): test and complete avx512 // TODO(TJ): test and complete avx512
/// eq16 /// eq16
#ifdef PADDLE_USE_MKLML #ifdef PADDLE_WITH_MKLML
// TODO(TJ): test and complete me // TODO(TJ): test and complete me
VADD_MKL_FLOAT(jit::avx, kEQ16) VADD_MKL_FLOAT(jit::avx, kEQ16)
VADD_MKL_FLOAT(jit::avx2, kEQ16) VADD_MKL_FLOAT(jit::avx2, kEQ16)
......
...@@ -20,6 +20,14 @@ limitations under the License. */ ...@@ -20,6 +20,14 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
inline double GetCurrentUS() { inline double GetCurrentUS() {
struct timeval time; struct timeval time;
gettimeofday(&time, NULL); gettimeofday(&time, NULL);
...@@ -38,17 +46,26 @@ void RandomVec(const int n, T* a) { ...@@ -38,17 +46,26 @@ void RandomVec(const int n, T* a) {
} }
} }
constexpr int repeat = 10000; constexpr int repeat = 20000;
TEST(JitKernel, vmul) { #if defined __AVX__ || defined __AVX2__
namespace jit = paddle::operators::math::jitkernel; void vmul_intri(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y);
tmpx = _mm256_mul_ps(tmpx, tmpy);
_mm256_storeu_ps(z, tmpx);
}
#endif
auto ref = [](const int n, const float* x, const float* y, float* z) { void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i]; z[i] = x[i] * y[i];
} }
}; }
TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256}) { for (int d : {7, 8, 15, 16, 30, 256}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -61,18 +78,42 @@ TEST(JitKernel, vmul) { ...@@ -61,18 +78,42 @@ TEST(JitKernel, vmul) {
const float* y_data = y.data(); const float* y_data = y.data();
float* ztgt_data = ztgt.data(); float* ztgt_data = ztgt.data();
float* zref_data = zref.data(); float* zref_data = zref.data();
#ifdef PADDLE_WITH_MKLML
auto s0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
paddle::platform::dynload::vsMul(d, x_data, y_data, zref_data);
}
#endif
auto st = GetCurrentUS(); auto st = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data); ker->Compute(d, x_data, y_data, ztgt_data);
} }
auto mt = GetCurrentUS(); auto mt = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ref(d, x_data, y_data, zref_data); vmul_ref(d, x_data, y_data, zref_data);
} }
auto et = GetCurrentUS(); auto et = GetCurrentUS();
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vmul_intri(d, x_data, y_data, zref_data);
}
auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
}
#endif
VLOG(3) << "Vec size " << d << ": refer takes: " << (et - mt) / repeat VLOG(3) << "Vec size " << d << ": refer takes: " << (et - mt) / repeat
<< " us, tgt takes: " << (mt - st) / repeat; << " us, tgt takes: " << (mt - st) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl takes: " << (st - s0) / repeat << " us";
#else
<< " us";
#endif
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册