diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index 8859c0f7d8f62fdfb2704cc80df01425073a9ece..b87715538fe1e78f387c7adb833be77fbd40f0fe 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -35,29 +35,6 @@ const std::shared_ptr KernelPool::Get(const std::string& key) const { return kers_.at(key); } -#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \ - template <> \ - const std::shared_ptr> \ - KernelPool::Get>(int d) { \ - std::string key = #ker_key #dtype_key + std::to_string(d); \ - if (kers_.find(key) == kers_.end()) { \ - auto p = std::make_shared>(d); \ - kers_.insert({key, std::dynamic_pointer_cast(p)}); \ - return p; \ - } \ - return std::dynamic_pointer_cast>(kers_.at(key)); \ - } - -#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \ - DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \ - DEFINE_WITH_DTYPE(ker_key, ker_class, double, d) - -REGISTER_BLAS_JITKERNEL(vmul, VMulKernel); -REGISTER_BLAS_JITKERNEL(vadd, VAddKernel); - -#undef REGISTER_BLAS_JITKERNEL -#undef DEFINE_WITH_DTYPE - template <> const std::shared_ptr> KernelPool::Get, int, const std::string&, const std::string&, diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 610f6714041066dc000a5560f4c925ff7227b1bb..3e75fd1137d7770f35b54fc42d89af6e9cbe7aca 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -40,7 +40,7 @@ typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; class Kernel { public: - Kernel() {} + Kernel() = default; virtual ~Kernel() = default; private: @@ -66,15 +66,13 @@ class KernelPool { template class VMulKernel : public Kernel { public: - explicit VMulKernel(int n); - void (*Compute)(const int n, const T *, const T *, T *); + virtual void Compute(const int n, const T *x, const T *y, T *z) = 0; }; template class VAddKernel : public Kernel { public: - explicit VAddKernel(int n); - void (*Compute)(const int n, const T *, const T *, T *); + virtual void Compute(const int n, const T *x, const T *y, T *z) = 0; }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index f4962bf313cd6089479a74bc547854fa819eb240..7710525717e33497db77b9b7da8e846602dc8501 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -29,17 +29,21 @@ namespace jitkernel { namespace jit = platform::jit; +#define NEW_IMPL(src, t, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>()) + #define SEARCH_BLOCK(src, t, isa) \ if (d < AVX_FLOAT_BLOCK) { \ - Compute = src; \ + NEW_IMPL(src, t, isa, kLT8); \ } else if (d == AVX_FLOAT_BLOCK) { \ - Compute = src; \ + NEW_IMPL(src, t, isa, kEQ8); \ } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ - Compute = src; \ + NEW_IMPL(src, t, isa, kGT8LT16); \ } else if (d == AVX512_FLOAT_BLOCK) { \ - Compute = src; \ + NEW_IMPL(src, t, isa, kEQ16); \ } else { \ - Compute = src; \ + NEW_IMPL(src, t, isa, kGT16); \ } #define SEARCH_ISA_BLOCK(src, t) \ @@ -53,6 +57,24 @@ namespace jit = platform::jit; SEARCH_BLOCK(src, t, jit::isa_any); \ } +#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \ + template <> \ + const std::shared_ptr> \ + KernelPool::Get>(int d) { \ + std::string key = #ker_key #dtype_key + std::to_string(d); \ + if (kers_.find(key) == kers_.end()) { \ + std::shared_ptr> p; \ + SEARCH_ISA_BLOCK(ker_class, ker_dtype); \ + kers_.insert({key, std::dynamic_pointer_cast(p)}); \ + return p; \ + } \ + return std::dynamic_pointer_cast>(kers_.at(key)); \ + } + +#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \ + DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \ + DEFINE_WITH_DTYPE(ker_key, ker_class, double, d) + // do not include lt8, eq8, eq16 #define FOR_EACH_COMMON_BLOCK(macro_, isa) \ macro_(isa, kGT8LT16) macro_(isa, kGT16) @@ -73,132 +95,130 @@ namespace jit = platform::jit; FOR_EACH_ALL_BLOCK(macro_, jit::avx) \ FOR_EACH_ALL_BLOCK(macro_, jit::isa_any) -#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \ - template <> \ - ker_class::ker_class(int d) { \ - SEARCH_ISA_BLOCK(ker_func, ker_dtype); \ - } - -#define BIND_KERNEL(ker_class, ker_func) \ - BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, float); \ - BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, double) - /* VMUL JitKernel */ template -static void VMulCompute(const int n, const T* x, const T* y, T* z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; +class VMulKernelImpl : public VMulKernel { + public: + void Compute(const int n, const T* x, const T* y, T* z) override { + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } } -} +}; #ifdef PADDLE_WITH_MKLML -#define VMUL_MKL_FLOAT(isa, block) \ - template <> \ - void VMulCompute(const int n, const float* x, \ - const float* y, float* z) { \ - platform::dynload::vsMul(n, x, y, z); \ +#define VMUL_MKL_FLOAT(isa, block) \ + template <> \ + void VMulKernelImpl::Compute(const int n, const float* x, \ + const float* y, float* z) { \ + platform::dynload::vsMul(n, x, y, z); \ } -#define VMUL_MKL_DOUBLE(isa, block) \ - template <> \ - void VMulCompute(const int n, const double* x, \ - const double* y, double* z) { \ - platform::dynload::vdMul(n, x, y, z); \ +#define VMUL_MKL_DOUBLE(isa, block) \ + template <> \ + void VMulKernelImpl::Compute( \ + const int n, const double* x, const double* y, double* z) { \ + platform::dynload::vdMul(n, x, y, z); \ } -FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT) -FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE) +FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT); +FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE); #endif -/// eq8 -#define VMUL_INTRI8_FLOAT(isa) \ - template <> \ - void VMulCompute(const int n, const float* x, \ - const float* y, float* z) { \ - __m256 tmpx, tmpy; \ - tmpx = _mm256_loadu_ps(x); \ - tmpy = _mm256_loadu_ps(y); \ - tmpx = _mm256_mul_ps(tmpx, tmpy); \ - _mm256_storeu_ps(z, tmpx); \ +#define VMUL_INTRI8_FLOAT(isa) \ + template <> \ + void VMulKernelImpl::Compute(const int n, const float* x, \ + const float* y, float* z) { \ + __m256 tmpx, tmpy; \ + tmpx = _mm256_loadu_ps(x); \ + tmpy = _mm256_loadu_ps(y); \ + tmpx = _mm256_mul_ps(tmpx, tmpy); \ + _mm256_storeu_ps(z, tmpx); \ } // avx > for > mkl #ifdef __AVX__ VMUL_INTRI8_FLOAT(jit::avx); #endif - -// avx2 > for > mkl #ifdef __AVX2__ -VMUL_INTRI8_FLOAT(jit::avx2) +VMUL_INTRI8_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +VMUL_INTRI8_FLOAT(jit::avx512f); #endif -// TODO(TJ): test and complete avx512 +// TODO(TJ): eq16 test and complete avx512 #undef VMUL_INTRI8_FLOAT #undef VMUL_MKL_FLOAT #undef VMUL_MKL_DOUBLE -/* VADD */ +/* VADD JitKernel */ template -static void VAddCompute(const int n, const T* x, const T* y, T* z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] + y[i]; +class VAddKernelImpl : public VAddKernel { + public: + void Compute(const int n, const T* x, const T* y, T* z) override { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + } } -} +}; #ifdef PADDLE_WITH_MKLML -#define VADD_MKL_FLOAT(isa, block) \ - template <> \ - void VAddCompute(const int n, const float* x, \ - const float* y, float* z) { \ - platform::dynload::vsAdd(n, x, y, z); \ +#define VADD_MKL_FLOAT(isa, block) \ + template <> \ + void VAddKernelImpl::Compute(const int n, const float* x, \ + const float* y, float* z) { \ + platform::dynload::vsAdd(n, x, y, z); \ } -#define VADD_MKL_DOUBLE(isa, block) \ - template <> \ - void VAddCompute(const int n, const double* x, \ - const double* y, double* z) { \ - platform::dynload::vdAdd(n, x, y, z); \ +#define VADD_MKL_DOUBLE(isa, block) \ + template <> \ + void VAddKernelImpl::Compute( \ + const int n, const double* x, const double* y, double* z) { \ + platform::dynload::vdAdd(n, x, y, z); \ } -FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT) -FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE) +FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT); +FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE); #endif -/// eq8 -#define VADD_INTRI8_FLOAT(isa) \ - template <> \ - void VAddCompute(const int n, const float* x, \ - const float* y, float* z) { \ - __m256 tmpx, tmpy; \ - tmpx = _mm256_loadu_ps(x); \ - tmpy = _mm256_loadu_ps(y); \ - tmpx = _mm256_add_ps(tmpx, tmpy); \ - _mm256_storeu_ps(z, tmpx); \ +#define VADD_INTRI8_FLOAT(isa) \ + template <> \ + void VAddKernelImpl::Compute(const int n, const float* x, \ + const float* y, float* z) { \ + __m256 tmpx, tmpy; \ + tmpx = _mm256_loadu_ps(x); \ + tmpy = _mm256_loadu_ps(y); \ + tmpx = _mm256_add_ps(tmpx, tmpy); \ + _mm256_storeu_ps(z, tmpx); \ } - #ifdef __AVX__ -VADD_INTRI8_FLOAT(jit::avx) +VADD_INTRI8_FLOAT(jit::avx); #endif #ifdef __AVX2__ -VADD_INTRI8_FLOAT(jit::avx2) +VADD_INTRI8_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +VADD_INTRI8_FLOAT(jit::avx512f); #endif -// TODO(TJ): test and complete avx512 +// TODO(TJ): eq16 test and complete avx512 #undef VADD_INTRI8_FLOAT #undef VADD_MKL_FLOAT #undef VADD_MKL_DOUBLE -BIND_KERNEL(VMulKernel, VMulCompute); -BIND_KERNEL(VAddKernel, VAddCompute); +REGISTER_BLAS_JITKERNEL(vmul, VMulKernel); +REGISTER_BLAS_JITKERNEL(vadd, VAddKernel); -#undef BIND_KERNEL -#undef BIND_KERNEL_WITH_DTYPE #undef FOR_EACH_ISA_ALL_BLOCK #undef FOR_EACH_ALL_BLOCK #undef FOR_EACH_ISA_COMMON_BLOCK #undef FOR_EACH_COMMON_BLOCK +#undef REGISTER_BLAS_JITKERNEL +#undef DEFINE_WITH_DTYPE #undef SEARCH_ISA_BLOCK #undef SEARCH_BLOCK +#undef NEW_IMPL } // namespace jitkernel } // namespace math