diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 3e75fd1137d7770f35b54fc42d89af6e9cbe7aca..9cb15f9bdb20a7c93d72156ddbb5c920f8c0fe8d 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -75,6 +75,13 @@ class VAddKernel : public Kernel { virtual void Compute(const int n, const T *x, const T *y, T *z) = 0; }; +template +class VScalKernel : public Kernel { + public: + virtual void Compute(const int n, const T a, const T *x, T *y) = 0; + virtual void Compute(const int n, const T a, T *x) = 0; +}; + template class LSTMKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 15f8bf7145cf9db24be2173b6eb8d134bd13ad89..0ec9ac10c810d8884724aabec20e4b4423b709ab 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -206,8 +206,84 @@ VADD_INTRI8_FLOAT(jit::avx512f); #undef VADD_MKL_FLOAT #undef VADD_MKL_DOUBLE +/* VSCAL JitKernel */ +template +class VScalKernelImpl : public VScalKernel { + public: + void Compute(const int n, const T a, const T* x, T* y) override { + for (int i = 0; i < n; ++i) { + y[i] = a * x[i]; + } + } + void Compute(const int n, const T a, T* x) override { + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } + } +}; + +#ifdef PADDLE_WITH_MKLML +#define VSCAL_MKL_FLOAT(isa, block) \ + template <> \ + void VScalKernelImpl::Compute(const int n, const float a, \ + float* x) { \ + platform::dynload::cblas_sscal(n, a, x, 1); \ + } + +#define VSCAL_MKL_DOUBLE(isa, block) \ + template <> \ + void VScalKernelImpl::Compute( \ + const int n, const double a, double* x) { \ + platform::dynload::cblas_dscal(n, a, x, 1); \ + } + +FOR_EACH_ISA(VSCAL_MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(VSCAL_MKL_DOUBLE); +#endif + +#define VSCAL_INTRI8(isa) \ + template <> \ + void VScalKernelImpl::Compute(const int n, const float a, \ + const float* x, float* y) { \ + __m256 tmp; \ + __m256 scalar = _mm256_set1_ps(a); \ + tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_mul_ps(tmp, scalar); \ + _mm256_storeu_ps(y, tmp); \ + } +#define VSCAL_INTRI8_INPLACE(isa) \ + template <> \ + void VScalKernelImpl::Compute(const int n, const float a, \ + float* x) { \ + __m256 tmp; \ + __m256 scalar = _mm256_set1_ps(a); \ + tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_mul_ps(tmp, scalar); \ + _mm256_storeu_ps(x, tmp); \ + } + +#ifdef __AVX__ +VSCAL_INTRI8(jit::avx); +VSCAL_INTRI8_INPLACE(jit::avx); +#endif +#ifdef __AVX2__ +VSCAL_INTRI8(jit::avx2); +VSCAL_INTRI8_INPLACE(jit::avx2); +#endif +#ifdef __AVX512F__ +VSCAL_INTRI8(jit::avx512f); +VSCAL_INTRI8_INPLACE(jit::avx512f); +#endif +// TODO(TJ): eq16 test and complete avx512 + +#undef VSCAL_INTRI8 +#undef VSCAL_INTRI8_INPLACE +#undef VSCAL_MKL_FLOAT +#undef VSCAL_MKL_DOUBLE + REGISTER_BLAS_JITKERNEL(vmul, VMulKernel); REGISTER_BLAS_JITKERNEL(vadd, VAddKernel); +REGISTER_BLAS_JITKERNEL(vscal, VScalKernel); #undef FOR_EACH_ISA #undef FOR_EACH_BLOCK diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 88437a050bf68583aef7900d176d0c1889d1aca1..ccd687d587d410d1649e77dd9478100b3494122d 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include +#include #include #include #include "gflags/gflags.h" @@ -28,6 +29,8 @@ limitations under the License. */ #include #endif +constexpr int repeat = 20000; + inline double GetCurrentUS() { struct timeval time; gettimeofday(&time, NULL); @@ -46,7 +49,113 @@ void RandomVec(const int n, T* a) { } } -constexpr int repeat = 20000; +void vscal_ref(const int n, const float a, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = a * x[i]; + } +} +void vscal_inp_ref(const int n, const float a, float* x) { + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +} +#if defined __AVX__ || defined __AVX2__ +void vscal_intri8(const int n, const float a, const float* x, float* y) { + __m256 tmp; + __m256 scalar = _mm256_set1_ps(a); + tmp = _mm256_loadu_ps(x); + tmp = _mm256_mul_ps(tmp, scalar); + _mm256_storeu_ps(y, tmp); +} +void vscal_inp_intri8(const int n, const float a, float* x) { + __m256 tmp; + __m256 scalar = _mm256_set1_ps(a); + tmp = _mm256_loadu_ps(x); + tmp = _mm256_mul_ps(tmp, scalar); + _mm256_storeu_ps(x, tmp); +} +#endif + +#ifdef PADDLE_WITH_MKLML +void vscal_inp_mkl(const int n, const float a, float* x) { + paddle::platform::dynload::cblas_sscal(n, a, x, 1); +} +#endif + +TEST(JitKernel, vscal) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 256, 512}) { + std::vector x(d), y(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data()); + std::memcpy(y.data(), x.data(), sizeof(float) * d); + float a = 2.f; + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* y_data = y.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_ref(d, a, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + auto trefs1 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_inp_ref(d, a, y_data); + } + auto trefe1 = GetCurrentUS(); + +#ifdef PADDLE_WITH_MKLML + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_inp_mkl(d, a, y_data); + } + auto tmkle = GetCurrentUS(); +#endif + +#if defined __AVX__ || defined __AVX2__ + if (d == 8) { + auto si0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_intri8(d, a, x_data, zref_data); + } + auto si1 = GetCurrentUS(); + auto si2 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_inp_intri8(d, a, y_data); + } + auto si3 = GetCurrentUS(); + VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat + << " us, inplace: " << (si3 - si2) / repeat; + } +#endif + + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(d, a, x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + auto ttgts1 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(d, a, y_data); + } + auto ttgte1 = GetCurrentUS(); + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, inplace takes: " << (trefe1 - trefs1) / repeat +#ifdef PADDLE_WITH_MKLML + << " us, mkl inplace takes: " << (tmkle - tmkls) / repeat << " us, " +#else + << " us, " +#endif + << "tgt takes: " << (ttgte - ttgts) / repeat + << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} void vmul_ref(const int n, const float* x, const float* y, float* z) { for (int i = 0; i < n; ++i) {