diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index e46f60f764ab9f1c292db339a5b38b976de5a11a..dd79949eca70edfc68fba52cc838b71c912a70ed 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -151,6 +151,132 @@ void ReluJitCode::generate() { } ret(); } + +bool VExpJitCode::init(int d) { + return MayIUse(avx) && d == 8; // only 8 yet +} + +#define ALIGN32 __attribute__((aligned(32))) +#define EXP_HIG 88.3762626647949f +#define EXP_LOW -88.3762626647949f +#define CEPHES_LOG2EF 1.44269504088896341 +#define CEPHES_EXP_C1 0.693359375 +#define CEPHES_EXP_C2 -2.12194440e-4 +#define CEPHES_EXP_P0 1.9875691500E-4 +#define CEPHES_EXP_P1 1.3981999507E-3 +#define CEPHES_EXP_P2 8.3334519073E-3 +#define CEPHES_EXP_P3 4.1665795894E-2 +#define CEPHES_EXP_P4 1.6666665459E-1 +#define CEPHES_EXP_P5 5.0000001201E-1 + +#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val + +#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float) + +static const float exp_float_consts[] ALIGN32 = { + REPEAT_8TIMES(1.f), REPEAT_8TIMES(0.5f), + REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_LOW), + REPEAT_8TIMES(CEPHES_LOG2EF), REPEAT_8TIMES(CEPHES_EXP_C1), + REPEAT_8TIMES(CEPHES_EXP_C2), REPEAT_8TIMES(CEPHES_EXP_P0), + REPEAT_8TIMES(CEPHES_EXP_P1), REPEAT_8TIMES(CEPHES_EXP_P2), + REPEAT_8TIMES(CEPHES_EXP_P3), REPEAT_8TIMES(CEPHES_EXP_P4), + REPEAT_8TIMES(CEPHES_EXP_P5)}; + +static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; +static int g_tmp_mem[16] ALIGN32 = {0}; + +void VExpJitCode::generate() { + preCode(); + // push some? + // in: ymm0, out: ymm1 + // use ymm 0~5 (and ymm 14~15 if avx only) + int offset = 0; + vmovups(ymm_src, ptr[param1 + offset]); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); + vminps(ymm_src, ymm_src, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); + vmaxps(ymm_src, ymm_src, ymm_tmp); + // express exp(x) as exp(g + n*log(2)) + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); + vmulps(ymm_fx, ymm_src, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); + vaddps(ymm_fx, ymm_fx, ymm_tmp); + vroundps(ymm_fy, ymm_fx, 0x01); + // if greater, substract 1 + vcmpgtps(ymm_mask, ymm_fy, ymm_fx); + vmovaps(ymm_tmp, ptr[reg_ptr_global]); + vandps(ymm_mask, ymm_mask, ymm_tmp); + vsubps(ymm_fx, ymm_fy, ymm_mask); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); + vmulps(ymm_fy, ymm_fx, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); + vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask + vsubps(ymm_src, ymm_src, ymm_fy); + vsubps(ymm_src, ymm_src, ymm_z); + vmulps(ymm_z, ymm_src, ymm_src); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); + vmulps(ymm_dst, ymm_src, ymm_tmp); + for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; + i += (AVX_FLOAT_BLOCK * sizeof(float))) { + vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4 + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vmulps(ymm_dst, ymm_dst, ymm_src); + } + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vmulps(ymm_dst, ymm_dst, ymm_z); + vaddps(ymm_dst, ymm_dst, ymm_src); + vmovaps(ymm_tmp, ptr[reg_ptr_global]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + + // build 2^n + ymm_t ymm_int = ymm_fx; + vcvttps2dq(ymm_int, ymm_fx); + mov(reg_ptr_global, reinterpret_cast(exp_int_0x7f)); + vmovdqa(ymm_tmp, ptr[reg_ptr_global]); + if (MayIUse(avx2)) { + vpaddd(ymm_int, ymm_int, ymm_tmp); + vpslld(ymm_int, ymm_int, 23); + } else if (MayIUse(avx)) { + // use ymm_int, ymm_tmp and reg_ptr_global + xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int + xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp + mov(reg_ptr_global, reinterpret_cast(g_tmp_mem)); + vmovdqa(ptr[reg_ptr_global], ymm_int); + vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp); + vpaddd(xtmp1, xtmp1, xtmp2); + vpslld(xtmp1, xtmp1, 23); + vmovdqa(ptr[reg_ptr_global], xtmp1); + // next 128bits + vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]); + vmovdqa(xtmp2, + ptr[reg_ptr_global + + (AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); + vpaddd(xtmp1, xtmp1, xtmp2); + vpslld(xtmp1, xtmp1, 23); + vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1); + // load out + vmovdqa(ymm_int, ptr[reg_ptr_global]); + } + vmulps(ymm_dst, ymm_dst, ymm_int); + vmovups(ptr[param2 + offset], ymm_dst); + + // ret(); + postCode(); +} + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 3c242870a24c5bb29d34d4b99406c5df8cec6763..984bd15a22a20cdb34207803dd55a0e2cf26c928 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -108,6 +108,30 @@ class ReluJitCode : public JitCode { ymm_t ymm_dst = ymm_t(1); }; +class VExpJitCode : public JitCode { + public: + DECLARE_JIT_CODE(VExpJitCode); + explicit VExpJitCode(int d, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), num_(d) {} + static bool init(int d); + void generate() override; + + private: + int num_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + + reg64_t reg_ptr_global = rax; + ymm_t ymm_src = ymm_t(0); + ymm_t ymm_dst = ymm_t(1); + ymm_t ymm_fx = ymm_t(2); + ymm_t ymm_fy = ymm_t(3); + ymm_t ymm_mask = ymm_t(4); + ymm_t ymm_z = ymm_t(4); + ymm_t ymm_tmp = ymm_t(5); +}; + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index cd3a45e66773c89e45e80ab77ebd925abd6cbe53..a68d9c5d2ebdb7001da5a6060538bc3e12d18d0c 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -117,6 +117,7 @@ template class VExpKernel : public VActKernel { public: virtual void ComputeDeprecated(const T *x, T *y) const = 0; + void (*Compute)(const T *, T *, int); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index cf46a210afbd4903dc3841f27765c390f721c763..d96d5f15ea7cd984d9f84f0943094b6c9abfc045 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -25,10 +25,6 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/mklml.h" #endif -#ifdef __AVX__ -#include -#endif - namespace paddle { namespace operators { namespace math { @@ -128,18 +124,11 @@ void VScalMKL(const double* a, const double* x, double* y, int n) { #endif -#define DECLARE_STATIC_FUNC \ - static inline std::string name(int d) { \ - PADDLE_THROW("DType should be either float or double"); \ - } \ - static inline bool useJIT(int d) { return false; } \ - static inline bool useMKL(int d) { return false; } - /* VMUL JitKernel */ template class VMulKernelImpl : public VMulKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VMulKernelImpl(int d) : VMulKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { @@ -191,7 +180,7 @@ bool VMulKernelImpl::useMKL(int d) { template class VAddKernelImpl : public VAddKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VAddKernelImpl(int d) : VAddKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { @@ -241,7 +230,7 @@ bool VAddKernelImpl::useMKL(int d) { template class VAddReluKernelImpl : public VAddReluKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VAddReluKernelImpl(int d) : VAddReluKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { @@ -273,7 +262,7 @@ bool VAddReluKernelImpl::useJIT(int d) { template class VScalKernelImpl : public VScalKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VScalKernelImpl(int d) : VScalKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { @@ -322,7 +311,7 @@ bool VScalKernelImpl::useMKL(int d) { template class VAddBiasKernelImpl : public VAddBiasKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { @@ -355,14 +344,14 @@ bool VAddBiasKernelImpl::useJIT(int d) { template class VReluKernelImpl : public VReluKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VReluKernelImpl(int d) : VReluKernel() { this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 /*init*/ + - d / AVX_FLOAT_BLOCK * 4 /* instructions*/ * - 8 /*everage byte for each instruction*/; + size_t sz = 96 /* init size */ + + d / AVX_FLOAT_BLOCK * 4 /* instructions */ * + 8 /* average bytes for each instruction */; jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; @@ -388,8 +377,6 @@ bool VReluKernelImpl::useJIT(int d) { } #endif -#undef DECLARE_STATIC_FUNC - REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 2ac9e1092362f60ea3d89da0c971a365b45f39ea..eae9648bdcd9e6cada50fd90e8ab358af36bfdbd 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -16,6 +16,11 @@ limitations under the License. */ #include // for exp #include #include "paddle/fluid/operators/math/jit_kernel_macro.h" + +#ifdef PADDLE_WITH_XBYAK +#include "paddle/fluid/operators/math/jit_code.h" +#endif + #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" #endif @@ -30,41 +35,84 @@ namespace math { namespace jitkernel { namespace jit = platform::jit; +// TODO(TJ): move refer codes to one file +template +void VExpRefer(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +} + +#ifdef PADDLE_WITH_MKLML +template +void VExpMKL(const T* x, T* y, int n); + +template <> +void VExpMKL(const float* x, float* y, int n) { + platform::dynload::vsExp(n, x, y); +} + +template <> +void VExpMKL(const double* x, double* y, int n) { + platform::dynload::vdExp(n, x, y); +} +#endif + /* VExp JitKernel */ -template +template class VExpKernelImpl : public VExpKernel { public: - explicit VExpKernelImpl(int d) : VExpKernel() { this->num_ = d; } - void ComputeDeprecated(const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = std::exp(x[i]); + JITKERNEL_DECLARE_STATIC_FUNC; + explicit VExpKernelImpl(int d) : VExpKernel() { + this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change + jitcode_.reset(new gen::VExpJitCode(d, sz > 4096 ? sz : 4096)); + this->Compute = jitcode_->getCode(); + return; } +#endif +#ifdef PADDLE_WITH_MKLML + if (useMKL(d)) { + this->Compute = VExpMKL; + return; + } +#endif + this->Compute = VExpRefer; } + void ComputeDeprecated(const T* x, T* y) const override { + VExpRefer(x, y, this->num_); + } +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +#endif }; +#ifdef PADDLE_WITH_XBYAK +template <> +bool VExpKernelImpl::useJIT(int d) { + return gen::VExpJitCode::init(d); +} +#endif + #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VExpKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - platform::dynload::vsExp(this->num_, x, y); \ - } +template <> +bool VExpKernelImpl::useMKL(int d) { + return d > 512; +} -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VExpKernelImpl::ComputeDeprecated( \ - const double* x, double* y) const { \ - platform::dynload::vdExp(this->num_, x, y); \ - } -FOR_EACH_ISA(MKL_FLOAT, kLT8); -FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); -FOR_EACH_ISA(MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(MKL_DOUBLE); +template <> +bool VExpKernelImpl::useMKL(int d) { + return true; +} #endif -namespace detail { +REGISTER_JITKERNEL(vexp, VExpKernel); -#ifdef __AVX__ +namespace detail { #define ALIGN32 __attribute__((aligned(32))) @@ -195,7 +243,6 @@ __m256 ExpAVX(__m256 x) { y = _mm256_mul_ps(y, pow2n); return y; } -#endif #ifdef __AVX2__ __m256 ExpAVX2(__m256 x) { @@ -211,47 +258,6 @@ __m256 ExpAVX2(__m256 x) { } // namespace detail -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VExpKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - _mm256_storeu_ps(y, expisa(tmp)); \ - } - -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VExpKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = expisa(tmp0); \ - tmp1 = expisa(tmp1); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx, detail::ExpAVX); -INTRI16_FLOAT(jit::avx, detail::ExpAVX); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); -#endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); -#endif -// TODO(TJ): eq16 test and complete avx512 - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef MKL_FLOAT -#undef MKL_DOUBLE - -REGISTER_JITKERNEL_DEPRECATED(vexp, VExpKernel); - /* VSigmoid JitKernel */ template class VSigmoidKernelImpl : public VSigmoidKernel { diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h index a8169ea48ae3eee5a8cba291be4496c4c6074221..e8bbc0cae57159a4369bad2c2798e9f67de46ffd 100644 --- a/paddle/fluid/operators/math/jit_kernel_macro.h +++ b/paddle/fluid/operators/math/jit_kernel_macro.h @@ -15,12 +15,20 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { namespace math { namespace jitkernel { +#define JITKERNEL_DECLARE_STATIC_FUNC \ + static inline std::string name(int d) { \ + PADDLE_THROW("DType should be either float or double"); \ + } \ + static inline bool useJIT(int d) { return false; } \ + static inline bool useMKL(int d) { return false; } + #define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \ template <> \ std::string ker_class##Impl::name(int d) { \ diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 5e1f91ffae03796be2817d0461900c2512938c77..db8e7b74c072f210572a092de08bc0f59ddbc596 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -181,7 +181,8 @@ TEST(JitKernel, vexp) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->ComputeDeprecated(x_data, ztgt_data); + // ker->ComputeDeprecated(x_data, ztgt_data); + ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS();