diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index eaf5fd0a87ff13c5a79b49ab030cad8ddd7e7263..8a247da4503caf002371f264f471b9731a4c7aca 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -29,7 +29,6 @@ namespace jitkernel { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 #define EXP_MAX_INPUT 40.0 - #define AVX_FLOAT_BLOCK 8 #define AVX2_FLOAT_BLOCK 8 #define AVX512_FLOAT_BLOCK 16 @@ -40,8 +39,9 @@ class Kernel { public: Kernel() = default; virtual ~Kernel() = default; - - private: + int num_{0}; + int end_{0}; + int rest_{0}; DISABLE_COPY_AND_ASSIGN(Kernel); }; @@ -95,13 +95,13 @@ class VExpKernel : public Kernel { template class VSigmoidKernel : public Kernel { public: - virtual void Compute(const int n, const T *x, T *y) const = 0; + virtual void Compute(const T *x, T *y) const = 0; }; template class VTanhKernel : public Kernel { public: - virtual void Compute(const int n, const T *x, T *y) const = 0; + virtual void Compute(const T *x, T *y) const = 0; }; template diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index da0a71be28eb1ee72e42e94ce914bb89aa13c951..ca4c4f4a42a8aea583c29b4f3f24b0e70b13f2fa 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -113,17 +113,18 @@ template class VSigmoidKernelImpl : public VSigmoidKernel { public: explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { + this->num_ = d; vexp_ = KernelPool::Instance().template Get>(d); } - void Compute(const int n, const T* x, T* y) const override { + void Compute(const T* x, T* y) const override { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; - for (int i = 0; i < n; ++i) { + for (int i = 0; i < this->num_; ++i) { y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = static_cast(0) - y[i]; } - vexp_->Compute(n, y, y); - for (int i = 0; i < n; ++i) { + vexp_->Compute(this->num_, y, y); + for (int i = 0; i < this->num_; ++i) { y[i] = static_cast(1) / (static_cast(1) + y[i]); } } @@ -140,76 +141,89 @@ class VSigmoidKernelImpl : public VSigmoidKernel { tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) -#define INTRI8_FLOAT(isa) \ - template <> \ - void VSigmoidKernelImpl::Compute( \ - const int n, const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max); \ - _mm256_storeu_ps(y, tmp); \ +#define INTRI8_FLOAT(isa) \ + template <> \ + void VSigmoidKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_SIGMOID(tmp, min, max); \ + _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa) \ - template <> \ - void VSigmoidKernelImpl::Compute( \ - const int n, const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_SIGMOID(tmp0, min, max); \ - INTRI_SIGMOID(tmp1, min, max); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ +#define INTRI16_FLOAT(isa) \ + template <> \ + void VSigmoidKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_SIGMOID(tmp0, min, max); \ + INTRI_SIGMOID(tmp1, min, max); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ } -#define INTRI_GT8LT16_FLOAT(isa) \ - template <> \ - void VSigmoidKernelImpl::Compute( \ - const int n, const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max); \ - _mm256_storeu_ps(y, tmp); \ - const float min_ = SIGMOID_THRESHOLD_MIN; \ - const float max_ = SIGMOID_THRESHOLD_MAX; \ - for (int i = AVX_FLOAT_BLOCK; i < n; ++i) { \ - y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ - y[i] = 0.f - y[i]; \ - } \ - vexp_->Compute(n - AVX_FLOAT_BLOCK, y + AVX_FLOAT_BLOCK, \ - y + AVX_FLOAT_BLOCK); \ - for (int i = AVX_FLOAT_BLOCK; i < n; ++i) { \ - y[i] = 1.f / (1.f + y[i]); \ - } \ +#define INTRI_GT8LT16_FLOAT(isa) \ + template <> \ + VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ + : VSigmoidKernel() { \ + this->num_ = d; \ + this->end_ = AVX_FLOAT_BLOCK; \ + this->rest_ = d - this->end_; \ + vexp_ = KernelPool::Instance().template Get>(d); \ + } \ + template <> \ + void VSigmoidKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_SIGMOID(tmp, min, max); \ + _mm256_storeu_ps(y, tmp); \ + const float min_ = SIGMOID_THRESHOLD_MIN; \ + const float max_ = SIGMOID_THRESHOLD_MAX; \ + for (int i = this->end_; i < this->num_; ++i) { \ + y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ + y[i] = 0.f - y[i]; \ + } \ + vexp_->Compute(this->rest_, y + this->end_, y + this->end_); \ + for (int i = this->end_; i < this->num_; ++i) { \ + y[i] = 1.f / (1.f + y[i]); \ + } \ } -#define INTRI_GT16_FLOAT(isa) \ - template <> \ - void VSigmoidKernelImpl::Compute( \ - const int n, const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - const int rest = n % AVX_FLOAT_BLOCK; \ - const int end = n - rest; \ - for (int i = 0; i < end; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_SIGMOID(tmp, min, max); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - const float min_ = SIGMOID_THRESHOLD_MIN; \ - const float max_ = SIGMOID_THRESHOLD_MAX; \ - for (int i = end; i < n; ++i) { \ - y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ - y[i] = 0.f - y[i]; \ - } \ - vexp_->Compute(rest, y + end, y + end); \ - for (int i = end; i < n; ++i) { \ - y[i] = 1.f / (1.f + y[i]); \ - } \ +#define INTRI_GT16_FLOAT(isa) \ + template <> \ + VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ + : VSigmoidKernel() { \ + this->num_ = d; \ + this->rest_ = d % AVX_FLOAT_BLOCK; \ + this->end_ = d - this->rest_; \ + vexp_ = KernelPool::Instance().template Get>(d); \ + } \ + template <> \ + void VSigmoidKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + INTRI_SIGMOID(tmp, min, max); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + const float min_ = SIGMOID_THRESHOLD_MIN; \ + const float max_ = SIGMOID_THRESHOLD_MAX; \ + for (int i = this->end_; i < this->num_; ++i) { \ + y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ + y[i] = 0.f - y[i]; \ + } \ + vexp_->Compute(this->rest_, y + this->end_, y + this->end_); \ + for (int i = this->end_; i < this->num_; ++i) { \ + y[i] = 1.f / (1.f + y[i]); \ + } \ } #ifdef __AVX__ @@ -249,15 +263,16 @@ template class VTanhKernelImpl : public VTanhKernel { public: explicit VTanhKernelImpl(int d) : VTanhKernel() { + this->num_ = d; vscal_ = KernelPool::Instance().template Get>(d); vsigmoid_ = KernelPool::Instance().template Get>(d); vaddbias_ = KernelPool::Instance().template Get>(d); } - void Compute(const int n, const T* x, T* y) const override { - vscal_->Compute(n, static_cast(2), x, y); - vsigmoid_->Compute(n, y, y); - vscal_->Compute(n, static_cast(2), y); - vaddbias_->Compute(n, static_cast(-1), y, y); + void Compute(const T* x, T* y) const override { + vscal_->Compute(this->num_, static_cast(2), x, y); + vsigmoid_->Compute(y, y); + vscal_->Compute(this->num_, static_cast(2), y); + vaddbias_->Compute(this->num_, static_cast(-1), y, y); } private: @@ -274,60 +289,83 @@ class VTanhKernelImpl : public VTanhKernel { tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) -#define INTRI8_FLOAT(isa) \ - template <> \ - void VTanhKernelImpl::Compute(const int n, const float* x, \ - float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp); \ - _mm256_storeu_ps(y, tmp); \ +#define INTRI8_FLOAT(isa) \ + template <> \ + void VTanhKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa) \ - template <> \ - void VTanhKernelImpl::Compute( \ - const int n, const float* x, float* y) const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_VTANH(tmp0); \ - INTRI_VTANH(tmp1); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ +#define INTRI16_FLOAT(isa) \ + template <> \ + void VTanhKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_VTANH(tmp0); \ + INTRI_VTANH(tmp1); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ } -#define INTRI_GT8LT16_FLOAT(isa) \ - template <> \ - void VTanhKernelImpl::Compute( \ - const int n, const float* x, float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp); \ - _mm256_storeu_ps(y, tmp); \ - x += AVX_FLOAT_BLOCK; \ - y += AVX_FLOAT_BLOCK; \ - const int rest = n - AVX_FLOAT_BLOCK; \ - vscal_->Compute(rest, 2.f, x, y); \ - vsigmoid_->Compute(rest, y, y); \ - vscal_->Compute(rest, 2.f, y); \ - vaddbias_->Compute(rest, -1.f, y, y); \ +#define INTRI_GT8LT16_FLOAT(isa) \ + template <> \ + VTanhKernelImpl::VTanhKernelImpl(int d) \ + : VTanhKernel() { \ + this->num_ = d; \ + this->end_ = AVX_FLOAT_BLOCK; \ + this->rest_ = d - this->end_; \ + vscal_ = \ + KernelPool::Instance().template Get>(this->rest_); \ + vsigmoid_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + vaddbias_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + } \ + template <> \ + void VTanhKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y, tmp); \ + x += AVX_FLOAT_BLOCK; \ + y += AVX_FLOAT_BLOCK; \ + vscal_->Compute(this->rest_, 2.f, x, y); \ + vsigmoid_->Compute(y, y); \ + vscal_->Compute(this->rest_, 2.f, y); \ + vaddbias_->Compute(this->rest_, -1.f, y, y); \ } -#define INTRI_GT16_FLOAT(isa) \ - template <> \ - void VTanhKernelImpl::Compute( \ - const int n, const float* x, float* y) const { \ - const int rest = n % AVX_FLOAT_BLOCK; \ - const int end = n - rest; \ - for (int i = 0; i < end; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_VTANH(tmp); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - x += end; \ - y += end; \ - vscal_->Compute(rest, 2.f, x, y); \ - vsigmoid_->Compute(rest, y, y); \ - vscal_->Compute(rest, 2.f, y); \ - vaddbias_->Compute(rest, -1.f, y, y); \ +#define INTRI_GT16_FLOAT(isa) \ + template <> \ + VTanhKernelImpl::VTanhKernelImpl(int d) \ + : VTanhKernel() { \ + this->num_ = d; \ + this->rest_ = d % AVX_FLOAT_BLOCK; \ + this->end_ = d - this->rest_; \ + vscal_ = \ + KernelPool::Instance().template Get>(this->rest_); \ + vsigmoid_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + vaddbias_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + } \ + template <> \ + void VTanhKernelImpl::Compute(const float* x, float* y) \ + const { \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + x += this->end_; \ + y += this->end_; \ + vscal_->Compute(this->rest_, 2.f, x, y); \ + vsigmoid_->Compute(y, y); \ + vscal_->Compute(this->rest_, 2.f, y); \ + vaddbias_->Compute(this->rest_, -1.f, y, y); \ } #ifdef __AVX__ diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 3aadc6ef44b1ea7459b635a6332b4cfbd991a573..290605749f5d3d171b9c6023c73ee6a5869a6e12 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -195,7 +195,7 @@ TEST(JitKernel, vsigmoid) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(d, x_data, ztgt_data); + ker->Compute(x_data, ztgt_data); } auto ttgte = GetCurrentUS(); @@ -227,7 +227,7 @@ void vtanh_better( vaddbias, const int n, const float* x, float* y) { vscal->Compute(n, 2.f, x, y); - vsigmoid->Compute(n, y, y); + vsigmoid->Compute(y, y); vscal->Compute(n, 2.f, y); vaddbias->Compute(n, -1.f, y, y); } @@ -261,7 +261,7 @@ TEST(JitKernel, vtanh) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(d, x_data, ztgt_data); + ker->Compute(x_data, ztgt_data); } auto ttgte = GetCurrentUS();