diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 0d69dfa79aa26940f8f56f84b35ffed34f29f703..9512fd056e73836cdc34a9e409ab2d7809a6aff6 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index 582c75872ab2818cdf834f9a46278db1d6f91d54..916f84cb4a78c3721cb67bd3cf8d3759a8eaf1bf 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { "Input(WeightX) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasInput("WeightH"), "Input(WeightH) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), - "Output(BatchedOut) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Output(Hidden) of GRU should not be null."); @@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { } framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); - ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchedOut", out_dims); ctx->ShareLoD("X", "Hidden"); - int xx_width; if (ctx->Attrs().Get("use_seq")) { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), + "Output(ReorderedH0) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), + "Output(BatchedInput) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), + "Output(BatchedOut) of GRU should not be null."); + ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); + ctx->SetOutputDim("BatchedOut", out_dims); } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 104e160e2d7069ec247cc51e927ce8824f1b69e8..ef23ab3f981786d33567619ad0194d21f31bdc8e 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "Output(Hidden) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), "Output(Cell) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), - "Output(BatchedHidden) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), - "Output(BatchedCell) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), - "Output(ReorderedC0) of LSTM should not be null."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); @@ -88,28 +78,36 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims[0], 1, "The first dimension of Input(Bias) should be 1."); - - auto use_peepholes = ctx->Attrs().Get("use_peepholes"); - PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size, - "The second dimension of Input(Bias) should be " - "7 * %d if enable peepholes connection or" - "4 * %d if disable peepholes", - frame_size, frame_size); + PADDLE_ENFORCE_EQ( + b_dims[1], (ctx->Attrs().Get("use_peepholes") ? 7 : 4) * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection or" + "4 * %d if disable peepholes", + frame_size, frame_size); framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchedHidden", out_dims); - ctx->SetOutputDim("BatchedCell", out_dims); ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); - int xx_width; if (ctx->Attrs().Get("use_seq")) { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), + "Output(BatchedInput) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), + "Output(BatchedHidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), + "Output(BatchedCell) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), + "Output(ReorderedH0) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), + "Output(ReorderedC0) of LSTM should not be null."); + ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); + ctx->SetOutputDim("BatchedHidden", out_dims); + ctx->SetOutputDim("BatchedCell", out_dims); } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); @@ -232,18 +230,18 @@ class FuisonLSTMKernel : public framework::OpKernel { act_cand = act_functor(act_cand_str); \ } -#define INIT_BASE_INPUT_OUTPUT \ - auto* x = ctx.Input("X"); \ - auto* h0 = ctx.Input("H0"); \ - auto* c0 = ctx.Input("C0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - auto* cell_out = ctx.Output("Cell"); \ - bool use_peepholes = ctx.Attr("use_peepholes"); \ - bool is_reverse = ctx.Attr("is_reverse"); +#define INIT_BASE_INPUT_OUTPUT \ + auto* x = ctx.Input("X"); \ + auto* h0 = ctx.Input("H0"); \ + auto* c0 = ctx.Input("C0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + auto* cell_out = ctx.Output("Cell"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + bool use_peepholes = ctx.Attr("use_peepholes"); #define INIT_BASE_SIZES \ auto x_dims = x->dims(); /* T x M*/ \ @@ -254,172 +252,183 @@ class FuisonLSTMKernel : public framework::OpKernel { const int D3 = D * 3; \ const int D4 = wh_dims[1]; +#define INIT_BASE_INPUT_DATAS \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wc_data = bias->data() + D4; \ + /* for peephole only*/ \ + Tensor checked_cell; \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + checked_cell_data = checked_cell.mutable_data({2, D}, place); \ + } + +/// Compute LSTM +#define GEMM_WH_ADDON(bs, prev, out) \ + blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast(1), prev, D, \ + wh_data, D4, static_cast(1), out, D4) + +// gates: W_ch, W_ih, W_fh, W_oh +#define GET_Ct(ct_1, gates, ct) \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + act_cand(D, gates, gates); \ + blas.VMUL(D, gates, gates + D, gates + D); \ + blas.VMUL(D, ct_1, gates + D2, gates + D2); \ + blas.VADD(D, gates + D, gates + D2, ct) + +#define GET_Ht(ct, gates, ht) \ + /* H_t = act_cell(C_t) * ogated */ \ + act_cell(D, ct, gates + D2); \ + blas.VMUL(D, gates + D2, gates + D3, ht) + +#define GET_Ct_NOH0C0(gates, ct) \ + /* C_t = igated * cgated*/ \ + act_gate(D, gates + D, gates + D); \ + act_cand(D, gates, gates); \ + blas.VMUL(D, gates, gates + D, ct) + +#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \ + GET_Ct_NOH0C0(gates, ct); \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \ + GET_Ct_NOH0C0(gates, ct); \ + /* get outgated, put W_oc * C_t on igated */ \ + blas.VMUL(D, wc_data + D2, ct, gates + D); \ + blas.VADD(D, gates + D, gates + D3, gates + D3); \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt(gates, ct_1, ct, ht) \ + act_gate(D3, gates + D, gates + D); \ + GET_Ct(ct_1, gates, ct); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \ + /* get fgated and igated*/ \ + blas.VMUL(D, wc_data, ct_1, checked_cell_data); \ + blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \ + blas.VADD(D2, checked_cell_data, gates + D, gates + D); \ + act_gate(D2, gates + D, gates + D); \ + GET_Ct(ct_1, gates, ct); \ + /* get ogated*/ \ + blas.VMUL(D, wc_data + D2, ct, gates + D); \ + blas.VADD(D, gates + D, gates + D3, gates + D3); \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT INIT_BASE_SIZES INIT_VEC_FUNC + INIT_BASE_INPUT_DATAS auto x_lod = x->lod(); const int total_T = x_dims[0]; - const int N = x_lod[0].size() - 1; // batch size - - const T* x_data = x->data(); + const int N = x_lod[0].size() - 1; const T* h0_data = h0 ? h0->data() : nullptr; const T* c0_data = c0 ? c0->data() : nullptr; - const T* bias_data = bias->data(); - const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); - - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); - T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); - - // use local variable - framework::DDim check_dims({3, D}); - Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct - auto checked_cell_data = - checked_cell.mutable_data(check_dims, ctx.GetPlace()); - + T* xx_data = xx->mutable_data(place); + T* h_out_data = hidden_out->mutable_data(place); + T* c_out_data = cell_out->mutable_data(place); auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); + int xx_offset = D4; int gate_offset = D; if (is_reverse) { const int offset = (total_T - 1) * D; xx_data = xx_data + offset * 4; - hidden_out_data = hidden_out_data + offset; - cell_out_data = cell_out_data + offset; + h_out_data = h_out_data + offset; + c_out_data = c_out_data + offset; xx_offset = -D4; gate_offset = -D; } - auto move_step = [&]() { - xx_data = xx_data + xx_offset; - hidden_out_data = hidden_out_data + gate_offset; - cell_out_data = cell_out_data + gate_offset; - }; - - for (int i = 0; i < N; ++i) { - int bid = is_reverse ? N - 1 - i : i; - int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; - const T* prev_c_data = nullptr; - const T* prev_h_data = nullptr; - - int tstart = 0; - if (h0_data) { - prev_h_data = h0_data + bid * D; - prev_c_data = c0_data + bid * D; - } else { - // If step == 0 and there is no initialized hidden state, that is to say - // the H0 is zeros. Then W_h * H_t-1 can be skipped - - // ~C_t - act_cand(D, xx_data, xx_data); - if (use_peepholes) { - // I_t, F_t - act_gate(D2, xx_data + D, xx_data + D); - } else { - // I_t, F_t, O_t - act_gate(D3, xx_data + D, xx_data + D); - } - // C_t = I_t * ~C_t - blas.VMUL(D, xx_data, xx_data + D, cell_out_data); - - if (use_peepholes) { - // + W_oc * C_t for peephole connection - blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2); - blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3); - // O_t - act_gate(D, xx_data + D3, xx_data + D3); - } - - // hidden out= act_state(cellout) * outgate - act_cell(D, cell_out_data, xx_data + D2); - // H_t = O_t * act_state(C_t) - blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); - - // prev - prev_h_data = hidden_out_data; - prev_c_data = cell_out_data; - - tstart = 1; - move_step(); - } - - for (int step = tstart; step < seq_len; ++step) { - // + W_h * H_t-1 - blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), - prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4); +#define MOVE_ONE_STEP \ + prev_h_data = h_out_data; \ + prev_c_data = c_out_data; \ + xx_data = xx_data + xx_offset; \ + h_out_data = h_out_data + gate_offset; \ + c_out_data = c_out_data + gate_offset + +#define PROCESS_H0C0_DEFINES \ + int bid = is_reverse ? N - 1 - i : i; \ + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \ + const T* prev_c_data = nullptr; \ + const T* prev_h_data = nullptr; \ + int tstart = 0 + +#define PROCESS_H0C0_PEEPHOLE \ + PROCESS_H0C0_DEFINES; \ + if (h0_data) { \ + prev_h_data = h0_data + bid * D; \ + prev_c_data = c0_data + bid * D; \ + } else { \ + COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \ + MOVE_ONE_STEP; \ + tstart = 1; \ + } - // ~C_t - act_cand(D, xx_data, xx_data); +#define PROCESS_H0C0 \ + PROCESS_H0C0_DEFINES; \ + if (h0_data) { \ + prev_h_data = h0_data + bid * D; \ + prev_c_data = c0_data + bid * D; \ + } else { \ + COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \ + MOVE_ONE_STEP; \ + tstart = 1; \ + } - if (use_peepholes) { - // + W_ic|W_fc * C_t-1 for peephole connection - blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); - blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); - blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D); - // I_t, F_t - act_gate(D2, xx_data + D, xx_data + D); - } else { - // I_t, F_t, O_t - act_gate(D3, xx_data + D, xx_data + D); + if (use_peepholes) { + for (int i = 0; i < N; ++i) { + PROCESS_H0C0_PEEPHOLE + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON(1, prev_h_data, xx_data); + COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data); + MOVE_ONE_STEP; } - - // F_t * C_t-1 - blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2); - // I_t * ~C_t - blas.VMUL(D, xx_data, xx_data + D, xx_data + D); - // C_t = F_t * C_t-1 + I_t * ~C_t - blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); - - if (use_peepholes) { - // + W_oc * C_t for peephole connection - blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2); - blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3); - // O_t - act_gate(D, xx_data + D3, xx_data + D3); + } + } else { + for (int i = 0; i < N; ++i) { + PROCESS_H0C0 + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON(1, prev_h_data, xx_data); + COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data); + MOVE_ONE_STEP; } - - // hidden out= act_state(cellout) * outgate - act_cell(D, cell_out_data, xx_data + D2); - // H_t = O_t * act_state(C_t) - blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); - - // prev - prev_h_data = hidden_out_data; - prev_c_data = cell_out_data; - - move_step(); - } // for each step in batch - } // for each batch + } + } +#undef PROCESS_H0C0_DEFINES +#undef PROCESS_H0C0_PEEPHOLE +#undef PROCESS_H0C0 +#undef MOVE_ONE_STEP } void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT - if (x->lod()[0].size() == 2) { // batch size == 1 + if (x->lod()[0].size() == 2) { SeqCompute(ctx); return; } INIT_BASE_SIZES INIT_VEC_FUNC + INIT_BASE_INPUT_DATAS auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* reordered_c0 = ctx.Output("ReorderedC0"); auto* batched_input = ctx.Output("BatchedInput"); auto* batched_c_out = ctx.Output("BatchedCell"); auto* batched_h_out = ctx.Output("BatchedHidden"); - - const T* x_data = x->data(); - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); - const T* bias_data = bias->data(); - const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc - auto place = ctx.GetPlace(); T* xx_data = xx->mutable_data(place); T* batched_input_data = batched_input->mutable_data(place); T* batched_c_out_data = batched_c_out->mutable_data(place); @@ -427,12 +436,6 @@ class FuisonLSTMKernel : public framework::OpKernel { hidden_out->mutable_data(place); cell_out->mutable_data(place); - // use local variable - framework::DDim check_dims({3, D}); - Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct - auto checked_cell_data = - checked_cell.mutable_data(check_dims, ctx.GetPlace()); - math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); @@ -454,27 +457,17 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_h0->Resize({max_bs, D}); reordered_c0->Resize({max_bs, D}); - T* prev_batch_h_data = nullptr; - T* prev_batch_c_data = nullptr; - T* cur_batch_in_data = batched_input_data; - T* cur_batch_h_out_data = batched_h_out_data; - T* cur_batch_c_out_data = batched_c_out_data; - - auto move_step = [&](int bs) { - cur_batch_in_data += bs * D4; - cur_batch_c_out_data += bs * D; - cur_batch_h_out_data += bs * D; - }; - int tstart = 0; + T* prev_h_data = nullptr; + T* prev_c_data = nullptr; if (h0) { // reorder h0, c0 T* reordered_h0_data = reordered_h0->mutable_data(place); T* reordered_c0_data = reordered_c0->mutable_data(place); const T* h0_data = h0->data(); const T* c0_data = c0->data(); - prev_batch_h_data = reordered_h0_data; - prev_batch_c_data = reordered_c0_data; + prev_h_data = reordered_h0_data; + prev_c_data = reordered_c0_data; size_t sz = sizeof(T) * D; for (int i = 0; i < max_bs; ++i) { std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); @@ -483,123 +476,80 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_c0_data += D; } } else { - // Compute with no H0/C0 - T* cur_in_data = cur_batch_in_data; - T* cur_c_out_data = cur_batch_c_out_data; - T* cur_h_out_data = cur_batch_h_out_data; - - // If step == 0 and there is no initialized hidden state, that is to say - // the H0 is zeros. Then W_h * H_t-1 can be skiped - - for (int i = 0; i < max_bs; ++i) { // iterate each data in 1st batch - // ~C_t - act_cand(D, cur_in_data, cur_in_data); - - if (use_peepholes) { - // I_t, F_t - act_gate(D2, cur_in_data + D, cur_in_data + D); - } else { - // I_t, F_t, O_t - act_gate(D3, cur_in_data + D, cur_in_data + D); - } - - // C_t = I_t * ~C_t - blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data); - + // compute without h0, c0 + T* cur_in_data = batched_input_data; + T* cur_h_out_data = batched_h_out_data; + T* cur_c_out_data = batched_c_out_data; + for (int i = 0; i < max_bs; ++i) { + GET_Ct_NOH0C0(cur_in_data, cur_c_out_data); if (use_peepholes) { - // + W_oc * C_t for peephole connection - blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); - blas.VADD(D, cur_in_data + D3, checked_cell_data + D2, - cur_in_data + D3); - // O_t - act_gate(D, cur_in_data + D3, cur_in_data + D3); + blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D); + blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3); } - - // hidden out= act_state(cellout) * outgate - act_cell(D, cur_c_out_data, cur_in_data + D2); - // H_t = O_t * act_state(C_t) - blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); - - // move to next data in the same batch + act_gate(D, cur_in_data + D3, cur_in_data + D3); + GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data); cur_in_data += D4; cur_c_out_data += D; cur_h_out_data += D; } - - // move to data for next timestep - prev_batch_h_data = cur_batch_h_out_data; - prev_batch_c_data = cur_batch_c_out_data; - move_step(max_bs); tstart = 1; + prev_h_data = batched_h_out_data; + prev_c_data = batched_c_out_data; } - const auto& batch_starts = batched_lod[0]; const int max_seq_len = batch_starts.size() - 1; - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = batch_starts[step + 1] - batch_starts[step]; - // + W_h * H_t-1 - blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast(1), - prev_batch_h_data, D, wh_data, D4, static_cast(1), - cur_batch_in_data, D4); - - T* cur_in_data = cur_batch_in_data; - T* cur_c_out_data = cur_batch_c_out_data; - T* cur_h_out_data = cur_batch_h_out_data; - T* prev_c_data = prev_batch_c_data; // NULL if no C0 in step0 - T* prev_h_data = prev_batch_h_data; // NULL if no H0 in step0 - auto next_data_in_batch = [&]() { - cur_in_data += D4; - cur_c_out_data += D; - cur_h_out_data += D; - prev_c_data = prev_c_data ? prev_c_data + D : nullptr; - prev_h_data = prev_h_data ? prev_h_data + D : nullptr; - }; - - for (int i = 0; i < cur_bs; ++i) { // iterate each data in same batch - // ~C_t - act_cand(D, cur_in_data, cur_in_data); - - if (use_peepholes) { - // + W_ic|W_fc * C_t-1 for peephole connection - blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); - blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); - blas.VADD(D2, cur_in_data + D, checked_cell_data, cur_in_data + D); - // I_t, F_t - act_gate(D2, cur_in_data + D, cur_in_data + D); - } else { - // I_t, F_t, O_t - act_gate(D3, cur_in_data + D, cur_in_data + D); + const int offset = tstart * max_bs * D; + batched_input_data = batched_input_data + offset * 4; + batched_h_out_data = batched_h_out_data + offset; + batched_c_out_data = batched_c_out_data + offset; + +#define DEFINE_CUR \ + T* cur_in_data = batched_input_data; \ + T* cur_prev_c_data = prev_c_data; \ + T* cur_c_out_data = batched_c_out_data; \ + T* cur_h_out_data = batched_h_out_data + +#define MOVE_ONE_BATCH \ + cur_in_data += D4; \ + cur_prev_c_data += D; \ + cur_c_out_data += D; \ + cur_h_out_data += D + +#define MOVE_ONE_STEP \ + prev_c_data = batched_c_out_data; \ + prev_h_data = batched_h_out_data; \ + batched_c_out_data = cur_c_out_data; \ + batched_h_out_data = cur_h_out_data; \ + batched_input_data = cur_in_data + + if (use_peepholes) { + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + DEFINE_CUR; + for (int i = 0; i < cur_bs; ++i) { + COMPUTE_CtHt_PEEPHOLE(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data); + MOVE_ONE_BATCH; } - - // F_t * C_t-1 - blas.VMUL(D, cur_in_data + D2, prev_c_data, cur_in_data + D2); - // I_t * ~C_t - blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D); - // C_t = F_t * C_t-1 + I_t * ~C_t - blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data); - - if (use_peepholes) { - // + W_oc * C_t for peephole connection - blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); - blas.VADD(D, cur_in_data + D3, checked_cell_data + D2, - cur_in_data + D3); - // O_t - act_gate(D, cur_in_data + D3, cur_in_data + D3); + MOVE_ONE_STEP; + } + } else { + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + DEFINE_CUR; + for (int i = 0; i < cur_bs; ++i) { + COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data); + MOVE_ONE_BATCH; } - - // hidden out= act_state(cellout) * outgate - act_cell(D, cur_c_out_data, cur_in_data + D2); - // H_t = O_t * act_state(C_t) - blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); - - // move to next data in same batch - next_data_in_batch(); + MOVE_ONE_STEP; } - // move to data for next timestep - prev_batch_h_data = cur_batch_h_out_data; - prev_batch_c_data = cur_batch_c_out_data; - move_step(cur_bs); } +#undef MOVE_ONE_STEP +#undef MOVE_ONE_BATCH +#undef DEFINE_CUR math::Batch2LoDTensorFunctor to_seq; batched_h_out->set_lod(batched_lod); @@ -615,6 +565,16 @@ class FuisonLSTMKernel : public framework::OpKernel { BatchCompute(ctx); } } + +#undef COMPUTE_CtHt_PEEPHOLE +#undef COMPUTE_CtHt +#undef GET_Ct_NOH0C0 +#undef COMPUTE_CtHt_NOH0C0 +#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0 +#undef GET_Ht +#undef GET_Ct +#undef GEMM_WH_ADDON +#undef INIT_BASE_INPUT_DATAS #undef INIT_BASE_SIZES #undef INIT_BASE_INPUT_OUTPUT #undef INIT_VEC_FUNC diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 4767e9433ea74d5da83867d646f2a63c9a092668..de0c86f96db958eebd7e74346bec244f0c804ed9 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest): self.M = 8 self.D = 16 self.has_initial_state = False + self.use_peepholes = False self.is_reverse = False self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' - self.use_peepholes = False - self.use_seq = False self.set_conf() T = sum(self.lod[0]) @@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest): } self.attrs = { 'use_peepholes': self.use_peepholes, - 'use_seq': self.use_seq, 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, @@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp): self.is_reverse = True -class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): +class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp): def set_conf(self): self.use_peepholes = True - self.lod = [[3]] - self.D = 16 - - -class TestFusionLSTMOpSeqInit(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True - self.has_initial_state = True - - -class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True - self.is_reverse = True - - -class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True self.has_initial_state = True self.is_reverse = True -class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp): +class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp): def set_conf(self): - self.use_seq = True self.use_peepholes = True - - -class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True - self.use_peepholes = True - self.has_initial_state = True - - -class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True - self.use_peepholes = True - self.is_reverse = True + self.lod = [[2]] + self.D = 8 if __name__ == '__main__':