diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index 452a79e4907f30e5af408d4a2a4f7cb11d770c1d..81b56ef2e8aa083da270171be462cc5a7ba73507 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -16,23 +16,132 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/platform/cpu_info.h" + +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef __AVX__ +#include +#endif namespace paddle { namespace operators { namespace math { namespace jitkernel { +namespace jit = platform::jit; + KernelPool& KernelPool::Instance() { static KernelPool g_jit_kernels; return g_jit_kernels; } +#define SEARCH_BLOCK(src, t, isa) \ + if (d < AVX_FLOAT_BLOCK) { \ + Compute = src; \ + } else if (d == AVX_FLOAT_BLOCK) { \ + Compute = src; \ + } else if (d == AVX512_FLOAT_BLOCK) { \ + Compute = src; \ + } else { \ + Compute = src; \ + } + +#define SEARCH_ISA_BLOCK(src, t) \ + if (jit::MayIUse(jit::avx512_common)) { \ + SEARCH_BLOCK(src, t, jit::avx512_common); \ + } else if (jit::MayIUse(jit::avx2)) { \ + SEARCH_BLOCK(src, t, jit::avx2); \ + } else if (jit::MayIUse(jit::avx)) { \ + SEARCH_BLOCK(src, t, jit::avx); \ + } else { \ + SEARCH_BLOCK(src, t, jit::isa_any); \ + } + +#define FOR_EACH_BLOCK(macro_, isa) \ + macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kEQ16) macro_(isa, kGT16) + +#define FOR_EACH_ISA_BLOCK(macro_) \ + FOR_EACH_BLOCK(macro_, jit::avx512_common) \ + FOR_EACH_BLOCK(macro_, jit::avx2) \ + FOR_EACH_BLOCK(macro_, jit::avx) \ + FOR_EACH_BLOCK(macro_, jit::any) + +#define VMUL_ANY \ + for (int i = 0; i < n; ++i) { \ + z[i] = x[i] * y[i]; \ + } + +template +static void VMulCompute(const int n, const T* x, const T* y, T* z) { + VMUL_ANY +} + +#ifdef PADDLE_USE_MKLML +#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \ + template <> \ + static void VMulCompute(const int n, const float* x, \ + const float* y, float* z) { \ + platform::dynload::vsMul(n, x, y, z); \ + } + +#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \ + template <> \ + static void VMulCompute(const int n, const double* x, \ + const double* y, float* z) { \ + platform::dynload::vdMul(n, x, y, z); \ + } + +FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT) +FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE) +// TODO(TJ): add EQ8 +#endif + +#undef DEFINE_VMUL_COMPUTE_FLOAT +#undef DEFINE_VMUL_COMPUTE_DOUBLE +#undef VMUL_ANY + +template <> +VMulKernel::VMulKernel(int d) { + SEARCH_ISA_BLOCK(VMulCompute, float); +} + +template <> +VMulKernel::VMulKernel(int d) { + SEARCH_ISA_BLOCK(VMulCompute, double); +} + +template <> +const std::shared_ptr> KernelPool::Get>( + int d) { + std::string key = "f" + std::to_string(d); + if (kers_.find(key) == kers_.end()) { + auto p = std::make_shared>(d); + kers_.insert({key, std::dynamic_pointer_cast(p)}); + return p; + } + return std::dynamic_pointer_cast>(kers_.at(key)); +} + +template <> +const std::shared_ptr> KernelPool::Get>( + int d) { + std::string key = "d" + std::to_string(d); + if (kers_.find(key) == kers_.end()) { + auto p = std::make_shared>(d); + kers_.insert({key, std::dynamic_pointer_cast(p)}); + return p; + } + return std::dynamic_pointer_cast>(kers_.at(key)); +} template <> LSTMKernel::LSTMKernel(int d, const std::string& act_gate_str, const std::string& act_cand_str, const std::string& act_cell_str) : Kernel(), d_(d) { + d2_ = d * 2; + d3_ = d * 3; if (platform::jit::MayIUse(platform::jit::avx512_common)) { math::VecActivations act_functor; act_gate_ = act_functor(act_gate_str); @@ -48,6 +157,22 @@ LSTMKernel::LSTMKernel(int d, const std::string& act_gate_str, act_gate_ = act_functor(act_gate_str); act_cell_ = act_functor(act_cell_str); act_cand_ = act_functor(act_cand_str); + // ComputeCtHt = [&](float*gates,const float*ct_1,float*ct, float*ht) { + // // gates: W_ch, W_ih, W_fh, W_oh + // act_gate(d3_, gates + d_, gates + d_); + + // /* C_t = C_t-1 * fgated + cand_gated * igated */ + // act_cand(d_, gates, gates); + // blas.VMUL(d_, gates, gates + d_, gates + d_); + // blas.VMUL(d_, ct_1, gates + d2_, gates + d2_); + // blas.VADD(d_, gates + d_, gates + d2_, ct); + + // /* H_t = act_cell(C_t) * ogated */ + // act_cell(d_, ct, gates + d2_); + // blas.VMUL(d_, gates + d2_, gates + d3_, ht) + // GET_Ct(ct_1, gates, ct); + // GET_Ht(ct, gates, ht); + // }; } else { math::VecActivations act_functor; act_gate_ = act_functor(act_gate_str); diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 29aac71060f520b6c8c4b3c813f6da0062e1a983..b65653498318625d8abb17e9f846aaa2d56ac4eb 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -17,6 +17,7 @@ limitations under the License. */ #include // for shared_ptr #include #include +#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/macros.h" // Note: Only support on CPU yet. @@ -25,6 +26,18 @@ namespace operators { namespace math { namespace jitkernel { +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 + +#define AVX_FLOAT_BLOCK 8 +#define AVX_DOUBLE_BLOCK 4 +#define AVX2_FLOAT_BLOCK 8 +#define AVX2_DOUBLE_BLOCK 4 +#define AVX512_FLOAT_BLOCK 16 +#define AVX512_DOUBLE_BLOCK 8 + +typedef enum { kLT8, kEQ8, kEQ16, kGT16 } jit_block; + class Kernel { public: Kernel() {} @@ -36,7 +49,7 @@ class Kernel { class KernelPool { public: - static KernelPool& Instance(); + static KernelPool &Instance(); template const std::shared_ptr Get(ARGS... args); @@ -48,17 +61,24 @@ class KernelPool { DISABLE_COPY_AND_ASSIGN(KernelPool); }; +template +class VMulKernel : public Kernel { + public: + explicit VMulKernel(int n); + void (*Compute)(const int n, const T *, const T *, T *); +}; + template class LSTMKernel : public Kernel { public: - explicit LSTMKernel(int d, const std::string& act_gate, - const std::string& act_cand, const std::string& act_cell); + explicit LSTMKernel(int d, const std::string &act_gate, + const std::string &act_cand, const std::string &act_cell); - void ComputeCtHt(T* gates, const T* ct_1, T* ct); - void ComputeCtHt_NoC0H0(T* gates, const T* ct_1, T* ct); + void (*jit_ker)(T *, const T *, T *, T *); + std::function ComputeCtHt, ComputeCtHt_NoC0H0; private: - int d_; + int d_, d2_, d3_; std::function act_gate_, act_cell_, act_cand_; }; diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 15193f0d940a9abac4f7d1be642001c2674fda1f..041234442d372df8deacb3663cff1fef478bac2c 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -23,10 +23,25 @@ TEST(JitKernel, pool) { namespace jit = paddle::operators::math::jitkernel; const int frame_size = 4; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; - const auto& t = + const auto& p1 = jit::KernelPool::Instance() .template Get, int, const std::string&, const std::string&, const std::string&>( frame_size, act_gate, act_cand, act_cell); - LOG(INFO) << t; + const auto& p2 = + jit::KernelPool::Instance() + .template Get, int, const std::string&, + const std::string&, const std::string&>( + frame_size, act_gate, act_cand, act_cell); + EXPECT_EQ(p1, p2); + + const auto& p3 = + jit::KernelPool::Instance().template Get>(4); + EXPECT_TRUE(std::dynamic_pointer_cast(p2) != + std::dynamic_pointer_cast(p3)); + + const auto& p4 = + jit::KernelPool::Instance().template Get>(4); + EXPECT_TRUE(std::dynamic_pointer_cast(p3) != + std::dynamic_pointer_cast(p4)); }