diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 0ba51012c4f974ccfa99b77194c0babcc0ea8864..067e6a3e7cccc1f15ebdd984f3a2441339a989ab 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,11 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" -#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -219,116 +217,55 @@ This operator fuse the X into LSTM, more details can refer to LSTM op. template class FuisonLSTMKernel : public framework::OpKernel { public: -#define INIT_VEC_FUNC \ - std::function act_gate, act_cell, act_cand; \ - auto& act_gate_str = ctx.Attr("gate_activation"); \ - auto& act_cell_str = ctx.Attr("cell_activation"); \ - auto& act_cand_str = ctx.Attr("candidate_activation"); \ - if (platform::jit::MayIUse(platform::jit::avx)) { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_cell = act_functor(act_cell_str); \ - act_cand = act_functor(act_cand_str); \ - } else { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_cell = act_functor(act_cell_str); \ - act_cand = act_functor(act_cand_str); \ - } - -#define INIT_BASE_INPUT_OUTPUT \ - auto* x = ctx.Input("X"); \ - auto* h0 = ctx.Input("H0"); \ - auto* c0 = ctx.Input("C0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - auto* cell_out = ctx.Output("Cell"); \ - bool is_reverse = ctx.Attr("is_reverse"); \ - bool use_peepholes = ctx.Attr("use_peepholes"); - -#define INIT_BASE_SIZES \ - auto x_dims = x->dims(); /* T x M*/ \ - auto wh_dims = wh->dims(); /* D x 4D*/ \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D2 = D * 2; \ - const int D3 = D * 3; \ - const int D4 = wh_dims[1]; - -#define INIT_BASE_INPUT_DATAS \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - /* diagonal weight*/ \ - const T* wc_data = bias->data() + D4; \ - /* for peephole only*/ \ - T* checked_cell_data = nullptr; \ - auto place = ctx.GetPlace(); \ - if (use_peepholes) { \ - /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ - auto* checked_cell = ctx.Output("CheckedCell"); \ - checked_cell_data = checked_cell->mutable_data(place); \ - } - -/// Compute LSTM +#define INIT_BASE_DEFINES \ + using DeviceContext = paddle::platform::CPUDeviceContext; \ + auto* x = ctx.Input("X"); \ + auto* h0 = ctx.Input("H0"); \ + auto* c0 = ctx.Input("C0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + auto* cell_out = ctx.Output("Cell"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + bool use_peepholes = ctx.Attr("use_peepholes"); \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 4D*/ \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D4 = wh_dims[1] + +#define INIT_OTHER_DEFINES \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wp_data = bias->data() + D4; \ + /* for peephole only*/ \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + auto* checked_cell = ctx.Output("CheckedCell"); \ + checked_cell_data = checked_cell->mutable_data(place); \ + } \ + const auto& ker = \ + math::jitkernel::KernelPool::Instance() \ + .template Get, const std::string&, \ + const std::string&, const std::string&>( \ + ctx.Attr("gate_activation"), \ + ctx.Attr("candidate_activation"), \ + ctx.Attr("cell_activation"), D, use_peepholes) + +// Wh GEMM #define GEMM_WH_ADDON(bs, prev, out) \ blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast(1), prev, D, \ wh_data, D4, static_cast(1), out, D4) -#define GET_Ct(ct_1, gates, ct) \ - /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ - act_cand(D, gates, gates); \ - blas.VMUL(D, gates, gates + D, gates + D); \ - blas.VMUL(D, ct_1, gates + D2, gates + D2); \ - blas.VADD(D, gates + D, gates + D2, ct) - -#define GET_Ht(ct, gates, ht) \ - /* H_t = act_cell(C_t) * ogated */ \ - act_cell(D, ct, gates + D2); \ - blas.VMUL(D, gates + D2, gates + D3, ht) - -#define GET_Ct_NOH0C0(gates, ct) \ - /* C_t = igated * cgated*/ \ - act_gate(D, gates + D, gates + D); \ - act_cand(D, gates, gates); \ - blas.VMUL(D, gates, gates + D, ct) - -#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \ - GET_Ct_NOH0C0(gates, ct); \ - act_gate(D, gates + D3, gates + D3); \ - GET_Ht(ct, gates, ht) - -#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \ - GET_Ct_NOH0C0(gates, ct); \ - /* get outgated, put W_oc * C_t on igated */ \ - blas.VMUL(D, wc_data + D2, ct, gates + D); \ - blas.VADD(D, gates + D, gates + D3, gates + D3); \ - act_gate(D, gates + D3, gates + D3); \ - GET_Ht(ct, gates, ht) - -#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \ - /* get fgated and igated*/ \ - blas.VMUL(D, wc_data, ct_1, checked_cell_data); \ - blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \ - blas.VADD(D2, checked_cell_data, gates + D, gates + D); \ - act_gate(D2, gates + D, gates + D); \ - GET_Ct(ct_1, gates, ct); \ - /* get ogated*/ \ - blas.VMUL(D, wc_data + D2, ct, gates + D); \ - blas.VADD(D, gates + D, gates + D3, gates + D3); \ - act_gate(D, gates + D3, gates + D3); \ - GET_Ht(ct, gates, ht) - void SeqCompute(const framework::ExecutionContext& ctx) const { - using DeviceContext = paddle::platform::CPUDeviceContext; - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES - INIT_VEC_FUNC - INIT_BASE_INPUT_DATAS - + INIT_BASE_DEFINES; + INIT_OTHER_DEFINES; auto x_lod = x->lod(); const int total_T = x_dims[0]; const int N = x_lod[0].size() - 1; @@ -352,84 +289,47 @@ class FuisonLSTMKernel : public framework::OpKernel { gate_offset = -D; } -#define MOVE_ONE_STEP \ - prev_h_data = h_out_data; \ - prev_c_data = c_out_data; \ - xx_data = xx_data + xx_offset; \ - h_out_data = h_out_data + gate_offset; \ - c_out_data = c_out_data + gate_offset - -#define PROCESS_H0C0_DEFINES \ - int bid = is_reverse ? N - 1 - i : i; \ - int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \ - const T* prev_c_data = nullptr; \ - const T* prev_h_data = nullptr; \ - int tstart = 0 - -#define PROCESS_H0C0_PEEPHOLE \ - PROCESS_H0C0_DEFINES; \ - if (h0_data) { \ - prev_h_data = h0_data + bid * D; \ - prev_c_data = c0_data + bid * D; \ - } else { \ - COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \ - MOVE_ONE_STEP; \ - tstart = 1; \ - } - -#define PROCESS_H0C0 \ - PROCESS_H0C0_DEFINES; \ - if (h0_data) { \ - prev_h_data = h0_data + bid * D; \ - prev_c_data = c0_data + bid * D; \ - } else { \ - COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \ - MOVE_ONE_STEP; \ - tstart = 1; \ - } - - if (use_peepholes) { - for (int i = 0; i < N; ++i) { - PROCESS_H0C0_PEEPHOLE - for (int step = tstart; step < seq_len; ++step) { - GEMM_WH_ADDON(1, prev_h_data, xx_data); - COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data); - MOVE_ONE_STEP; - } + for (int i = 0; i < N; ++i) { + int bid = is_reverse ? N - 1 - i : i; + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; + const T* prev_c_data = nullptr; + const T* prev_h_data = nullptr; + int tstart = 0; + if (h0_data) { + prev_h_data = h0_data + bid * D; + prev_c_data = c0_data + bid * D; + } else { + ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data); + tstart = 1; + // move one step + prev_h_data = h_out_data; + prev_c_data = c_out_data; + xx_data = xx_data + xx_offset; + h_out_data = h_out_data + gate_offset; + c_out_data = c_out_data + gate_offset; } - } else { - const auto& ker = - math::jitkernel::KernelPool::Instance() - .template Get, const std::string&, - const std::string&, const std::string&>( - act_gate_str, act_cand_str, act_cell_str, D, false); - - for (int i = 0; i < N; ++i) { - PROCESS_H0C0 - for (int step = tstart; step < seq_len; ++step) { - GEMM_WH_ADDON(1, prev_h_data, xx_data); - ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data); - MOVE_ONE_STEP; - } + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON(1, prev_h_data, xx_data); + ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data, + checked_cell_data); + // move one step + prev_h_data = h_out_data; + prev_c_data = c_out_data; + xx_data = xx_data + xx_offset; + h_out_data = h_out_data + gate_offset; + c_out_data = c_out_data + gate_offset; } } -#undef PROCESS_H0C0_DEFINES -#undef PROCESS_H0C0_PEEPHOLE -#undef PROCESS_H0C0 -#undef MOVE_ONE_STEP } void BatchCompute(const framework::ExecutionContext& ctx) const { - using DeviceContext = platform::CPUDeviceContext; - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES + INIT_BASE_DEFINES; if (x->lod()[0].size() == 2) { xx->Resize({x_dims[0], D4}); SeqCompute(ctx); return; } - INIT_VEC_FUNC - INIT_BASE_INPUT_DATAS + INIT_OTHER_DEFINES; auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* reordered_c0 = ctx.Output("ReorderedC0"); @@ -477,8 +377,8 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_c_data = reordered_c0_data; size_t sz = sizeof(T) * D; for (int i = 0; i < max_bs; ++i) { - std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); - std::memcpy(reordered_c0_data, c0_data + seq_order[i] * D, sz); + blas.VCOPY(sz, h0_data + seq_order[i] * D, reordered_h0_data); + blas.VCOPY(sz, c0_data + seq_order[i] * D, reordered_c0_data); reordered_h0_data += D; reordered_c0_data += D; } @@ -488,13 +388,7 @@ class FuisonLSTMKernel : public framework::OpKernel { T* cur_h_out_data = batched_h_out_data; T* cur_c_out_data = batched_c_out_data; for (int i = 0; i < max_bs; ++i) { - GET_Ct_NOH0C0(cur_in_data, cur_c_out_data); - if (use_peepholes) { - blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D); - blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3); - } - act_gate(D, cur_in_data + D3, cur_in_data + D3); - GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data); + ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data); cur_in_data += D4; cur_c_out_data += D; cur_h_out_data += D; @@ -503,66 +397,37 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_h_data = batched_h_out_data; prev_c_data = batched_c_out_data; } + + // compute kernel part const auto& batch_starts = batched_lod[0]; const int max_seq_len = batch_starts.size() - 1; const int offset = tstart * max_bs * D; batched_input_data = batched_input_data + offset * 4; batched_h_out_data = batched_h_out_data + offset; batched_c_out_data = batched_c_out_data + offset; - -#define DEFINE_CUR \ - T* cur_in_data = batched_input_data; \ - T* cur_prev_c_data = prev_c_data; \ - T* cur_c_out_data = batched_c_out_data; \ - T* cur_h_out_data = batched_h_out_data - -#define MOVE_ONE_BATCH \ - cur_in_data += D4; \ - cur_prev_c_data += D; \ - cur_c_out_data += D; \ - cur_h_out_data += D - -#define MOVE_ONE_STEP \ - prev_c_data = batched_c_out_data; \ - prev_h_data = batched_h_out_data; \ - batched_c_out_data = cur_c_out_data; \ - batched_h_out_data = cur_h_out_data; \ - batched_input_data = cur_in_data - - if (use_peepholes) { - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = batch_starts[step + 1] - batch_starts[step]; - GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); - DEFINE_CUR; - for (int i = 0; i < cur_bs; ++i) { - COMPUTE_CtHt_PEEPHOLE(cur_in_data, cur_prev_c_data, cur_c_out_data, - cur_h_out_data); - MOVE_ONE_BATCH; - } - MOVE_ONE_STEP; - } - } else { - const auto& ker = - math::jitkernel::KernelPool::Instance() - .template Get, const std::string&, - const std::string&, const std::string&>( - act_gate_str, act_cand_str, act_cell_str, D, false); - - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = batch_starts[step + 1] - batch_starts[step]; - GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); - DEFINE_CUR; - for (int i = 0; i < cur_bs; ++i) { - ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, - cur_h_out_data); - MOVE_ONE_BATCH; - } - MOVE_ONE_STEP; + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + T* cur_in_data = batched_input_data; + T* cur_prev_c_data = prev_c_data; + T* cur_c_out_data = batched_c_out_data; + T* cur_h_out_data = batched_h_out_data; + for (int i = 0; i < cur_bs; ++i) { + ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data, wp_data, checked_cell_data); + // move one batch + cur_in_data += D4; + cur_prev_c_data += D; + cur_c_out_data += D; + cur_h_out_data += D; } + // move one step + prev_c_data = batched_c_out_data; + prev_h_data = batched_h_out_data; + batched_c_out_data = cur_c_out_data; + batched_h_out_data = cur_h_out_data; + batched_input_data = cur_in_data; } -#undef MOVE_ONE_STEP -#undef MOVE_ONE_BATCH -#undef DEFINE_CUR math::Batch2LoDTensorFunctor to_seq; batched_h_out->set_lod(batched_lod); @@ -579,17 +444,9 @@ class FuisonLSTMKernel : public framework::OpKernel { } } -#undef COMPUTE_CtHt_PEEPHOLE -#undef GET_Ct_NOH0C0 -#undef COMPUTE_CtHt_NOH0C0 -#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0 -#undef GET_Ht -#undef GET_Ct #undef GEMM_WH_ADDON -#undef INIT_BASE_INPUT_DATAS -#undef INIT_BASE_SIZES -#undef INIT_BASE_INPUT_OUTPUT -#undef INIT_VEC_FUNC +#undef INIT_OTHER_DEFINES +#undef INIT_BASE_DEFINES }; } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index aeb439bb8615b21cefe64a8df0be34196036ee6e..b4dfda6db76fd4231be0acd1f90c98a2d62134b8 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -126,7 +126,14 @@ template class LSTMKernel : public Kernel { public: virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht, + /* below only used in peephole*/ + const T *wp_data = nullptr, T *checked = nullptr) const = 0; + + // compute c1 and h1 without c0 or h0 + virtual void ComputeC1H1(T *gates, T *ct, T *ht, + /* below only used in peephole*/ + const T *wp_data = nullptr) const = 0; }; } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_lstm.cc index 17e2d1fbb44093bc9e0941fb2e578269398f3573..42a2b96fd945c516f8c26ca51ecb452345a9a86f 100644 --- a/paddle/fluid/operators/math/jit_kernel_lstm.cc +++ b/paddle/fluid/operators/math/jit_kernel_lstm.cc @@ -82,6 +82,26 @@ __m256 AVXActImpl::Compute(__m256 x) const { } #endif +template +static std::shared_ptr> GetActKernel( + const std::string& type, int n) { + if (type == "sigmoid") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "relu") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "tanh") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "identity" || type == "") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } + PADDLE_THROW("Not support type: %s", type); + return nullptr; +} + /* LSTM JitKernel */ template class LSTMKernelImpl : public LSTMKernel { @@ -93,26 +113,10 @@ class LSTMKernelImpl : public LSTMKernel { d_ = d; d2_ = d * 2; d3_ = d * 3; - auto GetActKernel = [&](const std::string& type, - int n) -> std::shared_ptr> { - if (type == "sigmoid") { - return std::dynamic_pointer_cast>( - KernelPool::Instance().template Get>(n)); - } else if (type == "relu") { - return std::dynamic_pointer_cast>( - KernelPool::Instance().template Get>(n)); - } else if (type == "tanh") { - return std::dynamic_pointer_cast>( - KernelPool::Instance().template Get>(n)); - } else if (type == "identity" || type == "") { - return std::dynamic_pointer_cast>( - KernelPool::Instance().template Get>(n)); - } - PADDLE_THROW("Not support type: %s", type); - }; - act_gate_3d_ = GetActKernel(act_gate, d * 3); - act_cand_d_ = GetActKernel(act_cand, d); - act_cell_d_ = GetActKernel(act_cell, d); + act_gate_d3_ = GetActKernel(act_gate, d3_); + act_gate_d_ = GetActKernel(act_gate, d); + act_cand_d_ = GetActKernel(act_cand, d); + act_cell_d_ = GetActKernel(act_cell, d); vmul_d_ = KernelPool::Instance().template Get>(d); vadd_d_ = KernelPool::Instance().template Get>(d); #ifdef __AVX__ @@ -134,10 +138,10 @@ class LSTMKernelImpl : public LSTMKernel { #endif } - void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, + void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, T* checked) const override { // gates: W_ch, W_ih, W_fh, W_oh - act_gate_3d_->Compute(gates + d_, gates + d_); + act_gate_d3_->Compute(gates + d_, gates + d_); /* C_t = C_t-1 * fgated + cand_gated * igated */ act_cand_d_->Compute(gates, gates); @@ -149,10 +153,21 @@ class LSTMKernelImpl : public LSTMKernel { act_cell_d_->Compute(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht); } + void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { + /* C_t = igated * cgated*/ + act_gate_d_->Compute(gates + d_, gates + d_); + act_cand_d_->Compute(gates, gates); + vmul_d_->Compute(gates, gates + d_, ct); + /* H_t = act_cell(C_t) * ogated */ + act_gate_d_->Compute(gates + d3_, gates + d3_); + act_cell_d_->Compute(ct, gates + d2_); + vmul_d_->Compute(gates + d2_, gates + d3_, ht); + } private: int d_, d2_, d3_; - std::shared_ptr> act_gate_3d_, act_cand_d_, act_cell_d_; + std::shared_ptr> act_gate_d3_, act_gate_d_, act_cand_d_, + act_cell_d_; std::shared_ptr> vmul_d_; std::shared_ptr> vadd_d_; #ifdef __AVX__ @@ -163,8 +178,8 @@ class LSTMKernelImpl : public LSTMKernel { #define INTRI8_FLOAT(isa) \ template <> \ void LSTMKernelImpl::ComputeCtHt( \ - float* gates, const float* ct_1, float* ct, float* ht, float* checked) \ - const { \ + float* gates, const float* ct_1, float* ct, float* ht, \ + const float* wp_data, float* checked) const { \ /* gates: W_ch, W_ih, W_fh, W_oh */ \ __m256 c, i, f, o; \ c = _mm256_loadu_ps(gates); \ @@ -205,51 +220,56 @@ class PeepholeKernelImpl : public LSTMKernel { d_ = d; d2_ = d * 2; d3_ = d * 3; - auto GetActKernel = [&](const std::string& type, - int n) -> std::shared_ptr> { - if (type == "sigmoid") { - return std::dynamic_pointer_cast>( - KernelPool::Instance().template Get>(n)); - } else if (type == "relu") { - return std::dynamic_pointer_cast>( - KernelPool::Instance().template Get>(n)); - } else if (type == "tanh") { - return std::dynamic_pointer_cast>( - KernelPool::Instance().template Get>(n)); - } else if (type == "identity" || type == "") { - return std::dynamic_pointer_cast>( - KernelPool::Instance().template Get>(n)); - } - PADDLE_THROW("Not support type: %s", type); - }; - act_gate_3d_ = GetActKernel(act_gate, d * 3); - act_cand_d_ = GetActKernel(act_cand, d); - act_cell_d_ = GetActKernel(act_cell, d); + act_gate_d_ = GetActKernel(act_gate, d); + act_cand_d_ = GetActKernel(act_cand, d); + act_cell_d_ = GetActKernel(act_cell, d); vmul_d_ = KernelPool::Instance().template Get>(d); vadd_d_ = KernelPool::Instance().template Get>(d); + vadd_d2_ = KernelPool::Instance().template Get>(d2_); + act_gate_d2_ = GetActKernel(act_gate, d2_); } - void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, + void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, T* checked) const override { - // gates: W_ch, W_ih, W_fh, W_oh - act_gate_3d_->Compute(gates + d_, gates + d_); - - /* C_t = C_t-1 * fgated + cand_gated * igated */ + /* get fgated and igated*/ + vmul_d_->Compute(wp_data, ct_1, checked); + vmul_d_->Compute(wp_data + d_, ct_1, checked + d_); + vadd_d2_->Compute(checked, gates + d_, gates + d_); + act_gate_d2_->Compute(gates + d_, gates + d_); + /* C_t = C_t-1 * fgated + cand_gated * igated*/ act_cand_d_->Compute(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_); vadd_d_->Compute(gates + d_, gates + d2_, ct); + /* get ogated*/ + vmul_d_->Compute(wp_data + d2_, ct, gates + d_); + vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); + act_gate_d_->Compute(gates + d3_, gates + d3_); + /* H_t = act_cell(C_t) * ogated */ + act_cell_d_->Compute(ct, gates + d2_); + vmul_d_->Compute(gates + d2_, gates + d3_, ht); + } + void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { + /* C_t = igated * cgated*/ + act_gate_d_->Compute(gates + d_, gates + d_); + act_cand_d_->Compute(gates, gates); + vmul_d_->Compute(gates, gates + d_, ct); + /* get outgated, put W_oc * C_t on igated */ + vmul_d_->Compute(wp_data + d2_, ct, gates + d_); + vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); /* H_t = act_cell(C_t) * ogated */ + act_gate_d_->Compute(gates + d3_, gates + d3_); act_cell_d_->Compute(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht); } private: int d_, d2_, d3_; - std::shared_ptr> act_gate_3d_, act_cand_d_, act_cell_d_; + std::shared_ptr> act_gate_d2_, act_gate_d_, act_cand_d_, + act_cell_d_; std::shared_ptr> vmul_d_; - std::shared_ptr> vadd_d_; + std::shared_ptr> vadd_d_, vadd_d2_; }; #define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \