diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 9375ca20670858b6b92ff7057c159d80cc500f0d..35f0bdb9b31d472e8f2fbdc7b76a5aa4ac8175e3 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -70,10 +70,16 @@ bool VAddJitCode::init(int d) { return MayIUse(avx); } void VAddJitCode::generate() { int offset = 0; + if (with_relu_) { + vxorps(ymm_zero, ymm_zero, ymm_zero); + } for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src2, ptr[param2 + offset]); vaddps(ymm_dst, ymm_src1, ymm_src2); + if (with_relu_) { + vmaxps(ymm_dst, ymm_zero, ymm_dst); + } vmovups(ptr[param3 + offset], ymm_dst); offset += sizeof(float) * AVX_FLOAT_BLOCK; } @@ -82,6 +88,9 @@ void VAddJitCode::generate() { vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src2, ptr[param2 + offset]); vaddps(xmm_dst, xmm_src1, xmm_src2); + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovups(ptr[param3 + offset], xmm_dst); offset += sizeof(float) * 4; rest -= 4; @@ -90,6 +99,9 @@ void VAddJitCode::generate() { vmovq(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src2, ptr[param2 + offset]); vaddps(xmm_dst, xmm_src1, xmm_src2); + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovq(ptr[param3 + offset], xmm_dst); offset += sizeof(float) * 2; rest -= 2; @@ -98,6 +110,9 @@ void VAddJitCode::generate() { vmovss(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src2, ptr[param2 + offset]); vaddss(xmm_dst, xmm_src1, xmm_src2); + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovss(ptr[param3 + offset], xmm_dst); } ret(); diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 0c4b75d0309bb0e90298150919983c04ece925c2..6bfed4b22d251b656145ba9d69142b862a82656a 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -46,35 +46,38 @@ class VMulJitCode : public JitCode { xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_dst = xmm_t(1); ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_dst = ymm_t(1); }; class VAddJitCode : public JitCode { public: DECLARE_JIT_CODE(VAddJitCode); - explicit VAddJitCode(int d, size_t code_size = 256 * 1024, + explicit VAddJitCode(int d, bool with_relu, size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} + : JitCode(code_size, code_ptr), num_(d), with_relu_(with_relu) {} static bool init(int d); void generate() override; private: int num_; + bool with_relu_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; reg64_t param3{abi_param3}; xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_dst = xmm_t(1); + xmm_t xmm_zero = xmm_t(2); ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_dst = ymm_t(1); + ymm_t ymm_zero = ymm_t(2); }; } // namespace gen diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 7c3fb5de9bd43b2f01b22fa90b02a530e2180399..04e0b81d3e7c696ac2f5ee78db90fb3c89ab345d 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -75,22 +75,22 @@ class VAddKernel : public Kernel { }; template -class VScalKernel : public Kernel { +class VAddReluKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; - virtual void Compute(const T a, T *x) const = 0; + void (*Compute)(const T *, const T *, T *, int); }; template -class VAddBiasKernel : public Kernel { +class VScalKernel : public Kernel { public: virtual void Compute(const T a, const T *x, T *y) const = 0; + virtual void Compute(const T a, T *x) const = 0; }; template -class VAddReluKernel : public Kernel { +class VAddBiasKernel : public Kernel { public: - virtual void Compute(const T *x, const T *y, T *z) const = 0; + virtual void Compute(const T a, const T *x, T *y) const = 0; }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 16eab62dda75ff8379d87483160714c2869cf141..b3ac33043b6a9ed9c9de00a860f01895c1179dec 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -46,6 +46,14 @@ void VAddRefer(const T* x, const T* y, T* z, int n) { } } +template +void VAddReluRefer(const T* x, const T* y, T* z, int n) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + z[i] = z[i] > 0 ? z[i] : 0; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -131,7 +139,7 @@ class VAddKernelImpl : public VAddKernel { explicit VAddKernelImpl(int d) : VAddKernel() { if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VAddJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VAddJitCode(d, false, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; @@ -164,10 +172,36 @@ bool VAddKernelImpl::useMKL(int d) { return true; } +/* VAddRelu JitKernel */ +template +class VAddReluKernelImpl : public VAddReluKernel { + public: + DECLARE_STATIC_FUNC; + explicit VAddReluKernelImpl(int d) : VAddReluKernel() { + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VAddJitCode(d, true, sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; + } + this->Compute = VAddReluRefer; + } + + private: + std::unique_ptr jitcode_{nullptr}; +}; + +template <> +bool VAddReluKernelImpl::useJIT(int d) { + return gen::VAddJitCode::init(d); +} + #undef DECLARE_STATIC_FUNC REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); /* VSCAL JitKernel */ template @@ -404,97 +438,9 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; -/* VAddRelu JitKernel */ -template -class VAddReluKernelImpl : public VAddReluKernel { - public: - explicit VAddReluKernelImpl(int d) : VAddReluKernel() { this->num_ = d; } - void Compute(const T* x, const T* y, T* z) const override { - for (int i = 0; i < this->num_; ++i) { - z[i] = x[i] + y[i]; - z[i] = z[i] > 0 ? z[i] : 0; - } - } -}; - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 tmpx = _mm256_loadu_ps(x); \ - __m256 tmpy = _mm256_loadu_ps(y); \ - tmpy = _mm256_add_ps(tmpx, tmpy); \ - tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \ - _mm256_storeu_ps(z, tmpy); \ - } - -#define INTRI16_FLOAT(isa) \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 zeros = _mm256_setzero_ps(); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(y); \ - tmp0 = _mm256_add_ps(tmp0, tmp1); \ - tmp0 = _mm256_max_ps(tmp0, zeros); \ - tmp1 = _mm256_loadu_ps(x + 8); \ - __m256 tmp2 = _mm256_loadu_ps(y + 8); \ - tmp1 = _mm256_add_ps(tmp1, tmp2); \ - tmp1 = _mm256_max_ps(tmp1, zeros); \ - _mm256_storeu_ps(z, tmp0); \ - _mm256_storeu_ps(z + 8, tmp1); \ - } - -#define INTRI_COMMON_FLOAT(isa, block) \ - template <> \ - VAddReluKernelImpl::VAddReluKernelImpl(int d) \ - : VAddReluKernel() { \ - this->num_ = d; \ - this->end_ = d - d % AVX_FLOAT_BLOCK; \ - this->rest_ = d - this->end_; \ - } \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 zeros = _mm256_setzero_ps(); \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmpx = _mm256_loadu_ps(x + i); \ - __m256 tmpy = _mm256_loadu_ps(y + i); \ - tmpy = _mm256_add_ps(tmpx, tmpy); \ - tmpy = _mm256_max_ps(tmpy, zeros); \ - _mm256_storeu_ps(z + i, tmpy); \ - } \ - for (int i = this->end_; i < this->num_; ++i) { \ - z[i] = x[i] + y[i]; \ - z[i] = z[i] > 0 ? z[i] : 0; \ - } \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_COMMON_FLOAT(jit::avx, kGT16); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); -INTRI_COMMON_FLOAT(jit::avx2, kGT16); -#endif -#ifdef __AVX512F__ -// TODO(TJ): refine avx512 -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); -INTRI_COMMON_FLOAT(jit::avx512f, kGT16); -#endif - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_COMMON_FLOAT - REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel); REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); -REGISTER_JITKERNEL_DEPRECATED(vaddrelu, VAddReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index f9064d8b2f5a4000d816b7c563aad0eab15c1566..d990a0a98247e9e39080779fe151a6847dcd6e7c 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -757,7 +757,7 @@ TEST(JitKernel, vaddrelu) { auto tmkle = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data); + ker->Compute(x_data, y_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat