提交 3d928d4f 编写于 作者: T tensor-tang

refine and seepdup

上级 77fc42d2
...@@ -35,29 +35,6 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const { ...@@ -35,29 +35,6 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
return kers_.at(key); return kers_.at(key);
} }
#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
auto p = std::make_shared<ker_class<ker_dtype>>(d); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
}
#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
#undef REGISTER_BLAS_JITKERNEL
#undef DEFINE_WITH_DTYPE
template <> template <>
const std::shared_ptr<LSTMKernel<float>> const std::shared_ptr<LSTMKernel<float>>
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&, KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
......
...@@ -40,7 +40,7 @@ typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; ...@@ -40,7 +40,7 @@ typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
class Kernel { class Kernel {
public: public:
Kernel() {} Kernel() = default;
virtual ~Kernel() = default; virtual ~Kernel() = default;
private: private:
...@@ -66,15 +66,13 @@ class KernelPool { ...@@ -66,15 +66,13 @@ class KernelPool {
template <typename T> template <typename T>
class VMulKernel : public Kernel { class VMulKernel : public Kernel {
public: public:
explicit VMulKernel(int n); virtual void Compute(const int n, const T *x, const T *y, T *z) = 0;
void (*Compute)(const int n, const T *, const T *, T *);
}; };
template <typename T> template <typename T>
class VAddKernel : public Kernel { class VAddKernel : public Kernel {
public: public:
explicit VAddKernel(int n); virtual void Compute(const int n, const T *x, const T *y, T *z) = 0;
void (*Compute)(const int n, const T *, const T *, T *);
}; };
template <typename T> template <typename T>
......
...@@ -29,17 +29,21 @@ namespace jitkernel { ...@@ -29,17 +29,21 @@ namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
#define NEW_IMPL(src, t, isa, k) \
p = std::dynamic_pointer_cast<src<t>>( \
std::make_shared<src##Impl<t, isa, k>>())
#define SEARCH_BLOCK(src, t, isa) \ #define SEARCH_BLOCK(src, t, isa) \
if (d < AVX_FLOAT_BLOCK) { \ if (d < AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kLT8>; \ NEW_IMPL(src, t, isa, kLT8); \
} else if (d == AVX_FLOAT_BLOCK) { \ } else if (d == AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ8>; \ NEW_IMPL(src, t, isa, kEQ8); \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kGT8LT16>; \ NEW_IMPL(src, t, isa, kGT8LT16); \
} else if (d == AVX512_FLOAT_BLOCK) { \ } else if (d == AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ16>; \ NEW_IMPL(src, t, isa, kEQ16); \
} else { \ } else { \
Compute = src<t, isa, kGT16>; \ NEW_IMPL(src, t, isa, kGT16); \
} }
#define SEARCH_ISA_BLOCK(src, t) \ #define SEARCH_ISA_BLOCK(src, t) \
...@@ -53,6 +57,24 @@ namespace jit = platform::jit; ...@@ -53,6 +57,24 @@ namespace jit = platform::jit;
SEARCH_BLOCK(src, t, jit::isa_any); \ SEARCH_BLOCK(src, t, jit::isa_any); \
} }
#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
}
#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
// do not include lt8, eq8, eq16 // do not include lt8, eq8, eq16
#define FOR_EACH_COMMON_BLOCK(macro_, isa) \ #define FOR_EACH_COMMON_BLOCK(macro_, isa) \
macro_(isa, kGT8LT16) macro_(isa, kGT16) macro_(isa, kGT8LT16) macro_(isa, kGT16)
...@@ -73,47 +95,39 @@ namespace jit = platform::jit; ...@@ -73,47 +95,39 @@ namespace jit = platform::jit;
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \ FOR_EACH_ALL_BLOCK(macro_, jit::avx) \
FOR_EACH_ALL_BLOCK(macro_, jit::isa_any) FOR_EACH_ALL_BLOCK(macro_, jit::isa_any)
#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \
template <> \
ker_class<ker_dtype>::ker_class(int d) { \
SEARCH_ISA_BLOCK(ker_func, ker_dtype); \
}
#define BIND_KERNEL(ker_class, ker_func) \
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, float); \
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, double)
/* VMUL JitKernel */ /* VMUL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
static void VMulCompute(const int n, const T* x, const T* y, T* z) { class VMulKernelImpl : public VMulKernel<T> {
public:
void Compute(const int n, const T* x, const T* y, T* z) override {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i]; z[i] = x[i] * y[i];
} }
} }
};
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define VMUL_MKL_FLOAT(isa, block) \ #define VMUL_MKL_FLOAT(isa, block) \
template <> \ template <> \
void VMulCompute<float, isa, block>(const int n, const float* x, \ void VMulKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
const float* y, float* z) { \ const float* y, float* z) { \
platform::dynload::vsMul(n, x, y, z); \ platform::dynload::vsMul(n, x, y, z); \
} }
#define VMUL_MKL_DOUBLE(isa, block) \ #define VMUL_MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VMulCompute<double, isa, block>(const int n, const double* x, \ void VMulKernelImpl<double, isa, block>::Compute( \
const double* y, double* z) { \ const int n, const double* x, const double* y, double* z) { \
platform::dynload::vdMul(n, x, y, z); \ platform::dynload::vdMul(n, x, y, z); \
} }
FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT) FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT);
FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE) FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE);
#endif #endif
/// eq8
#define VMUL_INTRI8_FLOAT(isa) \ #define VMUL_INTRI8_FLOAT(isa) \
template <> \ template <> \
void VMulCompute<float, isa, kEQ8>(const int n, const float* x, \ void VMulKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
const float* y, float* z) { \ const float* y, float* z) { \
__m256 tmpx, tmpy; \ __m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \ tmpx = _mm256_loadu_ps(x); \
...@@ -126,48 +140,51 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE) ...@@ -126,48 +140,51 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE)
#ifdef __AVX__ #ifdef __AVX__
VMUL_INTRI8_FLOAT(jit::avx); VMUL_INTRI8_FLOAT(jit::avx);
#endif #endif
// avx2 > for > mkl
#ifdef __AVX2__ #ifdef __AVX2__
VMUL_INTRI8_FLOAT(jit::avx2) VMUL_INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
VMUL_INTRI8_FLOAT(jit::avx512f);
#endif #endif
// TODO(TJ): test and complete avx512
// TODO(TJ): eq16 test and complete avx512
#undef VMUL_INTRI8_FLOAT #undef VMUL_INTRI8_FLOAT
#undef VMUL_MKL_FLOAT #undef VMUL_MKL_FLOAT
#undef VMUL_MKL_DOUBLE #undef VMUL_MKL_DOUBLE
/* VADD */ /* VADD JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
static void VAddCompute(const int n, const T* x, const T* y, T* z) { class VAddKernelImpl : public VAddKernel<T> {
public:
void Compute(const int n, const T* x, const T* y, T* z) override {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i]; z[i] = x[i] + y[i];
} }
} }
};
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define VADD_MKL_FLOAT(isa, block) \ #define VADD_MKL_FLOAT(isa, block) \
template <> \ template <> \
void VAddCompute<float, isa, block>(const int n, const float* x, \ void VAddKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
const float* y, float* z) { \ const float* y, float* z) { \
platform::dynload::vsAdd(n, x, y, z); \ platform::dynload::vsAdd(n, x, y, z); \
} }
#define VADD_MKL_DOUBLE(isa, block) \ #define VADD_MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VAddCompute<double, isa, block>(const int n, const double* x, \ void VAddKernelImpl<double, isa, block>::Compute( \
const double* y, double* z) { \ const int n, const double* x, const double* y, double* z) { \
platform::dynload::vdAdd(n, x, y, z); \ platform::dynload::vdAdd(n, x, y, z); \
} }
FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT) FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT);
FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE) FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE);
#endif #endif
/// eq8
#define VADD_INTRI8_FLOAT(isa) \ #define VADD_INTRI8_FLOAT(isa) \
template <> \ template <> \
void VAddCompute<float, isa, kEQ8>(const int n, const float* x, \ void VAddKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
const float* y, float* z) { \ const float* y, float* z) { \
__m256 tmpx, tmpy; \ __m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \ tmpx = _mm256_loadu_ps(x); \
...@@ -175,30 +192,33 @@ FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE) ...@@ -175,30 +192,33 @@ FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE)
tmpx = _mm256_add_ps(tmpx, tmpy); \ tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \ _mm256_storeu_ps(z, tmpx); \
} }
#ifdef __AVX__ #ifdef __AVX__
VADD_INTRI8_FLOAT(jit::avx) VADD_INTRI8_FLOAT(jit::avx);
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
VADD_INTRI8_FLOAT(jit::avx2) VADD_INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
VADD_INTRI8_FLOAT(jit::avx512f);
#endif #endif
// TODO(TJ): test and complete avx512 // TODO(TJ): eq16 test and complete avx512
#undef VADD_INTRI8_FLOAT #undef VADD_INTRI8_FLOAT
#undef VADD_MKL_FLOAT #undef VADD_MKL_FLOAT
#undef VADD_MKL_DOUBLE #undef VADD_MKL_DOUBLE
BIND_KERNEL(VMulKernel, VMulCompute); REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
BIND_KERNEL(VAddKernel, VAddCompute); REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
#undef BIND_KERNEL
#undef BIND_KERNEL_WITH_DTYPE
#undef FOR_EACH_ISA_ALL_BLOCK #undef FOR_EACH_ISA_ALL_BLOCK
#undef FOR_EACH_ALL_BLOCK #undef FOR_EACH_ALL_BLOCK
#undef FOR_EACH_ISA_COMMON_BLOCK #undef FOR_EACH_ISA_COMMON_BLOCK
#undef FOR_EACH_COMMON_BLOCK #undef FOR_EACH_COMMON_BLOCK
#undef REGISTER_BLAS_JITKERNEL
#undef DEFINE_WITH_DTYPE
#undef SEARCH_ISA_BLOCK #undef SEARCH_ISA_BLOCK
#undef SEARCH_BLOCK #undef SEARCH_BLOCK
#undef NEW_IMPL
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册