diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 0959539068eef5b550a8e3997d3f11ea67ae0707..8021a896ceaa808f5754b5d165aa0c8c8cb8034a 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel { const int D = wh_dims[0]; \ const int D4 = wh_dims[1] -#define INIT_OTHER_DEFINES \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - /* diagonal weight*/ \ - const T* wp_data = bias->data() + D4; \ - /* for peephole only*/ \ - T* checked_cell_data = nullptr; \ - auto place = ctx.GetPlace(); \ - if (use_peepholes) { \ - /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ - auto* checked_cell = ctx.Output("CheckedCell"); \ - checked_cell_data = checked_cell->mutable_data(place); \ - } \ - const auto& ker = \ - math::jitkernel::KernelPool::Instance() \ - .template Get, const std::string&, \ - const std::string&, const std::string&>( \ - ctx.Attr("gate_activation"), \ - ctx.Attr("candidate_activation"), \ - ctx.Attr("cell_activation"), D, use_peepholes) +#define INIT_OTHER_DEFINES \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wp_data = bias->data() + D4; \ + /* for peephole only*/ \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + auto* checked_cell = ctx.Output("CheckedCell"); \ + checked_cell_data = checked_cell->mutable_data(place); \ + } \ + const math::jitkernel::lstm_attr_t attr( \ + D, ctx.Attr("gate_activation"), \ + ctx.Attr("candidate_activation"), \ + ctx.Attr("cell_activation"), use_peepholes); \ + math::jitkernel::lstm_t one_step; \ + one_step.wp = wp_data; \ + one_step.checked = checked_cell_data; \ + const auto& ker = \ + math::jitkernel::KernelPool::Instance() \ + .template Get, \ + const math::jitkernel::lstm_attr_t&>(attr) // Wh GEMM #define GEMM_WH_ADDON(bs, prev, out) \ @@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_h_data = h0_data + bid * D; prev_c_data = c0_data + bid * D; } else { - ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data); + one_step.gates = xx_data; + one_step.ct = c_out_data; + one_step.ht = h_out_data; + ker->ComputeC1H1(&one_step, &attr); tstart = 1; // move one step prev_h_data = h_out_data; @@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel { } for (int step = tstart; step < seq_len; ++step) { GEMM_WH_ADDON(1, prev_h_data, xx_data); - ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data, - checked_cell_data); + + one_step.gates = xx_data; + one_step.ct_1 = prev_c_data; + one_step.ct = c_out_data; + one_step.ht = h_out_data; + ker->ComputeCtHt(&one_step, &attr); // move one step prev_h_data = h_out_data; prev_c_data = c_out_data; @@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel { T* cur_h_out_data = batched_h_out_data; T* cur_c_out_data = batched_c_out_data; for (int i = 0; i < max_bs; ++i) { - ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data); + one_step.gates = cur_in_data; + one_step.ct = cur_c_out_data; + one_step.ht = cur_h_out_data; + ker->ComputeC1H1(&one_step, &attr); + cur_in_data += D4; cur_c_out_data += D; cur_h_out_data += D; @@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel { T* cur_c_out_data = batched_c_out_data; T* cur_h_out_data = batched_h_out_data; for (int i = 0; i < cur_bs; ++i) { - ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, - cur_h_out_data, wp_data, checked_cell_data); + one_step.gates = cur_in_data; + one_step.ct_1 = cur_prev_c_data; + one_step.ct = cur_c_out_data; + one_step.ht = cur_h_out_data; + ker->ComputeCtHt(&one_step, &attr); + // move one batch cur_in_data += D4; cur_prev_c_data += D; diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 418c8433625cb17d43bc6a55f9c746d6a22b3b82..ccc9206f5cda888ea11be9f5492570340dec87db 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -233,7 +233,7 @@ void LSTMJitCode::generate() { vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]); act(ymm_i, ymm_src, act_gate_); vmulps(ymm_c, ymm_c, ymm_i); - if (first_) { + if (!compute_c1h1_) { // f vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]); act(ymm_f, ymm_src, act_gate_); @@ -242,8 +242,8 @@ void LSTMJitCode::generate() { vaddps(ymm_f, ymm_f, ymm_c); } /* H_t = act_cell(C_t) * ogated */ - ymm_t ymm_ct = first_ ? ymm_c : ymm_f; - ymm_t ymm_o = first_ ? ymm_f : ymm_c; + ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; + ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c; ymm_t ymm_tmp = ymm_i; act(ymm_tmp, ymm_ct, act_cell_); vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]); diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 9782f5414c7de33af785810c164a5b89edc171c6..bf28d444b7712a5bac32985df31b5d24b94837a1 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -319,6 +319,12 @@ class LSTMJitCode : public VActJitCode { public: const char* name() const override { std::string base = "LSTMJitCode"; + if (use_peephole_) { + base += "_Peephole"; + } + if (compute_c1h1_) { + base += "_C1H1"; + } auto AddTypeStr = [&](operand_type type) { switch (type) { case operand_type::relu: @@ -340,30 +346,42 @@ class LSTMJitCode : public VActJitCode { break; } }; - if (first_) { - base += "_C1H1"; - } AddTypeStr(act_gate_); AddTypeStr(act_cand_); AddTypeStr(act_cell_); return base.c_str(); } - explicit LSTMJitCode(int d, bool first, operand_type act_gate, - operand_type act_cand, operand_type act_cell, + explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr, size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : VActJitCode(d, act_gate, code_size, code_ptr), - num_(d), - first_(first), - act_gate_(act_gate), - act_cand_(act_cand), - act_cell_(act_cell) {} + : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, + code_ptr), + compute_c1h1_(compute_c1h1) { + auto typeExchange = [](const std::string& type) -> gen::operand_type { + if (type == "sigmoid") { + return operand_type::sigmoid; + } else if (type == "relu") { + return operand_type::relu; + } else if (type == "tanh") { + return operand_type::tanh; + } else if (type == "identity" || type == "") { + return operand_type::identity; + } // else throw error + return operand_type::identity; + }; + num_ = attr.d; + use_peephole_ = attr.use_peephole; + act_gate_ = typeExchange(attr.act_gate); + act_cand_ = typeExchange(attr.act_cand); + act_cell_ = typeExchange(attr.act_cell); + } static bool init(int d); void generate() override; protected: int num_; - bool first_; + bool compute_c1h1_; + bool use_peephole_; operand_type act_gate_; operand_type act_cand_; operand_type act_cell_; diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 36199eddaf5f1a4b401804d8c65b574d4b74a57a..bb5ba5813a7e6df1e81cc4f2f95b4a86d093fed2 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -122,18 +122,9 @@ class VTanhKernel : public VActKernel {}; template class LSTMKernel : public Kernel { public: - virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht, - /* below only used in peephole*/ - const T *wp_data = nullptr, - T *checked = nullptr) const = 0; - - virtual void ComputeC1H1(T *gates, T *ct, T *ht, - /* below only used in peephole*/ - const T *wp_data = nullptr) const = 0; - - // void (*ComputeCtHt)(lstm_t *); - // // compute c1 and h1 without c0 or h0 - // void (*ComputeC1H1)(lstm_t *); + void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *); + // compute c1 and h1 without c0 or h0 + void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_impl.h b/paddle/fluid/operators/math/jit_kernel_impl.h index 337d5ae91413ae712946b80a1e7a0c7ff182d58c..2e734ca940895382c7d8476c453b816446397fc6 100644 --- a/paddle/fluid/operators/math/jit_kernel_impl.h +++ b/paddle/fluid/operators/math/jit_kernel_impl.h @@ -33,18 +33,24 @@ typedef struct { const void* ct_1; void* ct; void* ht; - /* below only used in peephole*/ - const void* wp_data{nullptr}; + /* weight_peephole and checked data are only used in peephole*/ + const void* wp{nullptr}; void* checked{nullptr}; } lstm_t; typedef struct lstm_attr_s { + bool use_peephole; int d; std::string act_gate, act_cand, act_cell; lstm_attr_s() = default; lstm_attr_s(int _d, const std::string& _act_gate, - const std::string& _act_cand, const std::string& _act_cell) - : d(_d), act_gate(_act_gate), act_cand(_act_cand), act_cell(_act_cell) {} + const std::string& _act_cand, const std::string& _act_cell, + bool _use_peephole = false) + : use_peephole(_use_peephole), + d(_d), + act_gate(_act_gate), + act_cand(_act_cand), + act_cell(_act_cell) {} } lstm_attr_t; } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h index 8acf60cfbfd3d47ad52862241b7635aba6982ebf..5a3efd979f803d396a5084c199b1d71b88a77126 100644 --- a/paddle/fluid/operators/math/jit_kernel_macro.h +++ b/paddle/fluid/operators/math/jit_kernel_macro.h @@ -82,10 +82,10 @@ namespace jitkernel { #define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \ marco_declare, macro_find_key, macro_impl) \ marco_define_name(ker_key, ker_class); \ - REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \ - JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \ - REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \ - JITKERNEL_FIND_KEY, JITKERNEL_IMPL) + REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \ + macro_find_key, macro_impl); \ + REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \ + macro_find_key, macro_impl) #define REGISTER_JITKERNEL(ker_key, ker_class) \ REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \ diff --git a/paddle/fluid/operators/math/jit_kernel_refer.h b/paddle/fluid/operators/math/jit_kernel_refer.h index 9c60ebc5873d2e2d65ce79000b5e32a57ef7f37a..097bb8595612b528fc2d349862f6ce0542b2a0bb 100644 --- a/paddle/fluid/operators/math/jit_kernel_refer.h +++ b/paddle/fluid/operators/math/jit_kernel_refer.h @@ -117,11 +117,13 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT } template -void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) { +void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { T* gates = reinterpret_cast(step->gates); const T* ct_1 = reinterpret_cast(step->ct_1); T* ct = reinterpret_cast(step->ct); T* ht = reinterpret_cast(step->ht); + const T* wp = reinterpret_cast(step->wp); + T* checked = reinterpret_cast(step->checked); auto act_gate = getActFunc(attr->act_gate); auto act_cand = getActFunc(attr->act_cand); auto act_cell = getActFunc(attr->act_cell); @@ -129,23 +131,36 @@ void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) { int d2 = d * 2; int d3 = d * 3; // gates: W_ch, W_ih, W_fh, W_oh - act_gate(gates + d, gates + d, d3); + if (attr->use_peephole) { + VMul(wp, ct_1, checked, d); + VMul(wp + d, ct_1, checked + d, d); + VAdd(checked, gates + d, gates + d, d2); + act_gate(gates + d, gates + d, d2); + } else { + act_gate(gates + d, gates + d, d3); + } - /* C_t = C_t-1 * fgated + cand_gated * igated */ + // C_t = C_t-1 * fgated + cand_gated * igated act_cand(gates, gates, d); VMul(gates, gates + d, gates + d, d); VMul(ct_1, gates + d2, gates + d2, d); VAdd(gates + d, gates + d2, ct, d); - /* H_t = act_cell(C_t) * ogated */ + if (attr->use_peephole) { + // get ogated + VMul(wp + d2, ct, gates + d, d); + VAdd(gates + d, gates + d3, gates + d3, d); + act_gate(gates + d3, gates + d3, d); + } + // H_t = act_cell(C_t) * ogated act_cell(ct, gates + d2, d); VMul(gates + d2, gates + d3, ht, d); } +// compute c1 and h1 without c0 or h0 template -void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) { +void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { T* gates = reinterpret_cast(step->gates); - const T* ct_1 = reinterpret_cast(step->ct_1); T* ct = reinterpret_cast(step->ct); T* ht = reinterpret_cast(step->ht); auto act_gate = getActFunc(attr->act_gate); @@ -158,10 +173,16 @@ void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) { act_gate(gates + d, gates + d, d); act_cand(gates, gates, d); VMul(gates, gates + d, ct, d); + if (attr->use_peephole) { + // get outgated, put W_oc * C_t on igated + const T* wp = reinterpret_cast(step->wp); + VMul(wp + d2, ct, gates + d, d); + VAdd(gates + d, gates + d3, gates + d3, d); + } /* H_t = act_cell(C_t) * ogated */ act_gate(gates + d3, gates + d3, d); act_cell(ct, gates + d2, d); - Vmul(gates + d2, gates + d3, ht, d); + VMul(gates + d2, gates + d3, ht, d); } } // namespace refer diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index e79b0400ab75d1488a26450bd8cde4a0979fc995..6b7463aa52b9313bb127bc86a929f21ad91b4e87 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -15,9 +15,14 @@ limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include #include "paddle/fluid/operators/math/jit_kernel_macro.h" +#include "paddle/fluid/operators/math/jit_kernel_refer.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" +#ifdef PADDLE_WITH_XBYAK +#include "paddle/fluid/operators/math/jit_code.h" +#endif + #ifdef __AVX__ #include #endif @@ -154,211 +159,136 @@ static std::unique_ptr GetAVXAct(const std::string& type) { #endif /* LSTM JitKernel */ -template +template class LSTMKernelImpl : public LSTMKernel { public: - explicit LSTMKernelImpl(const std::string& act_gate, - const std::string& act_cand, - const std::string& act_cell, int d) - : LSTMKernel() { - d_ = d; - d2_ = d * 2; - d3_ = d * 3; - act_gate_d3_ = GetActKernel(act_gate, d3_); - act_gate_d_ = GetActKernel(act_gate, d); - act_cand_d_ = GetActKernel(act_cand, d); - act_cell_d_ = GetActKernel(act_cell, d); - vmul_d_ = KernelPool::Instance().template Get>(d); - vadd_d_ = KernelPool::Instance().template Get>(d); + static inline std::string name(const lstm_attr_t& attr) { + PADDLE_THROW("DType should be either float or double"); } + static inline bool useJIT(int d) { return false; } + static inline bool useMKL(int d) { return false; } + explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(attr.d)) { + size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change + jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); + this->ComputeCtHt = + jitcode0_->getCode(); + + jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096)); + this->ComputeC1H1 = + jitcode1_->getCode(); + return; + } +#endif - void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, - T* checked) const override { - // gates: W_ch, W_ih, W_fh, W_oh - act_gate_d3_->Compute(gates + d_, gates + d_, d3_); - - /* C_t = C_t-1 * fgated + cand_gated * igated */ - act_cand_d_->Compute(gates, gates, d_); - vmul_d_->Compute(gates, gates + d_, gates + d_, d_); - vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); - vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); - - /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->Compute(ct, gates + d2_, d_); - vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); - } - void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { - /* C_t = igated * cgated*/ - act_gate_d_->Compute(gates + d_, gates + d_, d_); - act_cand_d_->Compute(gates, gates, d_); - vmul_d_->Compute(gates, gates + d_, ct, d_); - /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->Compute(gates + d3_, gates + d3_, d_); - act_cell_d_->Compute(ct, gates + d2_, d_); - vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); + this->ComputeCtHt = refer::LSTMCtHt; + this->ComputeC1H1 = refer::LSTMC1H1; } +#ifdef PADDLE_WITH_XBYAK + private: - int d_, d2_, d3_; - std::shared_ptr> act_gate_d3_, act_gate_d_, act_cand_d_, - act_cell_d_; - std::shared_ptr> vmul_d_; - std::shared_ptr> vadd_d_; -#ifdef __AVX__ - std::unique_ptr avx_act_gate_, avx_act_cand_, avx_act_cell_; + std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}; #endif }; -#define INTRI8_FLOAT(isa) \ - template <> \ - LSTMKernelImpl::LSTMKernelImpl( \ - const std::string& act_gate, const std::string& act_cand, \ - const std::string& act_cell, int d) \ - : LSTMKernel() { \ - avx_act_gate_ = GetAVXAct(act_gate); \ - avx_act_cand_ = GetAVXAct(act_cand); \ - avx_act_cell_ = GetAVXAct(act_cell); \ - } \ - template <> \ - void LSTMKernelImpl::ComputeCtHt( \ - float* gates, const float* ct_1, float* ct, float* ht, \ - const float* wp_data, float* checked) const { \ - /* gates: W_ch, W_ih, W_fh, W_oh */ \ - __m256 c, i, f, o; \ - c = _mm256_loadu_ps(gates); \ - i = _mm256_loadu_ps(gates + 8); \ - f = _mm256_loadu_ps(gates + 16); \ - o = _mm256_loadu_ps(gates + 24); \ - /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ - c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ - i = _mm256_loadu_ps(ct_1); \ - f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ - f = _mm256_add_ps(c, f); \ - _mm256_storeu_ps(ct, f); \ - /* H_t = act_cell(C_t) * ogated */ \ - o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ - _mm256_storeu_ps(ht, o); \ - } \ - template <> \ - void LSTMKernelImpl::ComputeC1H1( \ - float* gates, float* ct, float* ht, const float* wp_data) const { \ - __m256 c, i, o; \ - c = _mm256_loadu_ps(gates); \ - i = _mm256_loadu_ps(gates + 8); \ - o = _mm256_loadu_ps(gates + 24); \ - /* C_t = igated * cgated*/ \ - c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \ - _mm256_storeu_ps(ct, c); \ - /* H_t = act_cell(C_t) * ogated */ \ - o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \ - _mm256_storeu_ps(ht, o); \ - } - -// TODO(TJ): optimize keq16 - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -#endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); +#ifdef PADDLE_WITH_XBYAK +template <> +bool LSTMKernelImpl::useJIT(int d) { + return false; // not ready yet gen::LSTMJitCode::init(d); +} #endif /* Peephole JitKernel */ -template +template class PeepholeKernelImpl : public LSTMKernel { public: - explicit PeepholeKernelImpl(const std::string& act_gate, - const std::string& act_cand, - const std::string& act_cell, int d) - : LSTMKernel() { - d_ = d; - d2_ = d * 2; - d3_ = d * 3; - act_gate_d_ = GetActKernel(act_gate, d); - act_cand_d_ = GetActKernel(act_cand, d); - act_cell_d_ = GetActKernel(act_cell, d); - vmul_d_ = KernelPool::Instance().template Get>(d); - vadd_d_ = KernelPool::Instance().template Get>(d); - vadd_d2_ = KernelPool::Instance().template Get>(d2_); - act_gate_d2_ = GetActKernel(act_gate, d2_); + static inline std::string name(const lstm_attr_t& attr) { + PADDLE_THROW("DType should be either float or double"); } + static inline bool useJIT(int d) { return false; } + static inline bool useMKL(int d) { return false; } + explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(attr.d)) { + size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change + jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); + this->ComputeCtHt = + jitcode0_->getCode(); + + jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096)); + this->ComputeC1H1 = + jitcode1_->getCode(); + return; + } +#endif - void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, - T* checked) const override { - /* get fgated and igated*/ - vmul_d_->Compute(wp_data, ct_1, checked, d_); - vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); - vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); - act_gate_d2_->Compute(gates + d_, gates + d_, d2_); - /* C_t = C_t-1 * fgated + cand_gated * igated*/ - act_cand_d_->Compute(gates, gates, d_); - vmul_d_->Compute(gates, gates + d_, gates + d_, d_); - vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); - vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); - /* get ogated*/ - vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); - vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); - act_gate_d_->Compute(gates + d3_, gates + d3_, d_); - /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->Compute(ct, gates + d2_, d_); - vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); + this->ComputeCtHt = refer::LSTMCtHt; + this->ComputeC1H1 = refer::LSTMC1H1; } - void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { - /* C_t = igated * cgated*/ - act_gate_d_->Compute(gates + d_, gates + d_, d_); - act_cand_d_->Compute(gates, gates, d_); - vmul_d_->Compute(gates, gates + d_, ct, d_); - /* get outgated, put W_oc * C_t on igated */ - vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); - vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); - /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->Compute(gates + d3_, gates + d3_, d_); - act_cell_d_->Compute(ct, gates + d2_, d_); - vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); - } +#ifdef PADDLE_WITH_XBYAK private: - int d_, d2_, d3_; - std::shared_ptr> act_gate_d2_, act_gate_d_, act_cand_d_, - act_cell_d_; - std::shared_ptr> vmul_d_; - std::shared_ptr> vadd_d_, vadd_d2_; + std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}; +#endif }; -#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ - template <> \ - std::shared_ptr> \ - KernelPool::Get, const std::string&, \ - const std::string&, const std::string&, int, bool>( \ - const std::string& act_gate, const std::string& act_cand, \ - const std::string& act_cell, int d, bool use_peephole) - -#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \ - #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \ - (use_peephole ? "p" : "n") - -#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \ - if (use_peephole) { \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>( \ - act_gate, act_cand, act_cell, d)); \ - } else { \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(act_gate, act_cand, \ - act_cell, d)); \ +#ifdef PADDLE_WITH_XBYAK +template <> +bool PeepholeKernelImpl::useJIT(int d) { + return false; // peephole jitcode not ready yet +} +#endif + +#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \ + template <> \ + std::string ker_class##Impl::name(const lstm_attr_t& attr) { \ + std::string key(#ker_key "f"); \ + key += (attr.act_gate + attr.act_cand + attr.act_cell + \ + (attr.use_peephole ? "p" : "n")); \ + if (useJIT(attr.d)) { \ + /* only jit code need record d*/ \ + return key + "jit" + std::to_string(attr.d); \ + } else if (useMKL(attr.d)) { \ + return key + "mkl"; \ + } else { \ + return key + "any"; \ + } \ + } \ + template <> \ + std::string ker_class##Impl::name(const lstm_attr_t& attr) { \ + std::string key(#ker_key "d"); \ + /* jit code do not support double yet*/ \ + if (useMKL(attr.d)) { \ + return key + "mkl"; \ + } else { \ + return key + "any"; \ + } \ } -REGISTER_JITKERNEL_ARGS_DEPRECATED(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, - JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); +#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> \ + KernelPool::Get, const lstm_attr_t&>( \ + const lstm_attr_t& attr) + +#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \ + std::string key = ker_class##Impl::name(attr) + +#define JITKERNEL_LSTM_IMPL(ker, dtype) \ + if (attr.use_peephole) { \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(attr)); \ + } else { \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(attr)); \ + } -#undef INTRI8_FLOAT -#undef JITKERNEL_DECLARE_LSTM -#undef JITKERNEL_KEY_LSTM -#undef JITKERNEL_NEW_LSTM_IMPL +REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM, + JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM, + JITKERNEL_LSTM_IMPL); /* GRU JitKernel */ template diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index a1705a81c47d77b5c6cbed2a48d05a4833154b49..1cbe1b5d952f03869d77eb32d79956eb440429b9 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -341,11 +341,11 @@ TEST(JitKernel, lstm) { RandomVec(d, ct_1.data(), -2.f, 2.f); memcpy(xref.data(), x.data(), sizeof(float) * d4); std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; + const jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell, false); const auto& ker = jit::KernelPool::Instance() - .template Get, const std::string&, - const std::string&, const std::string&>( - act_gate, act_cand, act_cell, d, false); + .template Get, const jit::lstm_attr_t&>( + attr); // below kernels are used to compute refer const auto& vsigmoid_3d = jit::KernelPool::Instance().template Get>( @@ -366,14 +366,16 @@ TEST(JitKernel, lstm) { float* ht_ref_data = ht_ref.data(); // compute once to check correctness jit::lstm_t step; - jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell); step.gates = xref_data; step.ct_1 = ct_1_data; step.ct = ct_ref_data; step.ht = ht_ref_data; refer::LSTMCtHt(&step, &attr); - ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); + step.gates = x_data; + step.ct = ct_tgt_data; + step.ht = ht_tgt_data; + ker->ComputeCtHt(&step, &attr); for (int i = 0; i < d; ++i) { EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3); @@ -392,7 +394,7 @@ TEST(JitKernel, lstm) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); + ker->ComputeCtHt(&step, &attr); } auto ttgte = GetCurrentUS(); VLOG(30) << "Vec size " << d @@ -710,21 +712,21 @@ TEST(JitKernel, pool) { namespace jit = paddle::operators::math::jitkernel; const int frame_size = 4; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; + jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false); + const auto& plstm1 = jit::KernelPool::Instance() - .template Get, const std::string&, - const std::string&, const std::string&>( - act_gate, act_cand, act_cell, frame_size, false); + .template Get, const jit::lstm_attr_t&>(attr); + const auto& plstm2 = jit::KernelPool::Instance() - .template Get, const std::string&, - const std::string&, const std::string&>( - act_gate, act_cand, act_cell, frame_size, false); + .template Get, const jit::lstm_attr_t&>(attr); + EXPECT_EQ(plstm1, plstm2); + const auto& peephole = jit::KernelPool::Instance() - .template Get, const std::string&, - const std::string&, const std::string&>( - act_gate, act_cand, act_cell, frame_size, true); + .template Get, const jit::lstm_attr_t&>( + jit::lstm_attr_t(frame_size, act_gate, act_cand, act_cell, true)); EXPECT_TRUE(plstm1 != peephole); const auto& pvmul_f =