From 7cb19a5976a8c23c34cdea6d86bf3ce7c3c3cc79 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 19 Oct 2018 00:48:43 +0800 Subject: [PATCH] fuse elementwise_add and relu --- paddle/fluid/operators/math/fc_compute.h | 24 +++-- paddle/fluid/operators/math/jit_kernel.h | 6 ++ .../fluid/operators/math/jit_kernel_blas.cc | 91 +++++++++++++++++++ 3 files changed, 112 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 1f5a49c0ab5..2d7e877a77b 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): add deps DECLARE_int32(paddle_num_threads); @@ -30,20 +31,25 @@ inline void FCCompute(const BlasT& blas, const int M, if (B == NULL) { return; } + if (relu) { + const auto& vaddrelu = jitkernel::KernelPool::Instance() + .template Get>(N); + for (int i = 0; i < M; i++) { + T* dst = Y + i * N; + vaddrelu->Compute(B, dst, dst); + } + } else { + const auto& vadd = jitkernel::KernelPool::Instance() + .template Get>(N); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) #endif - for (int i = 0; i < M; i++) { - blas.AXPY(N, static_cast(1), B, Y + i * N); + for (int i = 0; i < M; i++) { + T* dst = Y + i * N; + vadd->Compute(B, dst, dst); + } } - - if (!relu) { - return; - } - - // TODO(TJ): fuse relu - LOG(FATAL) << "Not implemented!"; } } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index b4dfda6db76..e91e4e8e5ad 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -86,6 +86,12 @@ class VAddBiasKernel : public Kernel { virtual void Compute(const T a, const T *x, T *y) const = 0; }; +template +class VAddReluKernel : public Kernel { + public: + virtual void Compute(const T *x, const T *y, T *z) const = 0; +}; + template class VActKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 0f9ea533fcc..a486a0ca804 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -378,11 +378,102 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; +/* VAddRelu JitKernel */ +template +class VAddReluKernelImpl : public VAddReluKernel { + public: + explicit VAddReluKernelImpl(int d) : VAddReluKernel() { this->num_ = d; } + void Compute(const T* x, const T* y, T* z) const override { + for (int i = 0; i < this->num_; ++i) { + z[i] = x[i] + y[i]; + z[i] = z[i] > 0 ? z[i] : 0; + } + } +}; + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VAddReluKernelImpl::Compute( \ + const float* x, const float* y, float* z) const { \ + __m256 tmpx = _mm256_loadu_ps(x); \ + __m256 tmpy = _mm256_loadu_ps(y); \ + tmpy = _mm256_add_ps(tmpx, tmpy); \ + tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \ + _mm256_storeu_ps(z, tmpy); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VAddReluKernelImpl::Compute( \ + const float* x, const float* y, float* z) const { \ + __m256 zeros = _mm256_setzero_ps(); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(y); \ + tmp0 = _mm256_add_ps(tmp0, tmp1); \ + tmp0 = _mm256_max_ps(tmp0, zeros); \ + tmp1 = _mm256_loadu_ps(x + 8); \ + __m256 tmp2 = _mm256_loadu_ps(y + 8); \ + tmp1 = _mm256_add_ps(tmp1, tmp2); \ + tmp1 = _mm256_max_ps(tmp1, zeros); \ + _mm256_storeu_ps(z, tmp0); \ + _mm256_storeu_ps(z + 8, tmp1); \ + } + +#define INTRI_COMMON_FLOAT(isa, block) \ + template <> \ + VAddReluKernelImpl::VAddReluKernelImpl(int d) \ + : VAddReluKernel() { \ + this->num_ = d; \ + this->end_ = d - d % AVX_FLOAT_BLOCK; \ + this->rest_ = d - this->end_; \ + } \ + template <> \ + void VAddReluKernelImpl::Compute( \ + const float* x, const float* y, float* z) const { \ + __m256 zeros = _mm256_setzero_ps(); \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmpx = _mm256_loadu_ps(x + i); \ + __m256 tmpy = _mm256_loadu_ps(y + i); \ + tmpy = _mm256_add_ps(tmpx, tmpy); \ + tmpy = _mm256_max_ps(tmpy, zeros); \ + _mm256_storeu_ps(z + i, tmpy); \ + } \ + for (int i = this->end_; i < this->num_; ++i) { \ + z[i] = x[i] + y[i]; \ + z[i] = z[i] > 0 ? z[i] : 0; \ + } \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +INTRI_COMMON_FLOAT(jit::avx, kGT8LT16); +INTRI_COMMON_FLOAT(jit::avx, kGT16); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +INTRI_COMMON_FLOAT(jit::avx2, kGT8LT16); +INTRI_COMMON_FLOAT(jit::avx2, kGT16); +#endif +#ifdef __AVX512F__ +// TODO(TJ): refine avx512 +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +INTRI_COMMON_FLOAT(jit::avx512f, kGT8LT16); +INTRI_COMMON_FLOAT(jit::avx512f, kGT16); +#endif + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT +#undef INTRI_COMMON_FLOAT + REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vaddb, VAddBiasKernel); REGISTER_JITKERNEL(vrelu, VReluKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); REGISTER_JITKERNEL(videntity, VIdentityKernel); } // namespace jitkernel -- GitLab