提交 b3c63f40 编写于 作者: T tensor-tang

add vscal and unit test

上级 0987f2b4
...@@ -75,6 +75,13 @@ class VAddKernel : public Kernel { ...@@ -75,6 +75,13 @@ class VAddKernel : public Kernel {
virtual void Compute(const int n, const T *x, const T *y, T *z) = 0; virtual void Compute(const int n, const T *x, const T *y, T *z) = 0;
}; };
template <typename T>
class VScalKernel : public Kernel {
public:
virtual void Compute(const int n, const T a, const T *x, T *y) = 0;
virtual void Compute(const int n, const T a, T *x) = 0;
};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
public: public:
......
...@@ -206,8 +206,84 @@ VADD_INTRI8_FLOAT(jit::avx512f); ...@@ -206,8 +206,84 @@ VADD_INTRI8_FLOAT(jit::avx512f);
#undef VADD_MKL_FLOAT #undef VADD_MKL_FLOAT
#undef VADD_MKL_DOUBLE #undef VADD_MKL_DOUBLE
/* VSCAL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VScalKernelImpl : public VScalKernel<T> {
public:
void Compute(const int n, const T a, const T* x, T* y) override {
for (int i = 0; i < n; ++i) {
y[i] = a * x[i];
}
}
void Compute(const int n, const T a, T* x) override {
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
}
};
#ifdef PADDLE_WITH_MKLML
#define VSCAL_MKL_FLOAT(isa, block) \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
float* x) { \
platform::dynload::cblas_sscal(n, a, x, 1); \
}
#define VSCAL_MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute( \
const int n, const double a, double* x) { \
platform::dynload::cblas_dscal(n, a, x, 1); \
}
FOR_EACH_ISA(VSCAL_MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(VSCAL_MKL_DOUBLE);
#endif
#define VSCAL_INTRI8(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
const float* x, float* y) { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define VSCAL_INTRI8_INPLACE(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
float* x) { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(x, tmp); \
}
#ifdef __AVX__
VSCAL_INTRI8(jit::avx);
VSCAL_INTRI8_INPLACE(jit::avx);
#endif
#ifdef __AVX2__
VSCAL_INTRI8(jit::avx2);
VSCAL_INTRI8_INPLACE(jit::avx2);
#endif
#ifdef __AVX512F__
VSCAL_INTRI8(jit::avx512f);
VSCAL_INTRI8_INPLACE(jit::avx512f);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef VSCAL_INTRI8
#undef VSCAL_INTRI8_INPLACE
#undef VSCAL_MKL_FLOAT
#undef VSCAL_MKL_DOUBLE
REGISTER_BLAS_JITKERNEL(vmul, VMulKernel); REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
REGISTER_BLAS_JITKERNEL(vadd, VAddKernel); REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
REGISTER_BLAS_JITKERNEL(vscal, VScalKernel);
#undef FOR_EACH_ISA #undef FOR_EACH_ISA
#undef FOR_EACH_BLOCK #undef FOR_EACH_BLOCK
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <sys/time.h> #include <sys/time.h>
#include <cstring>
#include <string> #include <string>
#include <vector> #include <vector>
#include "gflags/gflags.h" #include "gflags/gflags.h"
...@@ -28,6 +29,8 @@ limitations under the License. */ ...@@ -28,6 +29,8 @@ limitations under the License. */
#include <immintrin.h> #include <immintrin.h>
#endif #endif
constexpr int repeat = 20000;
inline double GetCurrentUS() { inline double GetCurrentUS() {
struct timeval time; struct timeval time;
gettimeofday(&time, NULL); gettimeofday(&time, NULL);
...@@ -46,7 +49,113 @@ void RandomVec(const int n, T* a) { ...@@ -46,7 +49,113 @@ void RandomVec(const int n, T* a) {
} }
} }
constexpr int repeat = 20000; void vscal_ref(const int n, const float a, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = a * x[i];
}
}
void vscal_inp_ref(const int n, const float a, float* x) {
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
}
#if defined __AVX__ || defined __AVX2__
void vscal_intri8(const int n, const float a, const float* x, float* y) {
__m256 tmp;
__m256 scalar = _mm256_set1_ps(a);
tmp = _mm256_loadu_ps(x);
tmp = _mm256_mul_ps(tmp, scalar);
_mm256_storeu_ps(y, tmp);
}
void vscal_inp_intri8(const int n, const float a, float* x) {
__m256 tmp;
__m256 scalar = _mm256_set1_ps(a);
tmp = _mm256_loadu_ps(x);
tmp = _mm256_mul_ps(tmp, scalar);
_mm256_storeu_ps(x, tmp);
}
#endif
#ifdef PADDLE_WITH_MKLML
void vscal_inp_mkl(const int n, const float a, float* x) {
paddle::platform::dynload::cblas_sscal(n, a, x, 1);
}
#endif
TEST(JitKernel, vscal) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
std::memcpy(y.data(), x.data(), sizeof(float) * d);
float a = 2.f;
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VScalKernel<float>>(d);
const float* x_data = x.data();
float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vscal_ref(d, a, x_data, zref_data);
}
auto trefe = GetCurrentUS();
auto trefs1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vscal_inp_ref(d, a, y_data);
}
auto trefe1 = GetCurrentUS();
#ifdef PADDLE_WITH_MKLML
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vscal_inp_mkl(d, a, y_data);
}
auto tmkle = GetCurrentUS();
#endif
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vscal_intri8(d, a, x_data, zref_data);
}
auto si1 = GetCurrentUS();
auto si2 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vscal_inp_intri8(d, a, y_data);
}
auto si3 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat
<< " us, inplace: " << (si3 - si2) / repeat;
}
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, a, x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
auto ttgts1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(d, a, y_data);
}
auto ttgte1 = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, inplace takes: " << (trefe1 - trefs1) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl inplace takes: " << (tmkle - tmkls) / repeat << " us, "
#else
<< " us, "
#endif
<< "tgt takes: " << (ttgte - ttgts) / repeat
<< "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat;
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
void vmul_ref(const int n, const float* x, const float* y, float* z) { void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册