diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 6b3eecfbd11471b5d95dcb10c91acc536f78cb85..e46f60f764ab9f1c292db339a5b38b976de5a11a 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -118,6 +118,39 @@ void VXXJitCode::generate() { ret(); } +bool ReluJitCode::init(int d) { return MayIUse(avx); } + +void ReluJitCode::generate() { + int offset = 0; + vxorps(ymm_zero, ymm_zero, ymm_zero); + for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + vmovups(ymm_src, ptr[param1 + offset]); + vmaxps(ymm_dst, ymm_zero, ymm_src); + vmovups(ptr[param2 + offset], ymm_dst); + offset += sizeof(float) * AVX_FLOAT_BLOCK; + } + int rest = num_ % AVX_FLOAT_BLOCK; + if (rest >= 4) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovups(ptr[param2 + offset], xmm_dst); + offset += sizeof(float) * 4; + rest -= 4; + } + if (rest >= 2) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovq(ptr[param2 + offset], xmm_dst); + offset += sizeof(float) * 2; + rest -= 2; + } + if (rest > 0) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovss(ptr[param2 + offset], xmm_dst); + } + ret(); +} } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index aaedb0ae10323eeddfba9512d9e47c7a22320610..3c242870a24c5bb29d34d4b99406c5df8cec6763 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -85,6 +85,29 @@ class VXXJitCode : public JitCode { ymm_t ymm_zero = ymm_t(3); }; +class ReluJitCode : public JitCode { + public: + DECLARE_JIT_CODE(ReluJitCode); + explicit ReluJitCode(int d, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), num_(d) {} + static bool init(int d); + void generate() override; + + private: + int num_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + + xmm_t xmm_zero = xmm_t(0); + xmm_t xmm_src = xmm_t(1); + xmm_t xmm_dst = xmm_t(1); + + ymm_t ymm_zero = ymm_t(0); + ymm_t ymm_src = ymm_t(1); + ymm_t ymm_dst = ymm_t(1); +}; + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index e9b259282cd00cc2afc46634423ec09590bf5dd3..cd3a45e66773c89e45e80ab77ebd925abd6cbe53 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -97,37 +97,38 @@ class VAddBiasKernel : public Kernel { template class VActKernel : public Kernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template class VReluKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; + void (*Compute)(const T *, T *, int); }; template class VIdentityKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template class VExpKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template class VSigmoidKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template class VTanhKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index c4bfbcf925a2bbdc39f8468049c58e126d3eba1b..cf46a210afbd4903dc3841f27765c390f721c763 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -71,6 +71,13 @@ void VAddBiasRefer(const T* a, const T* x, T* y, int n) { } } +template +void VReluRefer(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -344,124 +351,60 @@ bool VAddBiasKernelImpl::useJIT(int d) { } #endif -#undef DECLARE_STATIC_FUNC - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); -REGISTER_JITKERNEL(vscal, VScalKernel); -REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); - /* VRelu JitKernel */ -template +template class VReluKernelImpl : public VReluKernel { public: - explicit VReluKernelImpl(int d) : VReluKernel() { this->num_ = d; } - void Compute(const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = x[i] > 0 ? x[i] : 0; + DECLARE_STATIC_FUNC; + explicit VReluKernelImpl(int d) : VReluKernel() { + this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 /*init*/ + + d / AVX_FLOAT_BLOCK * 4 /* instructions*/ * + 8 /*everage byte for each instruction*/; + jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096)); + this->Compute = jitcode_->getCode(); + return; } - } -}; - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VReluKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \ - _mm256_storeu_ps(y, tmp); \ - } - -#define INTRI16_FLOAT(isa) \ - template <> \ - void VReluKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 zeros = _mm256_setzero_ps(); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = _mm256_max_ps(tmp0, zeros); \ - tmp1 = _mm256_max_ps(tmp1, zeros); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ - } +#endif -#define INTRI_GT8LT16_FLOAT(isa) \ - template <> \ - VReluKernelImpl::VReluKernelImpl(int d) \ - : VReluKernel() { \ - this->num_ = d; \ - this->end_ = AVX_FLOAT_BLOCK; \ - this->rest_ = d - AVX_FLOAT_BLOCK; \ - } \ - template <> \ - void VReluKernelImpl::Compute(const float* x, \ - float* y) const { \ - __m256 zeros = _mm256_setzero_ps(); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \ - tmp0 = _mm256_max_ps(tmp0, zeros); \ - tmp1 = _mm256_max_ps(tmp1, zeros); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + this->rest_, tmp1); \ + this->Compute = VReluRefer; } - -#define INTRI_GT16_FLOAT(isa) \ - template <> \ - VReluKernelImpl::VReluKernelImpl(int d) \ - : VReluKernel() { \ - this->num_ = d; \ - this->end_ = d - d % AVX_FLOAT_BLOCK; \ - this->rest_ = d - AVX_FLOAT_BLOCK; \ - } \ - template <> \ - void VReluKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 zeros = _mm256_setzero_ps(); \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - tmp = _mm256_max_ps(tmp, zeros); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - __m256 tmp = _mm256_loadu_ps(x + this->rest_); \ - tmp = _mm256_max_ps(tmp, zeros); \ - _mm256_storeu_ps(y + this->rest_, tmp); \ + void ComputeDeprecated(const T* x, T* y) const override { + VReluRefer(x, y, this->num_); } +#ifdef PADDLE_WITH_XBYAK -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_GT8LT16_FLOAT(jit::avx); -INTRI_GT16_FLOAT(jit::avx); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); -INTRI_GT8LT16_FLOAT(jit::avx2); -INTRI_GT16_FLOAT(jit::avx2); + private: + std::unique_ptr jitcode_{nullptr}; #endif -#ifdef __AVX512F__ -// TODO(TJ): refine avx512 -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); -INTRI_GT8LT16_FLOAT(jit::avx512f); -INTRI_GT16_FLOAT(jit::avx512f); +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VReluKernelImpl::useJIT(int d) { + return gen::ReluJitCode::init(d); +} #endif -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_GT8LT16_FLOAT -#undef INTRI_GT16_FLOAT +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); +REGISTER_JITKERNEL(vrelu, VReluKernel); /* An empty JitKernel */ template class VIdentityKernelImpl : public VIdentityKernel { public: explicit VIdentityKernelImpl(int d) : VIdentityKernel() { this->num_ = d; } - void Compute(const T* x, T* y) const override {} + void ComputeDeprecated(const T* x, T* y) const override {} }; -REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index c55e54a13f539014c0f582436ca1a105d0b0fedd..2ac9e1092362f60ea3d89da0c971a365b45f39ea 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -35,7 +35,7 @@ template class VExpKernelImpl : public VExpKernel { public: explicit VExpKernelImpl(int d) : VExpKernel() { this->num_ = d; } - void Compute(const T* x, T* y) const override { + void ComputeDeprecated(const T* x, T* y) const override { for (int i = 0; i < this->num_; ++i) { y[i] = std::exp(x[i]); } @@ -43,18 +43,18 @@ class VExpKernelImpl : public VExpKernel { }; #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VExpKernelImpl::Compute(const float* x, float* y) \ - const { \ - platform::dynload::vsExp(this->num_, x, y); \ +#define MKL_FLOAT(isa, block) \ + template <> \ + void VExpKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + platform::dynload::vsExp(this->num_, x, y); \ } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VExpKernelImpl::Compute(const double* x, double* y) \ - const { \ - platform::dynload::vdExp(this->num_, x, y); \ +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VExpKernelImpl::ComputeDeprecated( \ + const double* x, double* y) const { \ + platform::dynload::vdExp(this->num_, x, y); \ } FOR_EACH_ISA(MKL_FLOAT, kLT8); FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); @@ -211,24 +211,24 @@ __m256 ExpAVX2(__m256 x) { } // namespace detail -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VExpKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - _mm256_storeu_ps(y, expisa(tmp)); \ +#define INTRI8_FLOAT(isa, expisa) \ + template <> \ + void VExpKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + _mm256_storeu_ps(y, expisa(tmp)); \ } -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VExpKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = expisa(tmp0); \ - tmp1 = expisa(tmp1); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ +#define INTRI16_FLOAT(isa, expisa) \ + template <> \ + void VExpKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + tmp0 = expisa(tmp0); \ + tmp1 = expisa(tmp1); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ } #ifdef __AVX__ @@ -260,14 +260,14 @@ class VSigmoidKernelImpl : public VSigmoidKernel { this->num_ = d; vexp_ = KernelPool::Instance().template Get>(d); } - void Compute(const T* x, T* y) const override { + void ComputeDeprecated(const T* x, T* y) const override { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; for (int i = 0; i < this->num_; ++i) { y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = static_cast(0) - y[i]; } - vexp_->Compute(y, y); + vexp_->ComputeDeprecated(y, y); for (int i = 0; i < this->num_; ++i) { y[i] = static_cast(1) / (static_cast(1) + y[i]); } @@ -285,30 +285,30 @@ class VSigmoidKernelImpl : public VSigmoidKernel { tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VSigmoidKernelImpl::Compute(const float* x, float* y) \ - const { \ - /* TODO(TJ): try to use static const*/ \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max, expisa); \ - _mm256_storeu_ps(y, tmp); \ +#define INTRI8_FLOAT(isa, expisa) \ + template <> \ + void VSigmoidKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ + /* TODO(TJ): try to use static const*/ \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_SIGMOID(tmp, min, max, expisa); \ + _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VSigmoidKernelImpl::Compute(const float* x, \ - float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_SIGMOID(tmp0, min, max, expisa); \ - INTRI_SIGMOID(tmp1, min, max, expisa); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ +#define INTRI16_FLOAT(isa, expisa) \ + template <> \ + void VSigmoidKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_SIGMOID(tmp0, min, max, expisa); \ + INTRI_SIGMOID(tmp1, min, max, expisa); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ } #define INTRI_GT8LT16_FLOAT(isa, expisa) \ @@ -322,8 +322,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel { KernelPool::Instance().template Get>(this->rest_); \ } \ template <> \ - void VSigmoidKernelImpl::Compute(const float* x, \ - float* y) const { \ + void VSigmoidKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 tmp = _mm256_loadu_ps(x); \ @@ -335,7 +335,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ y[i] = 0.f - y[i]; \ } \ - vexp_->Compute(y + this->end_, y + this->end_); \ + vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \ for (int i = this->end_; i < this->num_; ++i) { \ y[i] = 1.f / (1.f + y[i]); \ } \ @@ -352,8 +352,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel { KernelPool::Instance().template Get>(this->rest_); \ } \ template <> \ - void VSigmoidKernelImpl::Compute(const float* x, \ - float* y) const { \ + void VSigmoidKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ @@ -367,7 +367,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ y[i] = 0.f - y[i]; \ } \ - vexp_->Compute(y + this->end_, y + this->end_); \ + vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \ for (int i = this->end_; i < this->num_; ++i) { \ y[i] = 1.f / (1.f + y[i]); \ } \ @@ -408,10 +408,10 @@ class VTanhKernelImpl : public VTanhKernel { vsigmoid_ = KernelPool::Instance().template Get>(d); vaddbias_ = KernelPool::Instance().template Get>(d); } - void Compute(const T* x, T* y) const override { + void ComputeDeprecated(const T* x, T* y) const override { const T a = static_cast(2), b = static_cast(-1); vscal_->Compute(&a, x, y, this->num_); - vsigmoid_->Compute(y, y); + vsigmoid_->ComputeDeprecated(y, y); vscal_->Compute(&a, y, y, this->num_); vaddbias_->Compute(&b, y, y, this->num_); } @@ -430,25 +430,25 @@ class VTanhKernelImpl : public VTanhKernel { tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VTanhKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y, tmp); \ +#define INTRI8_FLOAT(isa, expisa) \ + template <> \ + void VTanhKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_VTANH(tmp, expisa); \ + _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VTanhKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_VTANH(tmp0, expisa); \ - INTRI_VTANH(tmp1, expisa); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ +#define INTRI16_FLOAT(isa, expisa) \ + template <> \ + void VTanhKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_VTANH(tmp0, expisa); \ + INTRI_VTANH(tmp1, expisa); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ } #define INTRI_GT8LT16_FLOAT(isa, expisa) \ @@ -466,8 +466,8 @@ class VTanhKernelImpl : public VTanhKernel { this->rest_); \ } \ template <> \ - void VTanhKernelImpl::Compute(const float* x, \ - float* y) const { \ + void VTanhKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ __m256 tmp = _mm256_loadu_ps(x); \ INTRI_VTANH(tmp, expisa); \ _mm256_storeu_ps(y, tmp); \ @@ -475,40 +475,40 @@ class VTanhKernelImpl : public VTanhKernel { y += AVX_FLOAT_BLOCK; \ const float a = 2.f, b = -1.f; \ vscal_->Compute(&a, x, y, this->num_); \ - vsigmoid_->Compute(y, y); \ + vsigmoid_->ComputeDeprecated(y, y); \ vscal_->Compute(&a, y, y, this->num_); \ vaddbias_->Compute(&b, y, y, this->num_); \ } -#define INTRI_GT16_FLOAT(isa, expisa) \ - template <> \ - VTanhKernelImpl::VTanhKernelImpl(int d) \ - : VTanhKernel() { \ - this->num_ = d; \ - this->rest_ = d % AVX_FLOAT_BLOCK; \ - this->end_ = d - this->rest_; \ - vscal_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - vsigmoid_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - vaddbias_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - } \ - template <> \ - void VTanhKernelImpl::Compute(const float* x, float* y) \ - const { \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - x += this->end_; \ - y += this->end_; \ - const float a = 2.f, b = -1.f; \ - vscal_->Compute(&a, x, y, this->num_); \ - vsigmoid_->Compute(y, y); \ - vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(&b, y, y, this->num_); \ +#define INTRI_GT16_FLOAT(isa, expisa) \ + template <> \ + VTanhKernelImpl::VTanhKernelImpl(int d) \ + : VTanhKernel() { \ + this->num_ = d; \ + this->rest_ = d % AVX_FLOAT_BLOCK; \ + this->end_ = d - this->rest_; \ + vscal_ = \ + KernelPool::Instance().template Get>(this->rest_); \ + vsigmoid_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + vaddbias_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + } \ + template <> \ + void VTanhKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + INTRI_VTANH(tmp, expisa); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + x += this->end_; \ + y += this->end_; \ + const float a = 2.f, b = -1.f; \ + vscal_->Compute(&a, x, y, this->num_); \ + vsigmoid_->ComputeDeprecated(y, y); \ + vscal_->Compute(&a, y, y, this->num_); \ + vaddbias_->Compute(&b, y, y, this->num_); \ } #ifdef __AVX__ diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index ba3e917377cf12192a068a9d71238442e12d5e5e..926221f0a75c461e275a72f16b4339ae28a8e988 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel { void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, T* checked) const override { // gates: W_ch, W_ih, W_fh, W_oh - act_gate_d3_->Compute(gates + d_, gates + d_); + act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_); /* C_t = C_t-1 * fgated + cand_gated * igated */ - act_cand_d_->Compute(gates, gates); + act_cand_d_->ComputeDeprecated(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->Compute(ct, gates + d2_); + act_cell_d_->ComputeDeprecated(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { /* C_t = igated * cgated*/ - act_gate_d_->Compute(gates + d_, gates + d_); - act_cand_d_->Compute(gates, gates); + act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); + act_cand_d_->ComputeDeprecated(gates, gates); vmul_d_->Compute(gates, gates + d_, ct, d_); /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->Compute(gates + d3_, gates + d3_); - act_cell_d_->Compute(ct, gates + d2_); + act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); + act_cell_d_->ComputeDeprecated(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } @@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel { vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); - act_gate_d2_->Compute(gates + d_, gates + d_); + act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_); /* C_t = C_t-1 * fgated + cand_gated * igated*/ - act_cand_d_->Compute(gates, gates); + act_cand_d_->ComputeDeprecated(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* get ogated*/ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); - act_gate_d_->Compute(gates + d3_, gates + d3_); + act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->Compute(ct, gates + d2_); + act_cell_d_->ComputeDeprecated(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { /* C_t = igated * cgated*/ - act_gate_d_->Compute(gates + d_, gates + d_); - act_cand_d_->Compute(gates, gates); + act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); + act_cand_d_->ComputeDeprecated(gates, gates); vmul_d_->Compute(gates, gates + d_, ct, d_); /* get outgated, put W_oc * C_t on igated */ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->Compute(gates + d3_, gates + d3_); - act_cell_d_->Compute(ct, gates + d2_); + act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); + act_cell_d_->ComputeDeprecated(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } @@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel { } void ComputeH1(T* gates, T* ht) const override { - act_gate_d_->Compute(gates, gates); - act_state_d_->Compute(gates + d2_, gates + d2_); + act_gate_d_->ComputeDeprecated(gates, gates); + act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_); vmul_d_->Compute(gates, gates + d2_, ht, d_); } void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { // W: {W_update, W_reset; W_state} - act_gate_d2_->Compute(gates, gates); + act_gate_d2_->ComputeDeprecated(gates, gates); vmul_d_->Compute(ht_1, gates + d_, ht, d_); } void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { T* y = gates + d2_; - act_state_d_->Compute(y, y); + act_state_d_->ComputeDeprecated(y, y); // out = zt*ht~ + (1-zt)*ht_1 for (int i = 0; i < d_; ++i) { ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 7dc3e600b564d95b46070ff4436b2d0de2f3e105..5e1f91ffae03796be2817d0461900c2512938c77 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -92,7 +92,7 @@ TEST(JitKernel, vrelu) { #endif auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data); + ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); VLOG(30) << "Vec size " << d @@ -181,7 +181,7 @@ TEST(JitKernel, vexp) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data); + ker->ComputeDeprecated(x_data, ztgt_data); } auto ttgte = GetCurrentUS(); @@ -222,7 +222,7 @@ void vsigmoid_better( y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = 0.f - y[i]; } - vexp->Compute(y, y); + vexp->ComputeDeprecated(y, y); for (int i = 0; i < n; ++i) { y[i] = 1.f / (1.f + y[i]); } @@ -253,7 +253,7 @@ TEST(JitKernel, vsigmoid) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data); + ker->ComputeDeprecated(x_data, ztgt_data); } auto ttgte = GetCurrentUS(); @@ -287,7 +287,7 @@ void vtanh_better( const int n, const float* x, float* y) { const float a = 2.f, b = -1.f; vscal->Compute(&a, x, y, n); - vsigmoid->Compute(y, y); + vsigmoid->ComputeDeprecated(y, y); vscal->Compute(&a, y, y, n); vaddbias->Compute(&b, y, y, n); } @@ -321,7 +321,7 @@ TEST(JitKernel, vtanh) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data); + ker->ComputeDeprecated(x_data, ztgt_data); } auto ttgte = GetCurrentUS(); @@ -344,8 +344,8 @@ void lstm_ctht_ref( const std::shared_ptr< const paddle::operators::math::jitkernel::VExpKernel>& vexp_1, const int d, float* gates, const float* ct_1, float* ct, float* ht) { - vsigmoid_3d->Compute(gates + d, gates + d); - vtanh_d->Compute(gates, gates); + vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); + vtanh_d->ComputeDeprecated(gates, gates); const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; const float min = SIGMOID_THRESHOLD_MIN; const float max = SIGMOID_THRESHOLD_MAX; @@ -355,7 +355,7 @@ void lstm_ctht_ref( // H_t = act_cell(C_t) * ogated float tmp = ct[k] * 2; tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); - vexp_1->Compute(&tmp, &tmp); + vexp_1->ComputeDeprecated(&tmp, &tmp); tmp = 2.f / (1.f + tmp) - 1.f; ht[k] = tmp * o[k]; } @@ -373,13 +373,13 @@ void lstm_ctht_better( const paddle::operators::math::jitkernel::VAddKernel>& vadd_d, const int d, float* gates, const float* ct_1, float* ct, float* ht) { int d2 = d * 2; - vsigmoid_3d->Compute(gates + d, gates + d); - vtanh_d->Compute(gates, gates); + vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); + vtanh_d->ComputeDeprecated(gates, gates); vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d); vadd_d->Compute(gates + d, gates + d2, ct, d); /* H_t = act_cell(C_t) * ogated */ - vtanh_d->Compute(ct, gates + d2); + vtanh_d->ComputeDeprecated(ct, gates + d2); vmul_d->Compute(gates + d2, gates + d * 3, ht, d); } @@ -736,7 +736,7 @@ void vaddrelu_better( const paddle::operators::math::jitkernel::VReluKernel>& vrelu, const float* x, const float* y, float* z, int d) { vadd->Compute(x, y, z, d); - vrelu->Compute(z, z); + vrelu->ComputeDeprecated(z, z); } TEST(JitKernel, vaddrelu) {