diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index abaa9237c09766fbe6bf1c2e52670f72dd4d3463..0ba51012c4f974ccfa99b77194c0babcc0ea8864 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -400,10 +400,9 @@ class FuisonLSTMKernel : public framework::OpKernel { } else { const auto& ker = math::jitkernel::KernelPool::Instance() - .template Get, int, - const std::string&, const std::string&, - const std::string&>(D, act_gate_str, act_cand_str, - act_cell_str); + .template Get, const std::string&, + const std::string&, const std::string&>( + act_gate_str, act_cand_str, act_cell_str, D, false); for (int i = 0; i < N; ++i) { PROCESS_H0C0 @@ -545,10 +544,9 @@ class FuisonLSTMKernel : public framework::OpKernel { } else { const auto& ker = math::jitkernel::KernelPool::Instance() - .template Get, int, - const std::string&, const std::string&, - const std::string&>(D, act_gate_str, act_cand_str, - act_cell_str); + .template Get, const std::string&, + const std::string&, const std::string&>( + act_gate_str, act_cand_str, act_cell_str, D, false); for (int step = tstart; step < max_seq_len; ++step) { const int cur_bs = batch_starts[step + 1] - batch_starts[step]; diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 6edfdf22d174c3e17d201d04798c86d76375c028..aeb439bb8615b21cefe64a8df0be34196036ee6e 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -125,7 +125,8 @@ class VTanhKernel : public VActKernel { template class LSTMKernel : public Kernel { public: - virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht) const = 0; + virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht, + T *checked = nullptr) const = 0; }; } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_lstm.cc index 71531d833dfac78702b9183b6e9116d10a8c9d43..17e2d1fbb44093bc9e0941fb2e578269398f3573 100644 --- a/paddle/fluid/operators/math/jit_kernel_lstm.cc +++ b/paddle/fluid/operators/math/jit_kernel_lstm.cc @@ -86,9 +86,9 @@ __m256 AVXActImpl::Compute(__m256 x) const { template class LSTMKernelImpl : public LSTMKernel { public: - explicit LSTMKernelImpl(int d, const std::string& act_gate, + explicit LSTMKernelImpl(const std::string& act_gate, const std::string& act_cand, - const std::string& act_cell) + const std::string& act_cell, int d) : LSTMKernel() { d_ = d; d2_ = d * 2; @@ -134,7 +134,8 @@ class LSTMKernelImpl : public LSTMKernel { #endif } - void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override { + void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, + T* checked) const override { // gates: W_ch, W_ih, W_fh, W_oh act_gate_3d_->Compute(gates + d_, gates + d_); @@ -162,7 +163,8 @@ class LSTMKernelImpl : public LSTMKernel { #define INTRI8_FLOAT(isa) \ template <> \ void LSTMKernelImpl::ComputeCtHt( \ - float* gates, const float* ct_1, float* ct, float* ht) const { \ + float* gates, const float* ct_1, float* ct, float* ht, float* checked) \ + const { \ /* gates: W_ch, W_ih, W_fh, W_oh */ \ __m256 c, i, f, o; \ c = _mm256_loadu_ps(gates); \ @@ -192,21 +194,86 @@ INTRI8_FLOAT(jit::avx2); INTRI8_FLOAT(jit::avx512f); #endif -#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ - template <> \ - std::shared_ptr> \ - KernelPool::Get, int, const std::string&, \ - const std::string&, const std::string&>( \ - int d, const std::string& act_gate, const std::string& act_cand, \ - const std::string& act_cell) +/* Peephole JitKernel */ +template +class PeepholeKernelImpl : public LSTMKernel { + public: + explicit PeepholeKernelImpl(const std::string& act_gate, + const std::string& act_cand, + const std::string& act_cell, int d) + : LSTMKernel() { + d_ = d; + d2_ = d * 2; + d3_ = d * 3; + auto GetActKernel = [&](const std::string& type, + int n) -> std::shared_ptr> { + if (type == "sigmoid") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "relu") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "tanh") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "identity" || type == "") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } + PADDLE_THROW("Not support type: %s", type); + }; + act_gate_3d_ = GetActKernel(act_gate, d * 3); + act_cand_d_ = GetActKernel(act_cand, d); + act_cell_d_ = GetActKernel(act_cell, d); + vmul_d_ = KernelPool::Instance().template Get>(d); + vadd_d_ = KernelPool::Instance().template Get>(d); + } + + void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, + T* checked) const override { + // gates: W_ch, W_ih, W_fh, W_oh + act_gate_3d_->Compute(gates + d_, gates + d_); + + /* C_t = C_t-1 * fgated + cand_gated * igated */ + act_cand_d_->Compute(gates, gates); + vmul_d_->Compute(gates, gates + d_, gates + d_); + vmul_d_->Compute(ct_1, gates + d2_, gates + d2_); + vadd_d_->Compute(gates + d_, gates + d2_, ct); + + /* H_t = act_cell(C_t) * ogated */ + act_cell_d_->Compute(ct, gates + d2_); + vmul_d_->Compute(gates + d2_, gates + d3_, ht); + } -#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \ - #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + private: + int d_, d2_, d3_; + std::shared_ptr> act_gate_3d_, act_cand_d_, act_cell_d_; + std::shared_ptr> vmul_d_; + std::shared_ptr> vadd_d_; +}; + +#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> \ + KernelPool::Get, const std::string&, \ + const std::string&, const std::string&, int, bool>( \ + const std::string& act_gate, const std::string& act_cand, \ + const std::string& act_cell, int d, bool use_peephole) -#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(d, act_gate, act_cand, \ - act_cell)) +#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \ + #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \ + (use_peephole ? "p" : "n") + +#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \ + if (use_peephole) { \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>( \ + act_gate, act_cand, act_cell, d)); \ + } else { \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(act_gate, act_cand, \ + act_cell, d)); \ + } REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); @@ -215,7 +282,6 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, #undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_KEY_LSTM #undef JITKERNEL_NEW_LSTM_IMPL - } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index d65a3299c5bc4e87f50a58042221b52553533f0c..26590171bbeaa385ac09b04e5faf483924176598 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -390,9 +390,9 @@ TEST(JitKernel, lstm) { std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; const auto& ker = jit::KernelPool::Instance() - .template Get, int, const std::string&, + .template Get, const std::string&, const std::string&, const std::string&>( - d, act_gate, act_cand, act_cell); + act_gate, act_cand, act_cell, d, false); // below kernels are used to compute refer const auto& vsigmoid_3d = jit::KernelPool::Instance().template Get>( @@ -717,15 +717,20 @@ TEST(JitKernel, pool) { std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; const auto& plstm1 = jit::KernelPool::Instance() - .template Get, int, const std::string&, + .template Get, const std::string&, const std::string&, const std::string&>( - frame_size, act_gate, act_cand, act_cell); + act_gate, act_cand, act_cell, frame_size, false); const auto& plstm2 = jit::KernelPool::Instance() - .template Get, int, const std::string&, + .template Get, const std::string&, const std::string&, const std::string&>( - frame_size, act_gate, act_cand, act_cell); - EXPECT_EQ(plstm1, plstm2); + act_gate, act_cand, act_cell, frame_size, false); + const auto& peephole = + jit::KernelPool::Instance() + .template Get, const std::string&, + const std::string&, const std::string&>( + act_gate, act_cand, act_cell, frame_size, true); + EXPECT_TRUE(plstm1 != peephole); const auto& pvmul_f = jit::KernelPool::Instance().template Get>(4);