diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index ac368c9d0d027e14635ed83982a716a80458db37..0433cfc23eb11b7589c7041addc1d43aea2686c1 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -168,24 +168,26 @@ void ReluJitCode::generate() { #define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val #define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float) -#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float) static const float exp_float_consts[] ALIGN32 = { REPEAT_8TIMES(1.f), + REPEAT_8TIMES(2.f), REPEAT_8TIMES(0.5f), REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_LOW), @@ -216,6 +218,7 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { ymm_t ymm_fy = ymm_t(3); ymm_t ymm_mask = ymm_t(4); ymm_t ymm_tmp = ymm_t(5); + assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore push(reg_ptr_global); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); @@ -327,6 +330,40 @@ void VSigmoidJitCode::generate() { ret(); } +bool VTanhJitCode::init(int d) { + return MayIUse(avx) && d == 8; // only 8 yet +} + +void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { + // y = 2 / (1 + e^(-2x)) - 1 + // use ymm2, ymm3 + reg64_t reg_ptr_global = rax; + ymm_t ymm_tmp = ymm_t(2); + ymm_t ymm_zero = ymm_t(3); + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); + vxorps(ymm_zero, ymm_zero, ymm_zero); + vsubps(ymm_tmp, ymm_zero, ymm_tmp); + vmulps(ymm_src, ymm_src, ymm_tmp); + exp_ymm(ymm_src, ymm_dst); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); + vdivps(ymm_dst, ymm_tmp, ymm_dst); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vsubps(ymm_dst, ymm_dst, ymm_tmp); + pop(reg_ptr_global); +} + +void VTanhJitCode::generate() { + int offset = 0; + vmovups(ymm_src, ptr[param1 + offset]); + vtanh_ymm(ymm_src, ymm_dst); + vmovups(ptr[param2 + offset], ymm_dst); + ret(); +} + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index df9d7fd051cf7341d19a08fecab5e646133a8cb9..685ab8750ed84d610685e9b4490881e3c864c795 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -149,6 +149,26 @@ class VSigmoidJitCode : public VExpJitCode { ymm_t ymm_dst = ymm_t(1); }; +class VTanhJitCode : public VExpJitCode { + public: + DECLARE_JIT_CODE(VTanhJitCode); + explicit VTanhJitCode(int d, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : VExpJitCode(d, code_size, code_ptr), num_(d) {} + static bool init(int d); + void generate() override; + + // compute sigmoid with ymm + void vtanh_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst); + + private: + int num_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + ymm_t ymm_src = ymm_t(0); + ymm_t ymm_dst = ymm_t(1); +}; + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 205d47be425052870c5df0cfd4549a7fece26543..1d443bdbe2bae4be7919423c6ae29f3af5010557 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -132,6 +132,7 @@ template class VTanhKernel : public VActKernel { public: virtual void ComputeDeprecated(const T *x, T *y) const = 0; + void (*Compute)(const T *, T *, int); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 4e5fd6de637d7eb0a883ffcb58f654a4b7d47d95..f0431be5816bfc156c0b200d1f3461065927b2cc 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -45,6 +45,7 @@ void VExpRefer(const T* x, T* y, int n) { template void VSigmoidRefer(const T* x, T* y, int n) { + // y = 1 / (1 + e^-x) const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; for (int i = 0; i < n; ++i) { @@ -53,6 +54,18 @@ void VSigmoidRefer(const T* x, T* y, int n) { } } +template +void VTanhRefer(const T* x, T* y, int n) { + // y = 2 * sigmoid(2x) - 1 + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * x[i]; + } + VSigmoidRefer(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * y[i] - static_cast(1); + } +} + #ifdef PADDLE_WITH_MKLML template void VExpMKL(const T* x, T* y, int n); @@ -80,6 +93,17 @@ void VSigmoidMKL(const T* x, T* y, int n) { y[i] = static_cast(1) / (static_cast(1) + y[i]); } } + +template +void VTanhMKL(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * x[i]; + } + VSigmoidMKL(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * y[i] - static_cast(1); + } +} #endif /* VExp JitKernel */ @@ -189,8 +213,63 @@ bool VSigmoidKernelImpl::useMKL(int d) { } #endif +/* VTanh JitKernel */ +template +class VTanhKernelImpl : public VTanhKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit VTanhKernelImpl(int d) : VTanhKernel() { + this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change + jitcode_.reset(new gen::VTanhJitCode(d, sz > 4096 ? sz : 4096)); + this->Compute = jitcode_->getCode(); + return; + } +#endif + +#ifdef PADDLE_WITH_MKLML + // strictly it's a better impl with MKL, then is refer + if (useMKL(d)) { + this->Compute = VTanhMKL; + return; + } +#endif + this->Compute = VTanhRefer; + } + void ComputeDeprecated(const T* x, T* y) const override { + VTanhRefer(x, y, this->num_); + } +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +#endif +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VTanhKernelImpl::useJIT(int d) { + return gen::VTanhJitCode::init(d); +} +#endif + +#ifdef PADDLE_WITH_MKLML +template <> +bool VTanhKernelImpl::useMKL(int d) { + return d > 512; +} + +template <> +bool VTanhKernelImpl::useMKL(int d) { + return true; +} +#endif + REGISTER_JITKERNEL(vexp, VExpKernel); REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); +REGISTER_JITKERNEL(vtanh, VTanhKernel); namespace detail { @@ -337,156 +416,6 @@ __m256 ExpAVX2(__m256 x) { #endif } // namespace detail - -#define INTRI_SIGMOID(tmp, min, max, expisa) \ - tmp = _mm256_max_ps(tmp, min); \ - tmp = _mm256_min_ps(tmp, max); \ - tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \ - tmp = expisa(tmp); \ - tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ - tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) -#undef INTRI_VSIGMOID - -/* VTanh JitKernel */ -template -class VTanhKernelImpl : public VTanhKernel { - public: - explicit VTanhKernelImpl(int d) : VTanhKernel() { - this->num_ = d; - vscal_ = KernelPool::Instance().template Get>(d); - vsigmoid_ = KernelPool::Instance().template Get>(d); - vaddbias_ = KernelPool::Instance().template Get>(d); - } - void ComputeDeprecated(const T* x, T* y) const override { - const T a = static_cast(2), b = static_cast(-1); - vscal_->Compute(&a, x, y, this->num_); - vsigmoid_->ComputeDeprecated(y, y); - vscal_->Compute(&a, y, y, this->num_); - vaddbias_->Compute(&b, y, y, this->num_); - } - - private: - std::shared_ptr> vscal_; - std::shared_ptr> vsigmoid_; - std::shared_ptr> vaddbias_; -}; - -#define INTRI_VTANH(tmp, expisa) \ - tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \ - tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \ - tmp = expisa(tmp); \ - tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ - tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ - tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) - -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VTanhKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y, tmp); \ - } - -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VTanhKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_VTANH(tmp0, expisa); \ - INTRI_VTANH(tmp1, expisa); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ - } - -#define INTRI_GT8LT16_FLOAT(isa, expisa) \ - template <> \ - VTanhKernelImpl::VTanhKernelImpl(int d) \ - : VTanhKernel() { \ - this->num_ = d; \ - this->end_ = AVX_FLOAT_BLOCK; \ - this->rest_ = d - this->end_; \ - vscal_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - vsigmoid_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - vaddbias_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - } \ - template <> \ - void VTanhKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y, tmp); \ - x += AVX_FLOAT_BLOCK; \ - y += AVX_FLOAT_BLOCK; \ - const float a = 2.f, b = -1.f; \ - vscal_->Compute(&a, x, y, this->num_); \ - vsigmoid_->ComputeDeprecated(y, y); \ - vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(&b, y, y, this->num_); \ - } - -#define INTRI_GT16_FLOAT(isa, expisa) \ - template <> \ - VTanhKernelImpl::VTanhKernelImpl(int d) \ - : VTanhKernel() { \ - this->num_ = d; \ - this->rest_ = d % AVX_FLOAT_BLOCK; \ - this->end_ = d - this->rest_; \ - vscal_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - vsigmoid_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - vaddbias_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - } \ - template <> \ - void VTanhKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - x += this->end_; \ - y += this->end_; \ - const float a = 2.f, b = -1.f; \ - vscal_->Compute(&a, x, y, this->num_); \ - vsigmoid_->ComputeDeprecated(y, y); \ - vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(&b, y, y, this->num_); \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx, detail::ExpAVX); -INTRI16_FLOAT(jit::avx, detail::ExpAVX); -INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); -INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); -// maybe use avx at gt8lt16 and gt16 -#endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); -// maybe use avx at gt8lt16 and gt16 -#endif - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_GT8LT16_FLOAT -#undef INTRI_GT16_FLOAT -#undef INTRI_VTANH - -REGISTER_JITKERNEL_DEPRECATED(vtanh, VTanhKernel); - -#undef JITKERNEL_NEW_ACT_IMPL - } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 29c4dcc357ab13550a6cc76b2b8c7663ced18771..2f9dbc585efb47664803f2da30688bd8aa68300a 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -322,7 +322,7 @@ TEST(JitKernel, vtanh) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->ComputeDeprecated(x_data, ztgt_data); + ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS();