From 5e64244f250376666814816fc333c614cc8c085d Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 8 Nov 2018 07:32:39 +0000 Subject: [PATCH] add vaddbias jitcode test=develop --- paddle/fluid/operators/math/jit_code.h | 12 ++- paddle/fluid/operators/math/jit_kernel.h | 4 +- .../fluid/operators/math/jit_kernel_blas.cc | 84 ++++++++----------- paddle/fluid/operators/math/jit_kernel_exp.cc | 12 +-- .../fluid/operators/math/jit_kernel_test.cc | 10 +-- 5 files changed, 62 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 939d9897e..aaedb0ae1 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -31,16 +31,26 @@ using Label = Xbyak::Label; typedef enum { mul = 0, add } operand_type; -// function: vec = Operand(vec(scalar), vec(scalar)) (maybe with relu) +// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VXXJitCode : public JitCode { public: const char* name() const override { std::string base = "VXXJitCode"; + if (scalar_index_ == 1) { + base += "_Scalar"; + } else { + base += "_Vec"; + } if (type_ == operand_type::mul) { base += "_Mul"; } else if (type_ == operand_type::add) { base += "_Add"; } + if (scalar_index_ == 2) { + base += "_Scalar"; + } else { + base += "_Vec"; + } base += (with_relu_ ? "_Relu" : ""); return base.c_str(); } diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 6ee651b98..e9b259282 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -83,13 +83,15 @@ class VAddReluKernel : public Kernel { template class VScalKernel : public Kernel { public: + // y = a.*x void (*Compute)(const T *, const T *, T *, int); }; template class VAddBiasKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; + // y = a.+x + void (*Compute)(const T *, const T *, T *, int); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 1f468a7fe..d5e45cf7f 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -60,6 +60,13 @@ void VScalRefer(const T* a, const T* x, T* y, int n) { } } +template +void VAddBiasRefer(const T* a, const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = a[0] + x[i]; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -300,62 +307,46 @@ bool VScalKernelImpl::useMKL(int d) { } #endif -#undef DECLARE_STATIC_FUNC - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vscal, VScalKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); - /* VAddBias JitKernel */ -template +template class VAddBiasKernelImpl : public VAddBiasKernel { public: - explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { this->num_ = d; } - void Compute(const T a, const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = x[i] + a; + DECLARE_STATIC_FUNC; + explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, + sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; } - } -}; - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddBiasKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \ - _mm256_storeu_ps(y, tmp); \ - } +#endif -#define INTRI16_FLOAT(isa) \ - template <> \ - void VAddBiasKernelImpl::Compute( \ - const float a, const float* x, float* y) const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \ - tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ + this->Compute = VAddBiasRefer; } +#ifdef PADDLE_WITH_XBYAK -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); + private: + std::unique_ptr jitcode_{nullptr}; #endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VAddBiasKernelImpl::useJIT(int d) { + return gen::VXXJitCode::init(d, 1); +} #endif -// TODO(TJ): eq16 test and complete avx512 -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); /* VRelu JitKernel */ template @@ -466,7 +457,6 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; -REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 5df17c11b..fd507808c 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -409,11 +409,11 @@ class VTanhKernelImpl : public VTanhKernel { vaddbias_ = KernelPool::Instance().template Get>(d); } void Compute(const T* x, T* y) const override { - const T a = static_cast(2); + const T a = static_cast(2), b = static_cast(-1); vscal_->Compute(&a, x, y, this->num_); vsigmoid_->Compute(y, y); vscal_->Compute(&a, y, y, this->num_); - vaddbias_->Compute(static_cast(-1), y, y); + vaddbias_->Compute(&b, y, y, this->num_); } private: @@ -473,11 +473,11 @@ class VTanhKernelImpl : public VTanhKernel { _mm256_storeu_ps(y, tmp); \ x += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \ - const float a = 2.f; \ + const float a = 2.f, b = -1.f; \ vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(-1.f, y, y); \ + vaddbias_->Compute(&b, y, y, this->num_); \ } #define INTRI_GT16_FLOAT(isa, expisa) \ @@ -504,11 +504,11 @@ class VTanhKernelImpl : public VTanhKernel { } \ x += this->end_; \ y += this->end_; \ - const float a = 2.f; \ + const float a = 2.f, b = -1.f; \ vscal_->Compute(&a, x, y, this->num_); \ vsigmoid_->Compute(y, y); \ vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(-1.f, y, y); \ + vaddbias_->Compute(&b, y, y, this->num_); \ } #ifndef __WIN32 diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 04a199faa..596bd3b2d 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(a, x_data, ztgt_data); + ker->Compute(&a, x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -281,11 +281,11 @@ void vtanh_better( const paddle::operators::math::jitkernel::VAddBiasKernel>& vaddbias, const int n, const float* x, float* y) { - const float tmp1 = 2.f; - vscal->Compute(&tmp1, x, y, n); + const float a = 2.f, b = -1.f; + vscal->Compute(&a, x, y, n); vsigmoid->Compute(y, y); - vscal->Compute(&tmp1, y, y, n); - vaddbias->Compute(-1.f, y, y); + vscal->Compute(&a, y, y, n); + vaddbias->Compute(&b, y, y, n); } TEST(JitKernel, vtanh) { -- GitLab