diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 24cf2aaf0bad28952a3b580ef43fe400d3a3cf8f..32944ae82c80f0628c63fc1f9ac32cd9dbb59208 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -82,6 +82,12 @@ class VScalKernel : public Kernel { virtual void Compute(const int n, const T a, T *x) const = 0; }; +template +class VAddBiasKernel : public Kernel { + public: + virtual void Compute(const int n, const T a, const T *x, T *y) const = 0; +}; + template class VExpKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 30761c0430d1920bee32e07cc57e2e6b5851b28e..d0ee97a43c6db4584193121e979a20e273acf099 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -216,9 +216,61 @@ INTRI8_INPLACE_FLOAT(jit::avx512f); #undef MKL_FLOAT #undef MKL_DOUBLE +/* VAddBias JitKernel */ +template +class VAddBiasKernelImpl : public VAddBiasKernel { + public: + void Compute(const int n, const T a, const T* x, T* y) const override { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + a; + } + } +}; + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VAddBiasKernelImpl::Compute( \ + const int n, const float a, const float* x, float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \ + _mm256_storeu_ps(y, tmp); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VAddBiasKernelImpl::Compute( \ + const int n, const float a, const float* x, float* y) const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \ + tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +#endif +// TODO(TJ): eq16 test and complete avx512 + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT +#undef INTRI_GT8LT16_FLOAT +#undef INTRI_GT16_FLOAT + REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddb, VAddBiasKernel); } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 99527d02244418845351d1b946bd1d5110e45638..0717c2aeebfc113439368c59351b07c8311587cd 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -221,16 +221,15 @@ INTRI_GT16_FLOAT(jit::avx); #ifdef __AVX2__ INTRI8_FLOAT(jit::avx2); INTRI16_FLOAT(jit::avx2); -INTRI_GT8LT16_FLOAT(jit::avx2); -INTRI_GT16_FLOAT(jit::avx2); +// INTRI_GT8LT16_FLOAT(jit::avx2); +// INTRI_GT16_FLOAT(jit::avx2); #endif #ifdef __AVX512F__ INTRI8_FLOAT(jit::avx512f); INTRI16_FLOAT(jit::avx512f); -INTRI_GT8LT16_FLOAT(jit::avx512f); -INTRI_GT16_FLOAT(jit::avx512f); +// INTRI_GT8LT16_FLOAT(jit::avx512f); +// INTRI_GT16_FLOAT(jit::avx512f); #endif -// TODO(TJ): eq16 test and complete avx512 #undef INTRI8_FLOAT #undef INTRI16_FLOAT diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 3db9a0b5eb60be081872471d714b5235ce7a981d..7c4178714126cb98ee76fb2f011296ee9813a4f0 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -48,6 +48,43 @@ void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), } } +void vaddbias_ref(const int n, const float a, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + a; + } +} + +TEST(JitKernel, vaddbias) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -2.f, 2.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float a = 2.f; + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vaddbias_ref(d, a, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(d, a, x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + void vexp_ref(const int n, const float* x, float* y) { for (int i = 0; i < n; ++i) { y[i] = std::exp(x[i]); @@ -135,7 +172,7 @@ void vsigmoid_better( TEST(JitKernel, vsigmoid) { namespace jit = paddle::operators::math::jitkernel; - for (int d : {7, 8, 15, 16, 30, 128}) { + for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { std::vector x(d); std::vector zref(d), ztgt(d); RandomVec(d, x.data(), -2.f, 2.f);