diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index a92e5d351e71a55bca2845ce275780950d096031..6b3eecfbd11471b5d95dcb10c91acc536f78cb85 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -24,21 +24,30 @@ namespace gen { using namespace platform::jit; // NOLINT -bool VVVJitCode::init(int d) { +bool VXXJitCode::init(int d, int scalar_index) { // It's not necessary to use avx512 since it would slow down the frequency // and this kernel is not compute bound. - return MayIUse(avx); + return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2; } -void VVVJitCode::generate() { +void VXXJitCode::generate() { // do not need push stack, and do not need save avx512reg if do not use avx512 int offset = 0; if (with_relu_) { vxorps(ymm_zero, ymm_zero, ymm_zero); } + if (scalar_index_ == 1) { + vbroadcastss(ymm_src1, ptr[param1]); + } else if (scalar_index_ == 2) { + vbroadcastss(ymm_src2, ptr[param2]); + } for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { - vmovups(ymm_src1, ptr[param1 + offset]); - vmovups(ymm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(ymm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(ymm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(ymm_dst, ymm_src1, ymm_src2); } else if (type_ == operand_type::add) { @@ -52,8 +61,12 @@ void VVVJitCode::generate() { } int rest = num_ % AVX_FLOAT_BLOCK; if (rest >= 4) { - vmovups(xmm_src1, ptr[param1 + offset]); - vmovups(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -67,8 +80,12 @@ void VVVJitCode::generate() { rest -= 4; } if (rest >= 2) { - vmovq(xmm_src1, ptr[param1 + offset]); - vmovq(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -82,8 +99,12 @@ void VVVJitCode::generate() { rest -= 2; } if (rest > 0) { - vmovss(xmm_src1, ptr[param1 + offset]); - vmovss(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulss(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -96,6 +117,7 @@ void VVVJitCode::generate() { } ret(); } + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 73692ebc67c71f6190f2d18bd50071a28a35d4c9..aaedb0ae10323eeddfba9512d9e47c7a22320610 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -29,33 +29,46 @@ using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; using Label = Xbyak::Label; -// function: vec = Operand(vec, vec) (maybe with relu) typedef enum { mul = 0, add } operand_type; -class VVVJitCode : public JitCode { +// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) +class VXXJitCode : public JitCode { public: const char* name() const override { - std::string base = "VVVJitCode"; + std::string base = "VXXJitCode"; + if (scalar_index_ == 1) { + base += "_Scalar"; + } else { + base += "_Vec"; + } if (type_ == operand_type::mul) { base += "_Mul"; } else if (type_ == operand_type::add) { base += "_Add"; } - base += (with_relu_ ? "_relu" : ""); + if (scalar_index_ == 2) { + base += "_Scalar"; + } else { + base += "_Vec"; + } + base += (with_relu_ ? "_Relu" : ""); return base.c_str(); } - explicit VVVJitCode(int d, operand_type type, bool with_relu, - size_t code_size = 256 * 1024, void* code_ptr = nullptr) + explicit VXXJitCode(int d, operand_type type, int scalar_index, + bool with_relu, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) : JitCode(code_size, code_ptr), num_(d), type_(type), + scalar_index_(scalar_index), with_relu_(with_relu) {} - static bool init(int d); + static bool init(int d, int scalar_index = 0); void generate() override; private: int num_; operand_type type_; + int scalar_index_; bool with_relu_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; @@ -63,13 +76,13 @@ class VVVJitCode : public JitCode { xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(1); - xmm_t xmm_zero = xmm_t(2); + xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_zero = xmm_t(3); ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(1); - ymm_t ymm_zero = ymm_t(2); + ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_zero = ymm_t(3); }; } // namespace gen diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 04e0b81d3e7c696ac2f5ee78db90fb3c89ab345d..e9b259282cd00cc2afc46634423ec09590bf5dd3 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -83,14 +83,15 @@ class VAddReluKernel : public Kernel { template class VScalKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; - virtual void Compute(const T a, T *x) const = 0; + // y = a.*x + void (*Compute)(const T *, const T *, T *, int); }; template class VAddBiasKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; + // y = a.+x + void (*Compute)(const T *, const T *, T *, int); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index f976953a245e424e6cb26bbf1cff2f120f84c133..c4bfbcf925a2bbdc39f8468049c58e126d3eba1b 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -57,6 +57,20 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) { } } +template +void VScalRefer(const T* a, const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = a[0] * x[i]; + } +} + +template +void VAddBiasRefer(const T* a, const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = a[0] + x[i]; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -83,6 +97,28 @@ template <> void VAddMKL(const double* x, const double* y, double* z, int n) { platform::dynload::vdAdd(n, x, y, z); } + +template +void VScalMKL(const T* a, const T* x, T* y, int n); + +template <> +void VScalMKL(const float* a, const float* x, float* y, int n) { + if (x == y) { + platform::dynload::cblas_sscal(n, *a, y, 1); + } else { + VScalRefer(a, x, y, n); + } +} + +template <> +void VScalMKL(const double* a, const double* x, double* y, int n) { + if (x == y) { + platform::dynload::cblas_dscal(n, *a, y, 1); + } else { + VScalRefer(a, x, y, n); + } +} + #endif #define DECLARE_STATIC_FUNC \ @@ -102,7 +138,7 @@ class VMulKernelImpl : public VMulKernel { if (useJIT(d)) { // roughly estimate the size of code size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -121,14 +157,14 @@ class VMulKernelImpl : public VMulKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VMulKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif @@ -153,7 +189,7 @@ class VAddKernelImpl : public VAddKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -171,14 +207,14 @@ class VAddKernelImpl : public VAddKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VAddKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif @@ -203,7 +239,7 @@ class VAddReluKernelImpl : public VAddReluKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -215,148 +251,106 @@ class VAddReluKernelImpl : public VAddReluKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VAddReluKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif -#undef DECLARE_STATIC_FUNC - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); - -/* VSCAL JitKernel */ -template +/* VScal JitKernel */ +template class VScalKernelImpl : public VScalKernel { public: - explicit VScalKernelImpl(int d) : VScalKernel() { this->num_ = d; } - void Compute(const T a, const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = a * x[i]; - } - } - void Compute(const T a, T* x) const override { - for (int i = 0; i < this->num_; ++i) { - x[i] = a * x[i]; + DECLARE_STATIC_FUNC; + explicit VScalKernelImpl(int d) : VScalKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, + sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; } - } -}; - +#endif #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VScalKernelImpl::Compute(const float a, float* x) \ - const { \ - platform::dynload::cblas_sscal(this->num_, a, x, 1); \ - } - -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VScalKernelImpl::Compute(const double a, double* x) \ - const { \ - platform::dynload::cblas_dscal(this->num_, a, x, 1); \ - } - -FOR_EACH_ISA(MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(MKL_DOUBLE); + if (useMKL(d)) { + this->Compute = VScalMKL; + return; + } #endif - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VScalKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp; \ - __m256 scalar = _mm256_set1_ps(a); \ - tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_mul_ps(tmp, scalar); \ - _mm256_storeu_ps(y, tmp); \ - } -#define INTRI8_INPLACE_FLOAT(isa) \ - template <> \ - void VScalKernelImpl::Compute(const float a, float* x) \ - const { \ - __m256 tmp; \ - __m256 scalar = _mm256_set1_ps(a); \ - tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_mul_ps(tmp, scalar); \ - _mm256_storeu_ps(x, tmp); \ + this->Compute = VScalRefer; } +#ifdef PADDLE_WITH_XBYAK -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI8_INPLACE_FLOAT(jit::avx); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI8_INPLACE_FLOAT(jit::avx2); + private: + std::unique_ptr jitcode_{nullptr}; #endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI8_INPLACE_FLOAT(jit::avx512f); +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VScalKernelImpl::useJIT(int d) { + return gen::VXXJitCode::init(d, 1); +} #endif -// TODO(TJ): eq16 test and complete avx512 -#undef INTRI8_FLOAT -#undef INTRI8_INPLACE_FLOAT -#undef MKL_FLOAT -#undef MKL_DOUBLE +#ifdef PADDLE_WITH_MKLML +template <> +bool VScalKernelImpl::useMKL(int d) { + return d > 512; +} +template <> +bool VScalKernelImpl::useMKL(int d) { + return true; +} +#endif /* VAddBias JitKernel */ -template +template class VAddBiasKernelImpl : public VAddBiasKernel { public: - explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { this->num_ = d; } - void Compute(const T a, const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = x[i] + a; + DECLARE_STATIC_FUNC; + explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, + sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; } - } -}; - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddBiasKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \ - _mm256_storeu_ps(y, tmp); \ - } +#endif -#define INTRI16_FLOAT(isa) \ - template <> \ - void VAddBiasKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \ - tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ + this->Compute = VAddBiasRefer; } +#ifdef PADDLE_WITH_XBYAK -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); + private: + std::unique_ptr jitcode_{nullptr}; #endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VAddBiasKernelImpl::useJIT(int d) { + return gen::VXXJitCode::init(d, 1); +} #endif -// TODO(TJ): eq16 test and complete avx512 -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); /* VRelu JitKernel */ template @@ -467,8 +461,6 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; -REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel); -REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index d7c177e6782e19e199542e10e1d62587ee0df4cf..c55e54a13f539014c0f582436ca1a105d0b0fedd 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -409,10 +409,11 @@ class VTanhKernelImpl : public VTanhKernel { vaddbias_ = KernelPool::Instance().template Get>(d); } void Compute(const T* x, T* y) const override { - vscal_->Compute(static_cast(2), x, y); + const T a = static_cast(2), b = static_cast(-1); + vscal_->Compute(&a, x, y, this->num_); vsigmoid_->Compute(y, y); - vscal_->Compute(static_cast(2), y); - vaddbias_->Compute(static_cast(-1), y, y); + vscal_->Compute(&a, y, y, this->num_); + vaddbias_->Compute(&b, y, y, this->num_); } private: @@ -472,10 +473,11 @@ class VTanhKernelImpl : public VTanhKernel { _mm256_storeu_ps(y, tmp); \ x += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \ - vscal_->Compute(2.f, x, y); \ + const float a = 2.f, b = -1.f; \ + vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ - vscal_->Compute(2.f, y); \ - vaddbias_->Compute(-1.f, y, y); \ + vscal_->Compute(&a, y, y, this->num_); \ + vaddbias_->Compute(&b, y, y, this->num_); \ } #define INTRI_GT16_FLOAT(isa, expisa) \ @@ -502,10 +504,11 @@ class VTanhKernelImpl : public VTanhKernel { } \ x += this->end_; \ y += this->end_; \ - vscal_->Compute(2.f, x, y); \ + const float a = 2.f, b = -1.f; \ + vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ - vscal_->Compute(2.f, y); \ - vaddbias_->Compute(-1.f, y, y); \ + vscal_->Compute(&a, y, y, this->num_); \ + vaddbias_->Compute(&b, y, y, this->num_); \ } #ifdef __AVX__ diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 9a19424691fad70c161ca6036c5cdfd3b2b22ada..596bd3b2d324131c30fce7439460226574f0a190 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, x_data, ztgt_data); + ker->Compute(&a, x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -281,10 +281,11 @@ void vtanh_better( const paddle::operators::math::jitkernel::VAddBiasKernel>& vaddbias, const int n, const float* x, float* y) { - vscal->Compute(2.f, x, y); + const float a = 2.f, b = -1.f; + vscal->Compute(&a, x, y, n); vsigmoid->Compute(y, y); - vscal->Compute(2.f, y); - vaddbias->Compute(-1.f, y, y); + vscal->Compute(&a, y, y, n); + vaddbias->Compute(&b, y, y, n); } TEST(JitKernel, vtanh) { @@ -531,12 +532,12 @@ TEST(JitKernel, vscal) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, x_data, ztgt_data); + ker->Compute(&a, x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); auto ttgts1 = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, y_data); + ker->Compute(&a, y_data, y_data, d); } auto ttgte1 = GetCurrentUS(); VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat