diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 0d94a639b4a344e2c2ab39cd4485818619fe7618..ac368c9d0d027e14635ed83982a716a80458db37 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -152,10 +152,6 @@ void ReluJitCode::generate() { ret(); } -bool VExpJitCode::init(int d) { - return MayIUse(avx) && d == 8; // only 8 yet -} - #define ALIGN32 __attribute__((aligned(32))) #define EXP_HIG 88.3762626647949f #define EXP_LOW -88.3762626647949f @@ -171,6 +167,7 @@ bool VExpJitCode::init(int d) { #define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val +#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float) @@ -183,24 +180,43 @@ bool VExpJitCode::init(int d) { #define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float) #define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float) static const float exp_float_consts[] ALIGN32 = { - REPEAT_8TIMES(1.f), REPEAT_8TIMES(0.5f), - REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_LOW), - REPEAT_8TIMES(CEPHES_LOG2EF), REPEAT_8TIMES(CEPHES_EXP_C1), - REPEAT_8TIMES(CEPHES_EXP_C2), REPEAT_8TIMES(CEPHES_EXP_P0), - REPEAT_8TIMES(CEPHES_EXP_P1), REPEAT_8TIMES(CEPHES_EXP_P2), - REPEAT_8TIMES(CEPHES_EXP_P3), REPEAT_8TIMES(CEPHES_EXP_P4), - REPEAT_8TIMES(CEPHES_EXP_P5)}; + REPEAT_8TIMES(1.f), + REPEAT_8TIMES(0.5f), + REPEAT_8TIMES(EXP_HIG), + REPEAT_8TIMES(EXP_LOW), + REPEAT_8TIMES(CEPHES_LOG2EF), + REPEAT_8TIMES(CEPHES_EXP_C1), + REPEAT_8TIMES(CEPHES_EXP_C2), + REPEAT_8TIMES(CEPHES_EXP_P0), + REPEAT_8TIMES(CEPHES_EXP_P1), + REPEAT_8TIMES(CEPHES_EXP_P2), + REPEAT_8TIMES(CEPHES_EXP_P3), + REPEAT_8TIMES(CEPHES_EXP_P4), + REPEAT_8TIMES(CEPHES_EXP_P5), + REPEAT_8TIMES(EXP_MAX_INPUT), + REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX), + REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)}; static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; static int g_tmp_mem[16] ALIGN32 = {0}; -void VExpJitCode::generate() { - // in: ymm0, out: ymm1 - // use ymm 0~5, rax - int offset = 0; - vmovups(ymm_src, ptr[param1 + offset]); +bool VExpJitCode::init(int d) { + return MayIUse(avx) && d == 8; // only 8 yet +} + +void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { + // use reg rax and ymm 2~5 + reg64_t reg_ptr_global = rax; + ymm_t ymm_fx = ymm_t(2); + ymm_t ymm_fy = ymm_t(3); + ymm_t ymm_mask = ymm_t(4); + ymm_t ymm_tmp = ymm_t(5); + push(reg_ptr_global); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); vminps(ymm_src, ymm_src, ymm_tmp); @@ -269,8 +285,45 @@ void VExpJitCode::generate() { vmovdqa(ymm_int, ptr[reg_ptr_tmp]); } vmulps(ymm_dst, ymm_dst, ymm_int); + pop(reg_ptr_global); +} + +void VExpJitCode::generate() { + int offset = 0; + vmovups(ymm_src, ptr[param1 + offset]); + exp_ymm(ymm_src, ymm_dst); vmovups(ptr[param2 + offset], ymm_dst); + ret(); +} + +bool VSigmoidJitCode::init(int d) { + return MayIUse(avx) && d == 8; // only 8 yet +} +void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { + // use ymm2 + reg64_t reg_ptr_global = rax; + ymm_t ymm_tmp = ymm_t(2); + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); + vminps(ymm_src, ymm_src, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); + vmaxps(ymm_src, ymm_src, ymm_tmp); + vxorps(ymm_tmp, ymm_tmp, ymm_tmp); + vsubps(ymm_src, ymm_tmp, ymm_src); + exp_ymm(ymm_src, ymm_dst); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vdivps(ymm_dst, ymm_tmp, ymm_dst); + pop(reg_ptr_global); +} + +void VSigmoidJitCode::generate() { + int offset = 0; + vmovups(ymm_src, ptr[param1 + offset]); + sigmoid_ymm(ymm_src, ymm_dst); + vmovups(ptr[param2 + offset], ymm_dst); ret(); } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 8296de9b72d725fc6d6021b000f31fa41d09e7b0..df9d7fd051cf7341d19a08fecab5e646133a8cb9 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -117,18 +117,36 @@ class VExpJitCode : public JitCode { static bool init(int d); void generate() override; + protected: + // compute exp with ymm + void exp_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst); + private: int num_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; + ymm_t ymm_src = ymm_t(0); + ymm_t ymm_dst = ymm_t(1); +}; - reg64_t reg_ptr_global = rax; +class VSigmoidJitCode : public VExpJitCode { + public: + DECLARE_JIT_CODE(VSigmoidJitCode); + explicit VSigmoidJitCode(int d, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : VExpJitCode(d, code_size, code_ptr), num_(d) {} + static bool init(int d); + void generate() override; + + // compute sigmoid with ymm + void sigmoid_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst); + + private: + int num_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; ymm_t ymm_src = ymm_t(0); ymm_t ymm_dst = ymm_t(1); - ymm_t ymm_fx = ymm_t(2); - ymm_t ymm_fy = ymm_t(3); - ymm_t ymm_mask = ymm_t(4); - ymm_t ymm_tmp = ymm_t(5); }; } // namespace gen diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index a68d9c5d2ebdb7001da5a6060538bc3e12d18d0c..205d47be425052870c5df0cfd4549a7fece26543 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -29,6 +29,7 @@ namespace jitkernel { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 #define EXP_MAX_INPUT 40.0 +// TODO(TJ): change AVX_FLOAT_BLOCK to YMM_FLOAT_BLOCK #define AVX_FLOAT_BLOCK 8 #define AVX2_FLOAT_BLOCK 8 #define AVX512_FLOAT_BLOCK 16 @@ -124,6 +125,7 @@ template class VSigmoidKernel : public VActKernel { public: virtual void ComputeDeprecated(const T *x, T *y) const = 0; + void (*Compute)(const T *, T *, int); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index eae9648bdcd9e6cada50fd90e8ab358af36bfdbd..4e5fd6de637d7eb0a883ffcb58f654a4b7d47d95 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -43,6 +43,16 @@ void VExpRefer(const T* x, T* y, int n) { } } +template +void VSigmoidRefer(const T* x, T* y, int n) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(1) / (static_cast(1) + std::exp(-tmp)); + } +} + #ifdef PADDLE_WITH_MKLML template void VExpMKL(const T* x, T* y, int n); @@ -56,6 +66,20 @@ template <> void VExpMKL(const double* x, double* y, int n) { platform::dynload::vdExp(n, x, y); } + +template +void VSigmoidMKL(const T* x, T* y, int n) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + VExpMKL(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); + } +} #endif /* VExp JitKernel */ @@ -108,9 +132,65 @@ template <> bool VExpKernelImpl::useMKL(int d) { return true; } + +#endif + +/* VSigmoid JitKernel */ +template +class VSigmoidKernelImpl : public VSigmoidKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { + this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change + jitcode_.reset(new gen::VSigmoidJitCode(d, sz > 4096 ? sz : 4096)); + this->Compute = jitcode_->getCode(); + return; + } +#endif + +#ifdef PADDLE_WITH_MKLML + // strictly it's a better impl with MKL, then is refer + if (useMKL(d)) { + this->Compute = VSigmoidMKL; + return; + } +#endif + this->Compute = VSigmoidRefer; + } + void ComputeDeprecated(const T* x, T* y) const override { + VSigmoidRefer(x, y, this->num_); + } +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +#endif +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VSigmoidKernelImpl::useJIT(int d) { + return gen::VSigmoidJitCode::init(d); +} +#endif + +#ifdef PADDLE_WITH_MKLML +template <> +bool VSigmoidKernelImpl::useMKL(int d) { + return d > 512; +} + +template <> +bool VSigmoidKernelImpl::useMKL(int d) { + return true; +} #endif REGISTER_JITKERNEL(vexp, VExpKernel); +REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); namespace detail { @@ -258,31 +338,6 @@ __m256 ExpAVX2(__m256 x) { } // namespace detail -/* VSigmoid JitKernel */ -template -class VSigmoidKernelImpl : public VSigmoidKernel { - public: - explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { - this->num_ = d; - vexp_ = KernelPool::Instance().template Get>(d); - } - void ComputeDeprecated(const T* x, T* y) const override { - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int i = 0; i < this->num_; ++i) { - y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = static_cast(0) - y[i]; - } - vexp_->ComputeDeprecated(y, y); - for (int i = 0; i < this->num_; ++i) { - y[i] = static_cast(1) / (static_cast(1) + y[i]); - } - } - - private: - std::shared_ptr> vexp_; -}; - #define INTRI_SIGMOID(tmp, min, max, expisa) \ tmp = _mm256_max_ps(tmp, min); \ tmp = _mm256_min_ps(tmp, max); \ @@ -290,120 +345,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel { tmp = expisa(tmp); \ tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) - -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VSigmoidKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - /* TODO(TJ): try to use static const*/ \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max, expisa); \ - _mm256_storeu_ps(y, tmp); \ - } - -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VSigmoidKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_SIGMOID(tmp0, min, max, expisa); \ - INTRI_SIGMOID(tmp1, min, max, expisa); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ - } - -#define INTRI_GT8LT16_FLOAT(isa, expisa) \ - template <> \ - VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ - : VSigmoidKernel() { \ - this->num_ = d; \ - this->end_ = AVX_FLOAT_BLOCK; \ - this->rest_ = d - this->end_; \ - vexp_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - } \ - template <> \ - void VSigmoidKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max, expisa); \ - _mm256_storeu_ps(y, tmp); \ - const float min_ = SIGMOID_THRESHOLD_MIN; \ - const float max_ = SIGMOID_THRESHOLD_MAX; \ - for (int i = this->end_; i < this->num_; ++i) { \ - y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ - y[i] = 0.f - y[i]; \ - } \ - vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \ - for (int i = this->end_; i < this->num_; ++i) { \ - y[i] = 1.f / (1.f + y[i]); \ - } \ - } - -#define INTRI_GT16_FLOAT(isa, expisa) \ - template <> \ - VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ - : VSigmoidKernel() { \ - this->num_ = d; \ - this->rest_ = d % AVX_FLOAT_BLOCK; \ - this->end_ = d - this->rest_; \ - vexp_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - } \ - template <> \ - void VSigmoidKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_SIGMOID(tmp, min, max, expisa); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - const float min_ = SIGMOID_THRESHOLD_MIN; \ - const float max_ = SIGMOID_THRESHOLD_MAX; \ - for (int i = this->end_; i < this->num_; ++i) { \ - y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ - y[i] = 0.f - y[i]; \ - } \ - vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \ - for (int i = this->end_; i < this->num_; ++i) { \ - y[i] = 1.f / (1.f + y[i]); \ - } \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx, detail::ExpAVX); -INTRI16_FLOAT(jit::avx, detail::ExpAVX); -INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); -INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); -// maybe use avx at gt8lt16 and gt16 -#endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); -// maybe use avx2 at gt8lt16 and gt16 -#endif - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_GT8LT16_FLOAT -#undef INTRI_GT16_FLOAT #undef INTRI_VSIGMOID -REGISTER_JITKERNEL_DEPRECATED(vsigmoid, VSigmoidKernel); - /* VTanh JitKernel */ template class VTanhKernelImpl : public VTanhKernel { diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index db8e7b74c072f210572a092de08bc0f59ddbc596..29c4dcc357ab13550a6cc76b2b8c7663ced18771 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -223,7 +223,7 @@ void vsigmoid_better( y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = 0.f - y[i]; } - vexp->ComputeDeprecated(y, y); + vexp->Compute(y, y, n); for (int i = 0; i < n; ++i) { y[i] = 1.f / (1.f + y[i]); } @@ -254,7 +254,7 @@ TEST(JitKernel, vsigmoid) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->ComputeDeprecated(x_data, ztgt_data); + ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -288,7 +288,7 @@ void vtanh_better( const int n, const float* x, float* y) { const float a = 2.f, b = -1.f; vscal->Compute(&a, x, y, n); - vsigmoid->ComputeDeprecated(y, y); + vsigmoid->Compute(y, y, n); vscal->Compute(&a, y, y, n); vaddbias->Compute(&b, y, y, n); }