提交 0987f2b4 编写于 作者: T tensor-tang

add vadd unit test

上级 3d928d4f
...@@ -75,25 +75,24 @@ namespace jit = platform::jit; ...@@ -75,25 +75,24 @@ namespace jit = platform::jit;
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \ DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d) DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
// do not include lt8, eq8, eq16 #define FOR_EACH_ISA(macro_, block) \
#define FOR_EACH_COMMON_BLOCK(macro_, isa) \ macro_(jit::avx512f, block); \
macro_(isa, kGT8LT16) macro_(isa, kGT16) macro_(jit::avx2, block); \
macro_(jit::avx, block); \
#define FOR_EACH_ISA_COMMON_BLOCK(macro_) \ macro_(jit::isa_any, block)
FOR_EACH_COMMON_BLOCK(macro_, jit::avx512f) \
FOR_EACH_COMMON_BLOCK(macro_, jit::avx2) \ #define FOR_EACH_BLOCK(macro_, isa) \
FOR_EACH_COMMON_BLOCK(macro_, jit::avx) \ macro_(isa, kLT8); \
FOR_EACH_COMMON_BLOCK(macro_, jit::isa_any) macro_(isa, kEQ8); \
macro_(isa, kGT8LT16); \
#define FOR_EACH_ALL_BLOCK(macro_, isa) \ macro_(isa, kEQ16); \
macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kGT8LT16) macro_(isa, kEQ16) \ macro_(isa, kGT16)
macro_(isa, kGT16)
#define FOR_EACH_ISA_BLOCK(macro_) \
#define FOR_EACH_ISA_ALL_BLOCK(macro_) \ FOR_EACH_BLOCK(macro_, jit::avx512f); \
FOR_EACH_ALL_BLOCK(macro_, jit::avx512f) \ FOR_EACH_BLOCK(macro_, jit::avx2); \
FOR_EACH_ALL_BLOCK(macro_, jit::avx2) \ FOR_EACH_BLOCK(macro_, jit::avx); \
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \ FOR_EACH_BLOCK(macro_, jit::isa_any)
FOR_EACH_ALL_BLOCK(macro_, jit::isa_any)
/* VMUL JitKernel */ /* VMUL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
...@@ -121,8 +120,8 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -121,8 +120,8 @@ class VMulKernelImpl : public VMulKernel<T> {
platform::dynload::vdMul(n, x, y, z); \ platform::dynload::vdMul(n, x, y, z); \
} }
FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT); FOR_EACH_ISA(VMUL_MKL_FLOAT, kGT16);
FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE); FOR_EACH_ISA_BLOCK(VMUL_MKL_DOUBLE);
#endif #endif
#define VMUL_INTRI8_FLOAT(isa) \ #define VMUL_INTRI8_FLOAT(isa) \
...@@ -178,8 +177,8 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -178,8 +177,8 @@ class VAddKernelImpl : public VAddKernel<T> {
platform::dynload::vdAdd(n, x, y, z); \ platform::dynload::vdAdd(n, x, y, z); \
} }
FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT); FOR_EACH_ISA(VADD_MKL_FLOAT, kGT16);
FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE); FOR_EACH_ISA_BLOCK(VADD_MKL_DOUBLE);
#endif #endif
#define VADD_INTRI8_FLOAT(isa) \ #define VADD_INTRI8_FLOAT(isa) \
...@@ -210,10 +209,9 @@ VADD_INTRI8_FLOAT(jit::avx512f); ...@@ -210,10 +209,9 @@ VADD_INTRI8_FLOAT(jit::avx512f);
REGISTER_BLAS_JITKERNEL(vmul, VMulKernel); REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
REGISTER_BLAS_JITKERNEL(vadd, VAddKernel); REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
#undef FOR_EACH_ISA_ALL_BLOCK #undef FOR_EACH_ISA
#undef FOR_EACH_ALL_BLOCK #undef FOR_EACH_BLOCK
#undef FOR_EACH_ISA_COMMON_BLOCK #undef FOR_EACH_ISA_BLOCK
#undef FOR_EACH_COMMON_BLOCK
#undef REGISTER_BLAS_JITKERNEL #undef REGISTER_BLAS_JITKERNEL
#undef DEFINE_WITH_DTYPE #undef DEFINE_WITH_DTYPE
#undef SEARCH_ISA_BLOCK #undef SEARCH_ISA_BLOCK
......
...@@ -79,12 +79,10 @@ TEST(JitKernel, vmul) { ...@@ -79,12 +79,10 @@ TEST(JitKernel, vmul) {
RandomVec<float>(d, y.data()); RandomVec<float>(d, y.data());
const auto& ker = const auto& ker =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d); jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d);
const float* x_data = x.data(); const float* x_data = x.data();
const float* y_data = y.data(); const float* y_data = y.data();
float* ztgt_data = ztgt.data(); float* ztgt_data = ztgt.data();
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vmul_ref(d, x_data, y_data, zref_data); vmul_ref(d, x_data, y_data, zref_data);
...@@ -129,6 +127,85 @@ TEST(JitKernel, vmul) { ...@@ -129,6 +127,85 @@ TEST(JitKernel, vmul) {
} }
} }
void vadd_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
#if defined __AVX__ || defined __AVX2__
void vadd_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y);
tmpx = _mm256_add_ps(tmpx, tmpy);
_mm256_storeu_ps(z, tmpx);
}
#endif
#ifdef PADDLE_WITH_MKLML
void vadd_mkl(const int n, const float* x, const float* y, float* z) {
paddle::platform::dynload::vsAdd(n, x, y, z);
}
#endif
TEST(JitKernel, vadd) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
RandomVec<float>(d, y.data());
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data);
}
auto trefe = GetCurrentUS();
#ifdef PADDLE_WITH_MKLML
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_mkl(d, x_data, y_data, zref_data);
}
auto tmkle = GetCurrentUS();
#endif
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_intri8(d, x_data, y_data, zref_data);
}
auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
}
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, y_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, "
#else
<< " us, "
#endif
<< "tgt takes: " << (ttgte - ttgts) / repeat;
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
TEST(JitKernel, pool) { TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4; const int frame_size = 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册