From b9acbcc8c525fba28a14c6a04640950a96c65bd1 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 18 Sep 2018 00:27:41 +0800 Subject: [PATCH] init lstm kernel --- paddle/fluid/operators/math/jit_kernel.cc | 40 ++++++++++++++++++- paddle/fluid/operators/math/jit_kernel.h | 27 +++++++++++-- paddle/fluid/operators/math/jit_kernel_impl.h | 7 +--- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index 83fb1b38b7c..452a79e4907 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" +#include #include +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -25,13 +28,48 @@ KernelPool& KernelPool::Instance() { return g_jit_kernels; } +template <> +LSTMKernel::LSTMKernel(int d, const std::string& act_gate_str, + const std::string& act_cand_str, + const std::string& act_cell_str) + : Kernel(), d_(d) { + if (platform::jit::MayIUse(platform::jit::avx512_common)) { + math::VecActivations act_functor; + act_gate_ = act_functor(act_gate_str); + act_cell_ = act_functor(act_cell_str); + act_cand_ = act_functor(act_cand_str); + } else if (platform::jit::MayIUse(platform::jit::avx2)) { + math::VecActivations act_functor; + act_gate_ = act_functor(act_gate_str); + act_cell_ = act_functor(act_cell_str); + act_cand_ = act_functor(act_cand_str); + } else if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate_ = act_functor(act_gate_str); + act_cell_ = act_functor(act_cell_str); + act_cand_ = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate_ = act_functor(act_gate_str); + act_cell_ = act_functor(act_cell_str); + act_cand_ = act_functor(act_cand_str); + } +} + template <> const std::shared_ptr> KernelPool::Get, int, const std::string&, const std::string&, const std::string&>(int d, const std::string& act_gate, const std::string& act_cand, const std::string& act_cell) { - return nullptr; + std::string key = "f" + std::to_string(d) + act_gate + act_cand + act_cell; + if (kers_.find(key) == kers_.end()) { + auto p = + std::make_shared>(d, act_gate, act_cand, act_cell); + kers_.insert({key, std::dynamic_pointer_cast(p)}); + return p; + } + return std::dynamic_pointer_cast>(kers_.at(key)); } } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index cfe4e8b0788..29aac71060f 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -14,10 +14,9 @@ limitations under the License. */ #pragma once #include -#include #include // for shared_ptr #include -#include +#include #include "paddle/fluid/platform/macros.h" // Note: Only support on CPU yet. @@ -27,23 +26,43 @@ namespace math { namespace jitkernel { class Kernel { + public: + Kernel() {} + virtual ~Kernel() = default; + + private: DISABLE_COPY_AND_ASSIGN(Kernel); }; class KernelPool { public: - static KernelPool &Instance(); + static KernelPool& Instance(); template const std::shared_ptr Get(ARGS... args); private: KernelPool() = default; - // std::unordered_map kers_; + std::unordered_map> kers_; DISABLE_COPY_AND_ASSIGN(KernelPool); }; +template +class LSTMKernel : public Kernel { + public: + explicit LSTMKernel(int d, const std::string& act_gate, + const std::string& act_cand, const std::string& act_cell); + + void ComputeCtHt(T* gates, const T* ct_1, T* ct); + void ComputeCtHt_NoC0H0(T* gates, const T* ct_1, T* ct); + + private: + int d_; + std::function act_gate_, act_cell_, + act_cand_; +}; + } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_impl.h b/paddle/fluid/operators/math/jit_kernel_impl.h index 9c11143da6d..46fef31ff03 100644 --- a/paddle/fluid/operators/math/jit_kernel_impl.h +++ b/paddle/fluid/operators/math/jit_kernel_impl.h @@ -21,12 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { -namespace jitkernel { - -template -class LSTMKernel : public Kernel {}; - -} // namespace jitkernel +namespace jitkernel {} // namespace jitkernel } // namespace math } // namespace operators } // namespace paddle -- GitLab