diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 32944ae82c80f0628c63fc1f9ac32cd9dbb59208..eaf5fd0a87ff13c5a79b49ab030cad8ddd7e7263 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -28,13 +28,11 @@ namespace jitkernel { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 #define AVX_FLOAT_BLOCK 8 -#define AVX_DOUBLE_BLOCK 4 #define AVX2_FLOAT_BLOCK 8 -#define AVX2_DOUBLE_BLOCK 4 #define AVX512_FLOAT_BLOCK 16 -#define AVX512_DOUBLE_BLOCK 8 typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 0717c2aeebfc113439368c59351b07c8311587cd..da0a71be28eb1ee72e42e94ce914bb89aa13c951 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -235,6 +235,7 @@ INTRI16_FLOAT(jit::avx512f); #undef INTRI16_FLOAT #undef INTRI_GT8LT16_FLOAT #undef INTRI_GT16_FLOAT +#undef INTRI_VSIGMOID #define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \ p = std::dynamic_pointer_cast>( \ @@ -243,6 +244,118 @@ INTRI16_FLOAT(jit::avx512f); REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE, JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL); +/* VTanh JitKernel */ +template +class VTanhKernelImpl : public VTanhKernel { + public: + explicit VTanhKernelImpl(int d) : VTanhKernel() { + vscal_ = KernelPool::Instance().template Get>(d); + vsigmoid_ = KernelPool::Instance().template Get>(d); + vaddbias_ = KernelPool::Instance().template Get>(d); + } + void Compute(const int n, const T* x, T* y) const override { + vscal_->Compute(n, static_cast(2), x, y); + vsigmoid_->Compute(n, y, y); + vscal_->Compute(n, static_cast(2), y); + vaddbias_->Compute(n, static_cast(-1), y, y); + } + + private: + std::shared_ptr> vscal_; + std::shared_ptr> vsigmoid_; + std::shared_ptr> vaddbias_; +}; + +#define INTRI_VTANH(tmp) \ + tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \ + tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \ + tmp = detail::Exp(tmp); \ + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ + tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ + tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VTanhKernelImpl::Compute(const int n, const float* x, \ + float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y, tmp); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VTanhKernelImpl::Compute( \ + const int n, const float* x, float* y) const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_VTANH(tmp0); \ + INTRI_VTANH(tmp1); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#define INTRI_GT8LT16_FLOAT(isa) \ + template <> \ + void VTanhKernelImpl::Compute( \ + const int n, const float* x, float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y, tmp); \ + x += AVX_FLOAT_BLOCK; \ + y += AVX_FLOAT_BLOCK; \ + const int rest = n - AVX_FLOAT_BLOCK; \ + vscal_->Compute(rest, 2.f, x, y); \ + vsigmoid_->Compute(rest, y, y); \ + vscal_->Compute(rest, 2.f, y); \ + vaddbias_->Compute(rest, -1.f, y, y); \ + } + +#define INTRI_GT16_FLOAT(isa) \ + template <> \ + void VTanhKernelImpl::Compute( \ + const int n, const float* x, float* y) const { \ + const int rest = n % AVX_FLOAT_BLOCK; \ + const int end = n - rest; \ + for (int i = 0; i < end; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + x += end; \ + y += end; \ + vscal_->Compute(rest, 2.f, x, y); \ + vsigmoid_->Compute(rest, y, y); \ + vscal_->Compute(rest, 2.f, y); \ + vaddbias_->Compute(rest, -1.f, y, y); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +INTRI_GT8LT16_FLOAT(jit::avx); +INTRI_GT16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +// maybe use avx at gt8lt16 and gt16 +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +// maybe use avx at gt8lt16 and gt16 +#endif + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT +#undef INTRI_GT8LT16_FLOAT +#undef INTRI_GT16_FLOAT +#undef INTRI_VTANH + +REGISTER_JITKERNEL_ARGS(vtanh, VTanhKernel, JITKERNEL_DECLARE, JITKERNEL_KEY, + JITKERNEL_NEW_ACT_IMPL); + #undef JITKERNEL_NEW_ACT_IMPL } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 7c4178714126cb98ee76fb2f011296ee9813a4f0..3aadc6ef44b1ea7459b635a6332b4cfbd991a573 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -208,6 +208,72 @@ TEST(JitKernel, vsigmoid) { } } +inline float _tanh(float x) { return 2.f * _sigmoid(2.f * x) - 1.f; } + +void vtanh_ref(const int n, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = _tanh(x[i]); + } +} + +void vtanh_better( + const std::shared_ptr< + const paddle::operators::math::jitkernel::VScalKernel>& vscal, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VSigmoidKernel>& + vsigmoid, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VAddBiasKernel>& + vaddbias, + const int n, const float* x, float* y) { + vscal->Compute(n, 2.f, x, y); + vsigmoid->Compute(n, y, y); + vscal->Compute(n, 2.f, y); + vaddbias->Compute(n, -1.f, y, y); +} + +TEST(JitKernel, vtanh) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -2.f, 2.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const auto& vscal = + jit::KernelPool::Instance().template Get>(d); + const auto& vsigmoid = + jit::KernelPool::Instance().template Get>(d); + const auto& vaddbias = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vtanh_better(vscal, vsigmoid, vaddbias, d, x_data, zref_data); + } + auto tmkle = GetCurrentUS(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vtanh_ref(d, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(d, x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + void vscal_ref(const int n, const float a, const float* x, float* y) { for (int i = 0; i < n; ++i) { y[i] = a * x[i];