From 03e11f3fc9dc53d4925f341f15bec0f2393da80a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 8 Nov 2018 06:01:25 +0000 Subject: [PATCH] add vscal jitcode --- paddle/fluid/operators/math/jit_code.cc | 35 +++++ paddle/fluid/operators/math/jit_code.h | 30 +++- paddle/fluid/operators/math/jit_kernel.h | 3 +- .../fluid/operators/math/jit_kernel_blas.cc | 143 +++++++++--------- paddle/fluid/operators/math/jit_kernel_exp.cc | 15 +- .../fluid/operators/math/jit_kernel_test.cc | 9 +- 6 files changed, 150 insertions(+), 85 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index a92e5d351..f85349780 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -96,6 +96,41 @@ void VVVJitCode::generate() { } ret(); } + +bool VScalJitCode::init(int d) { return MayIUse(avx); } + +void VScalJitCode::generate() { + int offset = 0; + vbroadcastss(ymm_src1, ptr[param1]); + for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + vmovups(ymm_src2, ptr[param2 + offset]); + vmulps(ymm_dst, ymm_src1, ymm_src2); + vmovups(ptr[param3 + offset], ymm_dst); + offset += sizeof(float) * AVX_FLOAT_BLOCK; + } + int rest = num_ % AVX_FLOAT_BLOCK; + if (rest >= 4) { + vmovups(xmm_src2, ptr[param2 + offset]); + vmulps(xmm_dst, xmm_src1, xmm_src2); + vmovups(ptr[param3 + offset], xmm_dst); + offset += sizeof(float) * 4; + rest -= 4; + } + if (rest >= 2) { + vmovq(xmm_src2, ptr[param2 + offset]); + vmulps(xmm_dst, xmm_src1, xmm_src2); + vmovq(ptr[param3 + offset], xmm_dst); + offset += sizeof(float) * 2; + rest -= 2; + } + if (rest > 0) { + vmovss(xmm_src2, ptr[param2 + offset]); + vmulss(xmm_dst, xmm_src1, xmm_src2); + vmovss(ptr[param3 + offset], xmm_dst); + } + ret(); +} + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 73692ebc6..d87831c57 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -29,9 +29,9 @@ using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; using Label = Xbyak::Label; -// function: vec = Operand(vec, vec) (maybe with relu) typedef enum { mul = 0, add } operand_type; +// function: vec = Operand(vec, vec) (maybe with relu) class VVVJitCode : public JitCode { public: const char* name() const override { @@ -41,7 +41,7 @@ class VVVJitCode : public JitCode { } else if (type_ == operand_type::add) { base += "_Add"; } - base += (with_relu_ ? "_relu" : ""); + base += (with_relu_ ? "_Relu" : ""); return base.c_str(); } explicit VVVJitCode(int d, operand_type type, bool with_relu, @@ -72,6 +72,32 @@ class VVVJitCode : public JitCode { ymm_t ymm_zero = ymm_t(2); }; +class VScalJitCode : public JitCode { + public: + DECLARE_JIT_CODE(VScalJitCode); + explicit VScalJitCode(int d, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), num_(d) {} + static bool init(int d); + void generate() override; + + private: + int num_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + reg64_t param3{abi_param3}; + + xmm_t xmm_src1 = xmm_t(0); + xmm_t xmm_src2 = xmm_t(1); + xmm_t xmm_dst = xmm_t(1); + xmm_t xmm_zero = xmm_t(2); + + ymm_t ymm_src1 = ymm_t(0); + ymm_t ymm_src2 = ymm_t(1); + ymm_t ymm_dst = ymm_t(1); + ymm_t ymm_zero = ymm_t(2); +}; + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 04e0b81d3..6ee651b98 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -83,8 +83,7 @@ class VAddReluKernel : public Kernel { template class VScalKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; - virtual void Compute(const T a, T *x) const = 0; + void (*Compute)(const T *, const T *, T *, int); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index f976953a2..a9537ab09 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -57,6 +57,13 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) { } } +template +void VScalRefer(const T* a, const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = a[0] * x[i]; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -83,6 +90,28 @@ template <> void VAddMKL(const double* x, const double* y, double* z, int n) { platform::dynload::vdAdd(n, x, y, z); } + +template +void VScalMKL(const T* a, const T* x, T* y, int n); + +template <> +void VScalMKL(const float* a, const float* x, float* y, int n) { + if (x == y) { + platform::dynload::cblas_sscal(n, *a, y, 1); + } else { + VScalRefer(a, x, y, n); + } +} + +template <> +void VScalMKL(const double* a, const double* x, double* y, int n) { + if (x == y) { + platform::dynload::cblas_dscal(n, *a, y, 1); + } else { + VScalRefer(a, x, y, n); + } +} + #endif #define DECLARE_STATIC_FUNC \ @@ -226,87 +255,60 @@ bool VAddReluKernelImpl::useJIT(int d) { } #endif -#undef DECLARE_STATIC_FUNC - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); - -/* VSCAL JitKernel */ -template +/* VScal JitKernel */ +template class VScalKernelImpl : public VScalKernel { public: - explicit VScalKernelImpl(int d) : VScalKernel() { this->num_ = d; } - void Compute(const T a, const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = a * x[i]; - } - } - void Compute(const T a, T* x) const override { - for (int i = 0; i < this->num_; ++i) { - x[i] = a * x[i]; + DECLARE_STATIC_FUNC; + explicit VScalKernelImpl(int d) : VScalKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; } - } -}; - +#endif #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VScalKernelImpl::Compute(const float a, float* x) \ - const { \ - platform::dynload::cblas_sscal(this->num_, a, x, 1); \ - } - -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VScalKernelImpl::Compute(const double a, double* x) \ - const { \ - platform::dynload::cblas_dscal(this->num_, a, x, 1); \ - } - -FOR_EACH_ISA(MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(MKL_DOUBLE); + if (useMKL(d)) { + this->Compute = VScalMKL; + return; + } #endif - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VScalKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp; \ - __m256 scalar = _mm256_set1_ps(a); \ - tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_mul_ps(tmp, scalar); \ - _mm256_storeu_ps(y, tmp); \ - } -#define INTRI8_INPLACE_FLOAT(isa) \ - template <> \ - void VScalKernelImpl::Compute(const float a, float* x) \ - const { \ - __m256 tmp; \ - __m256 scalar = _mm256_set1_ps(a); \ - tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_mul_ps(tmp, scalar); \ - _mm256_storeu_ps(x, tmp); \ + this->Compute = VScalRefer; } +#ifdef PADDLE_WITH_XBYAK -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI8_INPLACE_FLOAT(jit::avx); + private: + std::unique_ptr jitcode_{nullptr}; #endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI8_INPLACE_FLOAT(jit::avx2); +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VScalKernelImpl::useJIT(int d) { + return gen::VScalJitCode::init(d); +} #endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI8_INPLACE_FLOAT(jit::avx512f); + +#ifdef PADDLE_WITH_MKLML +template <> +bool VScalKernelImpl::useMKL(int d) { + return d > 512; +} +template <> +bool VScalKernelImpl::useMKL(int d) { + return true; +} #endif -// TODO(TJ): eq16 test and complete avx512 -#undef INTRI8_FLOAT -#undef INTRI8_INPLACE_FLOAT -#undef MKL_FLOAT -#undef MKL_DOUBLE +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); /* VAddBias JitKernel */ template @@ -467,7 +469,6 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; -REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel); REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index d7c177e67..07a77086d 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -409,9 +409,10 @@ class VTanhKernelImpl : public VTanhKernel { vaddbias_ = KernelPool::Instance().template Get>(d); } void Compute(const T* x, T* y) const override { - vscal_->Compute(static_cast(2), x, y); + const T a = static_cast(2); + vscal_->Compute(&a, x, y, this->num_); vsigmoid_->Compute(y, y); - vscal_->Compute(static_cast(2), y); + vscal_->Compute(&a, y, y, this->num_); vaddbias_->Compute(static_cast(-1), y, y); } @@ -472,9 +473,10 @@ class VTanhKernelImpl : public VTanhKernel { _mm256_storeu_ps(y, tmp); \ x += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \ - vscal_->Compute(2.f, x, y); \ + const float a = 2.f; \ + vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ - vscal_->Compute(2.f, y); \ + vscal_->Compute(&a, y, y, this->num_); \ vaddbias_->Compute(-1.f, y, y); \ } @@ -502,9 +504,10 @@ class VTanhKernelImpl : public VTanhKernel { } \ x += this->end_; \ y += this->end_; \ - vscal_->Compute(2.f, x, y); \ + const float a = 2.f; \ + vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ - vscal_->Compute(2.f, y); \ + vscal_->Compute(&a, y, y, this->num_); \ vaddbias_->Compute(-1.f, y, y); \ } diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 9a1942469..04a199faa 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -281,9 +281,10 @@ void vtanh_better( const paddle::operators::math::jitkernel::VAddBiasKernel>& vaddbias, const int n, const float* x, float* y) { - vscal->Compute(2.f, x, y); + const float tmp1 = 2.f; + vscal->Compute(&tmp1, x, y, n); vsigmoid->Compute(y, y); - vscal->Compute(2.f, y); + vscal->Compute(&tmp1, y, y, n); vaddbias->Compute(-1.f, y, y); } @@ -531,12 +532,12 @@ TEST(JitKernel, vscal) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, x_data, ztgt_data); + ker->Compute(&a, x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); auto ttgts1 = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, y_data); + ker->Compute(&a, y_data, y_data, d); } auto ttgte1 = GetCurrentUS(); VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat -- GitLab