提交 55e44761 编写于 作者: T tensor-tang

refine code and init vsigmoid

上级 2d0ff6a3
......@@ -28,7 +28,7 @@ KernelPool& KernelPool::Instance() {
return g_jit_kernels;
}
const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
if (kers_.find(key) == kers_.end()) {
return nullptr;
}
......@@ -36,7 +36,7 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
}
template <>
const std::shared_ptr<LSTMKernel<float>>
std::shared_ptr<const LSTMKernel<float>>
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
const std::string&>(int d, const std::string& act_gate,
const std::string& act_cand,
......@@ -49,7 +49,7 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p;
}
return std::dynamic_pointer_cast<LSTMKernel<float>>(kers_.at(key));
return std::dynamic_pointer_cast<const LSTMKernel<float>>(kers_.at(key));
}
} // namespace jitkernel
......
......@@ -52,13 +52,13 @@ class KernelPool {
static KernelPool &Instance();
template <typename Ker, typename... ARGS>
const std::shared_ptr<Ker> Get(ARGS... args);
std::shared_ptr<const Ker> Get(ARGS... args);
const std::shared_ptr<Kernel> Get(const std::string &key) const;
std::shared_ptr<const Kernel> Get(const std::string &key) const;
private:
KernelPool() = default;
std::unordered_map<std::string, std::shared_ptr<Kernel>> kers_;
std::unordered_map<std::string, std::shared_ptr<const Kernel>> kers_;
DISABLE_COPY_AND_ASSIGN(KernelPool);
};
......@@ -66,26 +66,38 @@ class KernelPool {
template <typename T>
class VMulKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, const T *y, T *z) = 0;
virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0;
};
template <typename T>
class VAddKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, const T *y, T *z) = 0;
virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0;
};
template <typename T>
class VScalKernel : public Kernel {
public:
virtual void Compute(const int n, const T a, const T *x, T *y) = 0;
virtual void Compute(const int n, const T a, T *x) = 0;
virtual void Compute(const int n, const T a, const T *x, T *y) const = 0;
virtual void Compute(const int n, const T a, T *x) const = 0;
};
template <typename T>
class VExpKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, T *y) = 0;
virtual void Compute(const int n, const T *x, T *y) const = 0;
};
template <typename T>
class VSigmoidKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, T *y) const = 0;
};
template <typename T>
class VTanhKernel : public Kernel {
public:
virtual void Compute(const int n, const T *x, T *y) const = 0;
};
template <typename T>
......
......@@ -34,7 +34,7 @@ namespace jit = platform::jit;
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VMulKernelImpl : public VMulKernel<T> {
public:
void Compute(const int n, const T* x, const T* y, T* z) override {
void Compute(const int n, const T* x, const T* y, T* z) const override {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
......@@ -42,33 +42,33 @@ class VMulKernelImpl : public VMulKernel<T> {
};
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VMulKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsMul(n, x, y, z); \
#define MKL_FLOAT(isa, block) \
template <> \
void VMulKernelImpl<float, isa, block>::Compute( \
const int n, const float* x, const float* y, float* z) const { \
platform::dynload::vsMul(n, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \
platform::dynload::vdMul(n, x, y, z); \
#define MKL_DOUBLE(isa, block) \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) const { \
platform::dynload::vdMul(n, x, y, z); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif
#define INTRI8_FLOAT(isa) \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
const float* y, float* z) { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
#define INTRI8_FLOAT(isa) \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
// avx > for > mkl
......@@ -90,7 +90,7 @@ INTRI8_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddKernelImpl : public VAddKernel<T> {
public:
void Compute(const int n, const T* x, const T* y, T* z) override {
void Compute(const int n, const T* x, const T* y, T* z) const override {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
......@@ -98,33 +98,33 @@ class VAddKernelImpl : public VAddKernel<T> {
};
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VAddKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsAdd(n, x, y, z); \
#define MKL_FLOAT(isa, block) \
template <> \
void VAddKernelImpl<float, isa, block>::Compute( \
const int n, const float* x, const float* y, float* z) const { \
platform::dynload::vsAdd(n, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VAddKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \
platform::dynload::vdAdd(n, x, y, z); \
#define MKL_DOUBLE(isa, block) \
template <> \
void VAddKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) const { \
platform::dynload::vdAdd(n, x, y, z); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif
#define INTRI8_FLOAT(isa) \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
const float* y, float* z) { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
#define INTRI8_FLOAT(isa) \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
......@@ -145,12 +145,12 @@ INTRI8_FLOAT(jit::avx512f);
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VScalKernelImpl : public VScalKernel<T> {
public:
void Compute(const int n, const T a, const T* x, T* y) override {
void Compute(const int n, const T a, const T* x, T* y) const override {
for (int i = 0; i < n; ++i) {
y[i] = a * x[i];
}
}
void Compute(const int n, const T a, T* x) override {
void Compute(const int n, const T a, T* x) const override {
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
......@@ -161,35 +161,35 @@ class VScalKernelImpl : public VScalKernel<T> {
#define MKL_FLOAT(isa, block) \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
float* x) { \
float* x) const { \
platform::dynload::cblas_sscal(n, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute( \
const int n, const double a, double* x) { \
platform::dynload::cblas_dscal(n, a, x, 1); \
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute( \
const int n, const double a, double* x) const { \
platform::dynload::cblas_dscal(n, a, x, 1); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
const float* x, float* y) { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const int n, const float a, const float* x, float* y) const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
float* x) { \
float* x) const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
......
......@@ -34,14 +34,13 @@ __m256 Exp(__m256 a);
#endif
namespace jitkernel {
namespace jit = platform::jit;
/* VExp JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VExpKernelImpl : public VExpKernel<T> {
public:
void Compute(const int n, const T* x, T* y) override {
void Compute(const int n, const T* x, T* y) const override {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
......@@ -52,15 +51,15 @@ class VExpKernelImpl : public VExpKernel<T> {
#define MKL_FLOAT(isa, block) \
template <> \
void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
float* y) { \
float* y) const { \
platform::dynload::vsExp(n, x, y); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VExpKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, double* y) { \
platform::dynload::vdExp(n, x, y); \
#define MKL_DOUBLE(isa, block) \
template <> \
void VExpKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, double* y) const { \
platform::dynload::vdExp(n, x, y); \
}
FOR_EACH_ISA(MKL_FLOAT, kLT8);
FOR_EACH_ISA(MKL_FLOAT, kGT8LT16);
......@@ -71,7 +70,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
float* y) { \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \
}
......@@ -79,7 +78,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI16_FLOAT(isa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \
float* y) { \
float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \
......@@ -109,6 +108,38 @@ INTRI16_FLOAT(jit::avx512f);
REGISTER_JITKERNEL(vexp, VExpKernel);
/* VSigmoid JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public:
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
vexp_ = KernelPool::Instance().template Get<VExpKernel<T>>(d);
}
void Compute(const int n, const T* x, T* y) const override {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
vexp_->Compute(n, y, y);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
private:
std::shared_ptr<const VExpKernel<T>> vexp_;
};
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE,
JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL);
#undef JITKERNEL_NEW_ACT_IMPL
} // namespace jitkernel
} // namespace math
} // namespace operators
......
......@@ -23,51 +23,68 @@ namespace jitkernel {
namespace jit = platform::jit;
#define NEW_JITKERNEL_IMPL(src, t, isa, k) \
p = std::dynamic_pointer_cast<src<t>>( \
std::make_shared<src##Impl<t, isa, k>>())
#define SEARCH_BLOCK(src, t, isa) \
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kLT8); \
macro_(ker, dtype, isa, kLT8); \
} else if (d == AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ8); \
macro_(ker, dtype, isa, kEQ8); \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT8LT16); \
macro_(ker, dtype, isa, kGT8LT16); \
} else if (d == AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ16); \
macro_(ker, dtype, isa, kEQ16); \
} else { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT16); \
macro_(ker, dtype, isa, kGT16); \
}
#define SEARCH_ISA_BLOCK(src, t) \
if (jit::MayIUse(jit::avx512f)) { \
SEARCH_BLOCK(src, t, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(src, t, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(src, t, jit::avx); \
} else { \
SEARCH_BLOCK(src, t, jit::isa_any); \
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
if (jit::MayIUse(jit::avx512f)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx); \
} else { \
SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
}
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>())
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL)
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \
macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \
macro_impl); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \
macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \
......
......@@ -388,16 +388,16 @@ TEST(JitKernel, pool) {
const auto& pvmul_f =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(plstm2) !=
std::dynamic_pointer_cast<jit::Kernel>(pvmul_f));
EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(plstm2) !=
std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f));
const auto& pvmul_d =
jit::KernelPool::Instance().template Get<jit::VMulKernel<double>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(pvmul_f) !=
std::dynamic_pointer_cast<jit::Kernel>(pvmul_d));
EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f) !=
std::dynamic_pointer_cast<const jit::Kernel>(pvmul_d));
const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulf4");
EXPECT_TRUE(pvmul_f == pvmul_from_key);
EXPECT_EQ(pvmul_f, pvmul_from_key);
const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulf5");
EXPECT_TRUE(pvmul_from_key2 == nullptr);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册