提交 55e44761 编写于 作者: T tensor-tang

refine code and init vsigmoid

上级 2d0ff6a3
...@@ -28,7 +28,7 @@ KernelPool& KernelPool::Instance() { ...@@ -28,7 +28,7 @@ KernelPool& KernelPool::Instance() {
return g_jit_kernels; return g_jit_kernels;
} }
const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const { std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
if (kers_.find(key) == kers_.end()) { if (kers_.find(key) == kers_.end()) {
return nullptr; return nullptr;
} }
...@@ -36,7 +36,7 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const { ...@@ -36,7 +36,7 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
} }
template <> template <>
const std::shared_ptr<LSTMKernel<float>> std::shared_ptr<const LSTMKernel<float>>
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&, KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
const std::string&>(int d, const std::string& act_gate, const std::string&>(int d, const std::string& act_gate,
const std::string& act_cand, const std::string& act_cand,
...@@ -49,7 +49,7 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&, ...@@ -49,7 +49,7 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p; return p;
} }
return std::dynamic_pointer_cast<LSTMKernel<float>>(kers_.at(key)); return std::dynamic_pointer_cast<const LSTMKernel<float>>(kers_.at(key));
} }
} // namespace jitkernel } // namespace jitkernel
......
...@@ -52,13 +52,13 @@ class KernelPool { ...@@ -52,13 +52,13 @@ class KernelPool {
static KernelPool &Instance(); static KernelPool &Instance();
template <typename Ker, typename... ARGS> template <typename Ker, typename... ARGS>
const std::shared_ptr<Ker> Get(ARGS... args); std::shared_ptr<const Ker> Get(ARGS... args);
const std::shared_ptr<Kernel> Get(const std::string &key) const; std::shared_ptr<const Kernel> Get(const std::string &key) const;
private: private:
KernelPool() = default; KernelPool() = default;
std::unordered_map<std::string, std::shared_ptr<Kernel>> kers_; std::unordered_map<std::string, std::shared_ptr<const Kernel>> kers_;
DISABLE_COPY_AND_ASSIGN(KernelPool); DISABLE_COPY_AND_ASSIGN(KernelPool);
}; };
...@@ -66,26 +66,38 @@ class KernelPool { ...@@ -66,26 +66,38 @@ class KernelPool {
template <typename T> template <typename T>
class VMulKernel : public Kernel { class VMulKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, const T *y, T *z) = 0; virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0;
}; };
template <typename T> template <typename T>
class VAddKernel : public Kernel { class VAddKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, const T *y, T *z) = 0; virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0;
}; };
template <typename T> template <typename T>
class VScalKernel : public Kernel { class VScalKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T a, const T *x, T *y) = 0; virtual void Compute(const int n, const T a, const T *x, T *y) const = 0;
virtual void Compute(const int n, const T a, T *x) = 0; virtual void Compute(const int n, const T a, T *x) const = 0;
}; };
template <typename T> template <typename T>
class VExpKernel : public Kernel { class VExpKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, T *y) = 0; virtual void Compute(const int n, const T *x, T *y) const = 0;
};
template <typename T>
class VSigmoidKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, T *y) const = 0;
};
template <typename T>
class VTanhKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
......
...@@ -34,7 +34,7 @@ namespace jit = platform::jit; ...@@ -34,7 +34,7 @@ namespace jit = platform::jit;
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VMulKernelImpl : public VMulKernel<T> { class VMulKernelImpl : public VMulKernel<T> {
public: public:
void Compute(const int n, const T* x, const T* y, T* z) override { void Compute(const int n, const T* x, const T* y, T* z) const override {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i]; z[i] = x[i] * y[i];
} }
...@@ -42,33 +42,33 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -42,33 +42,33 @@ class VMulKernelImpl : public VMulKernel<T> {
}; };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \ #define MKL_FLOAT(isa, block) \
template <> \ template <> \
void VMulKernelImpl<float, isa, block>::Compute(const int n, const float* x, \ void VMulKernelImpl<float, isa, block>::Compute( \
const float* y, float* z) { \ const int n, const float* x, const float* y, float* z) const { \
platform::dynload::vsMul(n, x, y, z); \ platform::dynload::vsMul(n, x, y, z); \
} }
#define MKL_DOUBLE(isa, block) \ #define MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VMulKernelImpl<double, isa, block>::Compute( \ void VMulKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \ const int n, const double* x, const double* y, double* z) const { \
platform::dynload::vdMul(n, x, y, z); \ platform::dynload::vdMul(n, x, y, z); \
} }
FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \ void VMulKernelImpl<float, isa, kEQ8>::Compute( \
const float* y, float* z) { \ const int n, const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \ __m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \ tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \ tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \ tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \ _mm256_storeu_ps(z, tmpx); \
} }
// avx > for > mkl // avx > for > mkl
...@@ -90,7 +90,7 @@ INTRI8_FLOAT(jit::avx512f); ...@@ -90,7 +90,7 @@ INTRI8_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddKernelImpl : public VAddKernel<T> { class VAddKernelImpl : public VAddKernel<T> {
public: public:
void Compute(const int n, const T* x, const T* y, T* z) override { void Compute(const int n, const T* x, const T* y, T* z) const override {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i]; z[i] = x[i] + y[i];
} }
...@@ -98,33 +98,33 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -98,33 +98,33 @@ class VAddKernelImpl : public VAddKernel<T> {
}; };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \ #define MKL_FLOAT(isa, block) \
template <> \ template <> \
void VAddKernelImpl<float, isa, block>::Compute(const int n, const float* x, \ void VAddKernelImpl<float, isa, block>::Compute( \
const float* y, float* z) { \ const int n, const float* x, const float* y, float* z) const { \
platform::dynload::vsAdd(n, x, y, z); \ platform::dynload::vsAdd(n, x, y, z); \
} }
#define MKL_DOUBLE(isa, block) \ #define MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VAddKernelImpl<double, isa, block>::Compute( \ void VAddKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \ const int n, const double* x, const double* y, double* z) const { \
platform::dynload::vdAdd(n, x, y, z); \ platform::dynload::vdAdd(n, x, y, z); \
} }
FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \ void VAddKernelImpl<float, isa, kEQ8>::Compute( \
const float* y, float* z) { \ const int n, const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \ __m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \ tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \ tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \ tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \ _mm256_storeu_ps(z, tmpx); \
} }
#ifdef __AVX__ #ifdef __AVX__
INTRI8_FLOAT(jit::avx); INTRI8_FLOAT(jit::avx);
...@@ -145,12 +145,12 @@ INTRI8_FLOAT(jit::avx512f); ...@@ -145,12 +145,12 @@ INTRI8_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VScalKernelImpl : public VScalKernel<T> { class VScalKernelImpl : public VScalKernel<T> {
public: public:
void Compute(const int n, const T a, const T* x, T* y) override { void Compute(const int n, const T a, const T* x, T* y) const override {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a * x[i]; y[i] = a * x[i];
} }
} }
void Compute(const int n, const T a, T* x) override { void Compute(const int n, const T a, T* x) const override {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
x[i] = a * x[i]; x[i] = a * x[i];
} }
...@@ -161,35 +161,35 @@ class VScalKernelImpl : public VScalKernel<T> { ...@@ -161,35 +161,35 @@ class VScalKernelImpl : public VScalKernel<T> {
#define MKL_FLOAT(isa, block) \ #define MKL_FLOAT(isa, block) \
template <> \ template <> \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \ void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
float* x) { \ float* x) const { \
platform::dynload::cblas_sscal(n, a, x, 1); \ platform::dynload::cblas_sscal(n, a, x, 1); \
} }
#define MKL_DOUBLE(isa, block) \ #define MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VScalKernelImpl<double, isa, block>::Compute( \ void VScalKernelImpl<double, isa, block>::Compute( \
const int n, const double a, double* x) { \ const int n, const double a, double* x) const { \
platform::dynload::cblas_dscal(n, a, x, 1); \ platform::dynload::cblas_dscal(n, a, x, 1); \
} }
FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \ void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, float* y) { \ const int n, const float a, const float* x, float* y) const { \
__m256 tmp; \ __m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \ __m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \ tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \ tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
} }
#define INTRI8_INPLACE_FLOAT(isa) \ #define INTRI8_INPLACE_FLOAT(isa) \
template <> \ template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \ void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
float* x) { \ float* x) const { \
__m256 tmp; \ __m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \ __m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \ tmp = _mm256_loadu_ps(x); \
......
...@@ -34,14 +34,13 @@ __m256 Exp(__m256 a); ...@@ -34,14 +34,13 @@ __m256 Exp(__m256 a);
#endif #endif
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
/* VExp JitKernel */ /* VExp JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class VExpKernelImpl : public VExpKernel<T> { class VExpKernelImpl : public VExpKernel<T> {
public: public:
void Compute(const int n, const T* x, T* y) override { void Compute(const int n, const T* x, T* y) const override {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]); y[i] = std::exp(x[i]);
} }
...@@ -52,15 +51,15 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -52,15 +51,15 @@ class VExpKernelImpl : public VExpKernel<T> {
#define MKL_FLOAT(isa, block) \ #define MKL_FLOAT(isa, block) \
template <> \ template <> \
void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \ void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
float* y) { \ float* y) const { \
platform::dynload::vsExp(n, x, y); \ platform::dynload::vsExp(n, x, y); \
} }
#define MKL_DOUBLE(isa, block) \ #define MKL_DOUBLE(isa, block) \
template <> \ template <> \
void VExpKernelImpl<double, isa, block>::Compute( \ void VExpKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, double* y) { \ const int n, const double* x, double* y) const { \
platform::dynload::vdExp(n, x, y); \ platform::dynload::vdExp(n, x, y); \
} }
FOR_EACH_ISA(MKL_FLOAT, kLT8); FOR_EACH_ISA(MKL_FLOAT, kLT8);
FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); FOR_EACH_ISA(MKL_FLOAT, kGT8LT16);
...@@ -71,7 +70,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE); ...@@ -71,7 +70,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \ void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
float* y) { \ float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \ _mm256_storeu_ps(y, detail::Exp(tmp)); \
} }
...@@ -79,7 +78,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE); ...@@ -79,7 +78,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI16_FLOAT(isa) \ #define INTRI16_FLOAT(isa) \
template <> \ template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \ void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \
float* y) { \ float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \ tmp0 = detail::Exp(tmp0); \
...@@ -109,6 +108,38 @@ INTRI16_FLOAT(jit::avx512f); ...@@ -109,6 +108,38 @@ INTRI16_FLOAT(jit::avx512f);
REGISTER_JITKERNEL(vexp, VExpKernel); REGISTER_JITKERNEL(vexp, VExpKernel);
/* VSigmoid JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public:
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
vexp_ = KernelPool::Instance().template Get<VExpKernel<T>>(d);
}
void Compute(const int n, const T* x, T* y) const override {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
vexp_->Compute(n, y, y);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
private:
std::shared_ptr<const VExpKernel<T>> vexp_;
};
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE,
JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL);
#undef JITKERNEL_NEW_ACT_IMPL
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -23,51 +23,68 @@ namespace jitkernel { ...@@ -23,51 +23,68 @@ namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
#define NEW_JITKERNEL_IMPL(src, t, isa, k) \ #define SEARCH_BLOCK(macro_, ker, dtype, isa) \
p = std::dynamic_pointer_cast<src<t>>( \
std::make_shared<src##Impl<t, isa, k>>())
#define SEARCH_BLOCK(src, t, isa) \
if (d < AVX_FLOAT_BLOCK) { \ if (d < AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kLT8); \ macro_(ker, dtype, isa, kLT8); \
} else if (d == AVX_FLOAT_BLOCK) { \ } else if (d == AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ8); \ macro_(ker, dtype, isa, kEQ8); \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT8LT16); \ macro_(ker, dtype, isa, kGT8LT16); \
} else if (d == AVX512_FLOAT_BLOCK) { \ } else if (d == AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ16); \ macro_(ker, dtype, isa, kEQ16); \
} else { \ } else { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT16); \ macro_(ker, dtype, isa, kGT16); \
} }
#define SEARCH_ISA_BLOCK(src, t) \ #define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
if (jit::MayIUse(jit::avx512f)) { \ if (jit::MayIUse(jit::avx512f)) { \
SEARCH_BLOCK(src, t, jit::avx512f); \ SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \ } else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(src, t, jit::avx2); \ SEARCH_BLOCK(macro_, ker, dtype, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \ } else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(src, t, jit::avx); \ SEARCH_BLOCK(macro_, ker, dtype, jit::avx); \
} else { \ } else { \
SEARCH_BLOCK(src, t, jit::isa_any); \ SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
} }
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \ #define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \ template <> \
const std::shared_ptr<ker_class<ker_dtype>> \ std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \ KernelPool::Get<ker_class<ker_dtype>, int>(int d)
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \ #define JITKERNEL_KEY(ker_key, dtype_key) \
std::shared_ptr<ker_class<ker_dtype>> p; \ #ker_key #dtype_key + std::to_string(d)
SEARCH_ISA_BLOCK(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \ #define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
return p; \ p = std::dynamic_pointer_cast<ker<dtype>>( \
} \ std::make_shared<ker##Impl<dtype, isa, k>>())
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
} }
#define REGISTER_JITKERNEL(ker_key, ker_class) \ #define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f); \ JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d) JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL)
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \
macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \
macro_impl); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \
macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \ #define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \ macro_(jit::avx512f, block); \
......
...@@ -388,16 +388,16 @@ TEST(JitKernel, pool) { ...@@ -388,16 +388,16 @@ TEST(JitKernel, pool) {
const auto& pvmul_f = const auto& pvmul_f =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4); jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(plstm2) != EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(plstm2) !=
std::dynamic_pointer_cast<jit::Kernel>(pvmul_f)); std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f));
const auto& pvmul_d = const auto& pvmul_d =
jit::KernelPool::Instance().template Get<jit::VMulKernel<double>>(4); jit::KernelPool::Instance().template Get<jit::VMulKernel<double>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(pvmul_f) != EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f) !=
std::dynamic_pointer_cast<jit::Kernel>(pvmul_d)); std::dynamic_pointer_cast<const jit::Kernel>(pvmul_d));
const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulf4"); const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulf4");
EXPECT_TRUE(pvmul_f == pvmul_from_key); EXPECT_EQ(pvmul_f, pvmul_from_key);
const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulf5"); const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulf5");
EXPECT_TRUE(pvmul_from_key2 == nullptr); EXPECT_TRUE(pvmul_from_key2 == nullptr);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册