diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index b87715538fe1e78f387c7adb833be77fbd40f0fe..18a58cbea7cb7fc5b3c0d98bc134acba205ec38d 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -28,7 +28,7 @@ KernelPool& KernelPool::Instance() { return g_jit_kernels; } -const std::shared_ptr KernelPool::Get(const std::string& key) const { +std::shared_ptr KernelPool::Get(const std::string& key) const { if (kers_.find(key) == kers_.end()) { return nullptr; } @@ -36,7 +36,7 @@ const std::shared_ptr KernelPool::Get(const std::string& key) const { } template <> -const std::shared_ptr> +std::shared_ptr> KernelPool::Get, int, const std::string&, const std::string&, const std::string&>(int d, const std::string& act_gate, const std::string& act_cand, @@ -49,7 +49,7 @@ KernelPool::Get, int, const std::string&, const std::string&, kers_.insert({key, std::dynamic_pointer_cast(p)}); return p; } - return std::dynamic_pointer_cast>(kers_.at(key)); + return std::dynamic_pointer_cast>(kers_.at(key)); } } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 0a16a878558309ddb600531b29627a829857d8da..24cf2aaf0bad28952a3b580ef43fe400d3a3cf8f 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -52,13 +52,13 @@ class KernelPool { static KernelPool &Instance(); template - const std::shared_ptr Get(ARGS... args); + std::shared_ptr Get(ARGS... args); - const std::shared_ptr Get(const std::string &key) const; + std::shared_ptr Get(const std::string &key) const; private: KernelPool() = default; - std::unordered_map> kers_; + std::unordered_map> kers_; DISABLE_COPY_AND_ASSIGN(KernelPool); }; @@ -66,26 +66,38 @@ class KernelPool { template class VMulKernel : public Kernel { public: - virtual void Compute(const int n, const T *x, const T *y, T *z) = 0; + virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0; }; template class VAddKernel : public Kernel { public: - virtual void Compute(const int n, const T *x, const T *y, T *z) = 0; + virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0; }; template class VScalKernel : public Kernel { public: - virtual void Compute(const int n, const T a, const T *x, T *y) = 0; - virtual void Compute(const int n, const T a, T *x) = 0; + virtual void Compute(const int n, const T a, const T *x, T *y) const = 0; + virtual void Compute(const int n, const T a, T *x) const = 0; }; template class VExpKernel : public Kernel { public: - virtual void Compute(const int n, const T *x, T *y) = 0; + virtual void Compute(const int n, const T *x, T *y) const = 0; +}; + +template +class VSigmoidKernel : public Kernel { + public: + virtual void Compute(const int n, const T *x, T *y) const = 0; +}; + +template +class VTanhKernel : public Kernel { + public: + virtual void Compute(const int n, const T *x, T *y) const = 0; }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index a08d53f496391b704d2a6a48619918d572820311..30761c0430d1920bee32e07cc57e2e6b5851b28e 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -34,7 +34,7 @@ namespace jit = platform::jit; template class VMulKernelImpl : public VMulKernel { public: - void Compute(const int n, const T* x, const T* y, T* z) override { + void Compute(const int n, const T* x, const T* y, T* z) const override { for (int i = 0; i < n; ++i) { z[i] = x[i] * y[i]; } @@ -42,33 +42,33 @@ class VMulKernelImpl : public VMulKernel { }; #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VMulKernelImpl::Compute(const int n, const float* x, \ - const float* y, float* z) { \ - platform::dynload::vsMul(n, x, y, z); \ +#define MKL_FLOAT(isa, block) \ + template <> \ + void VMulKernelImpl::Compute( \ + const int n, const float* x, const float* y, float* z) const { \ + platform::dynload::vsMul(n, x, y, z); \ } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VMulKernelImpl::Compute( \ - const int n, const double* x, const double* y, double* z) { \ - platform::dynload::vdMul(n, x, y, z); \ +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VMulKernelImpl::Compute( \ + const int n, const double* x, const double* y, double* z) const { \ + platform::dynload::vdMul(n, x, y, z); \ } FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #endif -#define INTRI8_FLOAT(isa) \ - template <> \ - void VMulKernelImpl::Compute(const int n, const float* x, \ - const float* y, float* z) { \ - __m256 tmpx, tmpy; \ - tmpx = _mm256_loadu_ps(x); \ - tmpy = _mm256_loadu_ps(y); \ - tmpx = _mm256_mul_ps(tmpx, tmpy); \ - _mm256_storeu_ps(z, tmpx); \ +#define INTRI8_FLOAT(isa) \ + template <> \ + void VMulKernelImpl::Compute( \ + const int n, const float* x, const float* y, float* z) const { \ + __m256 tmpx, tmpy; \ + tmpx = _mm256_loadu_ps(x); \ + tmpy = _mm256_loadu_ps(y); \ + tmpx = _mm256_mul_ps(tmpx, tmpy); \ + _mm256_storeu_ps(z, tmpx); \ } // avx > for > mkl @@ -90,7 +90,7 @@ INTRI8_FLOAT(jit::avx512f); template class VAddKernelImpl : public VAddKernel { public: - void Compute(const int n, const T* x, const T* y, T* z) override { + void Compute(const int n, const T* x, const T* y, T* z) const override { for (int i = 0; i < n; ++i) { z[i] = x[i] + y[i]; } @@ -98,33 +98,33 @@ class VAddKernelImpl : public VAddKernel { }; #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VAddKernelImpl::Compute(const int n, const float* x, \ - const float* y, float* z) { \ - platform::dynload::vsAdd(n, x, y, z); \ +#define MKL_FLOAT(isa, block) \ + template <> \ + void VAddKernelImpl::Compute( \ + const int n, const float* x, const float* y, float* z) const { \ + platform::dynload::vsAdd(n, x, y, z); \ } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VAddKernelImpl::Compute( \ - const int n, const double* x, const double* y, double* z) { \ - platform::dynload::vdAdd(n, x, y, z); \ +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VAddKernelImpl::Compute( \ + const int n, const double* x, const double* y, double* z) const { \ + platform::dynload::vdAdd(n, x, y, z); \ } FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #endif -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddKernelImpl::Compute(const int n, const float* x, \ - const float* y, float* z) { \ - __m256 tmpx, tmpy; \ - tmpx = _mm256_loadu_ps(x); \ - tmpy = _mm256_loadu_ps(y); \ - tmpx = _mm256_add_ps(tmpx, tmpy); \ - _mm256_storeu_ps(z, tmpx); \ +#define INTRI8_FLOAT(isa) \ + template <> \ + void VAddKernelImpl::Compute( \ + const int n, const float* x, const float* y, float* z) const { \ + __m256 tmpx, tmpy; \ + tmpx = _mm256_loadu_ps(x); \ + tmpy = _mm256_loadu_ps(y); \ + tmpx = _mm256_add_ps(tmpx, tmpy); \ + _mm256_storeu_ps(z, tmpx); \ } #ifdef __AVX__ INTRI8_FLOAT(jit::avx); @@ -145,12 +145,12 @@ INTRI8_FLOAT(jit::avx512f); template class VScalKernelImpl : public VScalKernel { public: - void Compute(const int n, const T a, const T* x, T* y) override { + void Compute(const int n, const T a, const T* x, T* y) const override { for (int i = 0; i < n; ++i) { y[i] = a * x[i]; } } - void Compute(const int n, const T a, T* x) override { + void Compute(const int n, const T a, T* x) const override { for (int i = 0; i < n; ++i) { x[i] = a * x[i]; } @@ -161,35 +161,35 @@ class VScalKernelImpl : public VScalKernel { #define MKL_FLOAT(isa, block) \ template <> \ void VScalKernelImpl::Compute(const int n, const float a, \ - float* x) { \ + float* x) const { \ platform::dynload::cblas_sscal(n, a, x, 1); \ } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VScalKernelImpl::Compute( \ - const int n, const double a, double* x) { \ - platform::dynload::cblas_dscal(n, a, x, 1); \ +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VScalKernelImpl::Compute( \ + const int n, const double a, double* x) const { \ + platform::dynload::cblas_dscal(n, a, x, 1); \ } FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #endif -#define INTRI8_FLOAT(isa) \ - template <> \ - void VScalKernelImpl::Compute(const int n, const float a, \ - const float* x, float* y) { \ - __m256 tmp; \ - __m256 scalar = _mm256_set1_ps(a); \ - tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_mul_ps(tmp, scalar); \ - _mm256_storeu_ps(y, tmp); \ +#define INTRI8_FLOAT(isa) \ + template <> \ + void VScalKernelImpl::Compute( \ + const int n, const float a, const float* x, float* y) const { \ + __m256 tmp; \ + __m256 scalar = _mm256_set1_ps(a); \ + tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_mul_ps(tmp, scalar); \ + _mm256_storeu_ps(y, tmp); \ } #define INTRI8_INPLACE_FLOAT(isa) \ template <> \ void VScalKernelImpl::Compute(const int n, const float a, \ - float* x) { \ + float* x) const { \ __m256 tmp; \ __m256 scalar = _mm256_set1_ps(a); \ tmp = _mm256_loadu_ps(x); \ diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 5f04ba97be056878d94a74cdb133b2eb59b35f3d..0c736cd2d07c092242620fbacde8bf55ecf26a8f 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -34,14 +34,13 @@ __m256 Exp(__m256 a); #endif namespace jitkernel { - namespace jit = platform::jit; /* VExp JitKernel */ template class VExpKernelImpl : public VExpKernel { public: - void Compute(const int n, const T* x, T* y) override { + void Compute(const int n, const T* x, T* y) const override { for (int i = 0; i < n; ++i) { y[i] = std::exp(x[i]); } @@ -52,15 +51,15 @@ class VExpKernelImpl : public VExpKernel { #define MKL_FLOAT(isa, block) \ template <> \ void VExpKernelImpl::Compute(const int n, const float* x, \ - float* y) { \ + float* y) const { \ platform::dynload::vsExp(n, x, y); \ } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VExpKernelImpl::Compute( \ - const int n, const double* x, double* y) { \ - platform::dynload::vdExp(n, x, y); \ +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VExpKernelImpl::Compute( \ + const int n, const double* x, double* y) const { \ + platform::dynload::vdExp(n, x, y); \ } FOR_EACH_ISA(MKL_FLOAT, kLT8); FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); @@ -71,7 +70,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #define INTRI8_FLOAT(isa) \ template <> \ void VExpKernelImpl::Compute(const int n, const float* x, \ - float* y) { \ + float* y) const { \ __m256 tmp = _mm256_loadu_ps(x); \ _mm256_storeu_ps(y, detail::Exp(tmp)); \ } @@ -79,7 +78,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #define INTRI16_FLOAT(isa) \ template <> \ void VExpKernelImpl::Compute(const int n, const float* x, \ - float* y) { \ + float* y) const { \ __m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \ tmp0 = detail::Exp(tmp0); \ @@ -109,6 +108,38 @@ INTRI16_FLOAT(jit::avx512f); REGISTER_JITKERNEL(vexp, VExpKernel); +/* VSigmoid JitKernel */ +template +class VSigmoidKernelImpl : public VSigmoidKernel { + public: + explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { + vexp_ = KernelPool::Instance().template Get>(d); + } + void Compute(const int n, const T* x, T* y) const override { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + vexp_->Compute(n, y, y); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); + } + } + + private: + std::shared_ptr> vexp_; +}; + +#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(d)) + +REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE, + JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL); + +#undef JITKERNEL_NEW_ACT_IMPL } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h index 239583f301880e2fe5dfb964dda0bfe8b70a53e1..2b63c695243d3f57ebea63ea8d54fbfb30ecac8e 100644 --- a/paddle/fluid/operators/math/jit_kernel_macro.h +++ b/paddle/fluid/operators/math/jit_kernel_macro.h @@ -23,51 +23,68 @@ namespace jitkernel { namespace jit = platform::jit; -#define NEW_JITKERNEL_IMPL(src, t, isa, k) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>()) - -#define SEARCH_BLOCK(src, t, isa) \ +#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ if (d < AVX_FLOAT_BLOCK) { \ - NEW_JITKERNEL_IMPL(src, t, isa, kLT8); \ + macro_(ker, dtype, isa, kLT8); \ } else if (d == AVX_FLOAT_BLOCK) { \ - NEW_JITKERNEL_IMPL(src, t, isa, kEQ8); \ + macro_(ker, dtype, isa, kEQ8); \ } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ - NEW_JITKERNEL_IMPL(src, t, isa, kGT8LT16); \ + macro_(ker, dtype, isa, kGT8LT16); \ } else if (d == AVX512_FLOAT_BLOCK) { \ - NEW_JITKERNEL_IMPL(src, t, isa, kEQ16); \ + macro_(ker, dtype, isa, kEQ16); \ } else { \ - NEW_JITKERNEL_IMPL(src, t, isa, kGT16); \ + macro_(ker, dtype, isa, kGT16); \ } -#define SEARCH_ISA_BLOCK(src, t) \ - if (jit::MayIUse(jit::avx512f)) { \ - SEARCH_BLOCK(src, t, jit::avx512f); \ - } else if (jit::MayIUse(jit::avx2)) { \ - SEARCH_BLOCK(src, t, jit::avx2); \ - } else if (jit::MayIUse(jit::avx)) { \ - SEARCH_BLOCK(src, t, jit::avx); \ - } else { \ - SEARCH_BLOCK(src, t, jit::isa_any); \ +#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ + if (jit::MayIUse(jit::avx512f)) { \ + SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \ + } else if (jit::MayIUse(jit::avx2)) { \ + SEARCH_BLOCK(macro_, ker, dtype, jit::avx2); \ + } else if (jit::MayIUse(jit::avx)) { \ + SEARCH_BLOCK(macro_, ker, dtype, jit::avx); \ + } else { \ + SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \ } -#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \ - template <> \ - const std::shared_ptr> \ - KernelPool::Get>(int d) { \ - std::string key = #ker_key #dtype_key + std::to_string(d); \ - if (kers_.find(key) == kers_.end()) { \ - std::shared_ptr> p; \ - SEARCH_ISA_BLOCK(ker_class, ker_dtype); \ - kers_.insert({key, std::dynamic_pointer_cast(p)}); \ - return p; \ - } \ - return std::dynamic_pointer_cast>(kers_.at(key)); \ +#define JITKERNEL_DECLARE(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> \ + KernelPool::Get, int>(int d) + +#define JITKERNEL_KEY(ker_key, dtype_key) \ + #ker_key #dtype_key + std::to_string(d) + +#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>()) + +#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \ + marco_declare, macro_key, macro_impl) \ + marco_declare(ker_class, ker_dtype) { \ + std::string key = macro_key(ker_key, dtype_key); \ + if (kers_.find(key) == kers_.end()) { \ + std::shared_ptr> p; \ + SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \ + kers_.insert({key, std::dynamic_pointer_cast(p)}); \ + return p; \ + } \ + return std::dynamic_pointer_cast>( \ + kers_.at(key)); \ } -#define REGISTER_JITKERNEL(ker_key, ker_class) \ - JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f); \ - JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d) +#define REGISTER_JITKERNEL(ker_key, ker_class) \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \ + JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \ + JITKERNEL_KEY, JITKERNEL_NEW_IMPL) + +#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \ + macro_impl) \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \ + macro_impl); \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \ + macro_key, macro_impl) #define FOR_EACH_ISA(macro_, block) \ macro_(jit::avx512f, block); \ diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index a23d5fff04e56c1823b5eb80eeb8c36a48f2ab2e..2495712cb7ae6dca49783cb2bff7a3a4435e9ac1 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -388,16 +388,16 @@ TEST(JitKernel, pool) { const auto& pvmul_f = jit::KernelPool::Instance().template Get>(4); - EXPECT_TRUE(std::dynamic_pointer_cast(plstm2) != - std::dynamic_pointer_cast(pvmul_f)); + EXPECT_TRUE(std::dynamic_pointer_cast(plstm2) != + std::dynamic_pointer_cast(pvmul_f)); const auto& pvmul_d = jit::KernelPool::Instance().template Get>(4); - EXPECT_TRUE(std::dynamic_pointer_cast(pvmul_f) != - std::dynamic_pointer_cast(pvmul_d)); + EXPECT_TRUE(std::dynamic_pointer_cast(pvmul_f) != + std::dynamic_pointer_cast(pvmul_d)); const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulf4"); - EXPECT_TRUE(pvmul_f == pvmul_from_key); + EXPECT_EQ(pvmul_f, pvmul_from_key); const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulf5"); EXPECT_TRUE(pvmul_from_key2 == nullptr); }