diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index d67029a392e5d973424fd9322b9000ea5f34ac3f..bdd03caa3bc62d6b06ac9298002caf47ea8f87b7 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" +DEFINE_bool(gru_use_seq, true, "Use sequence mode"); + namespace paddle { namespace operators { @@ -84,7 +86,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ctx->SetOutputDim("BatchedOut", out_dims); ctx->ShareLoD("X", "Hidden"); - int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + int xx_width; + if (FLAGS_gru_use_seq) { + xx_width = wx_dims[1]; + } else { + xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); } @@ -157,6 +164,122 @@ template class FusionGRUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + if (FLAGS_gru_use_seq) { + SeqCompute(ctx); + } else { + BatchCompute(ctx); + } + } + +#define INIT_VEC_FUNC \ + std::function act_gate, act_state; \ + std::function cross; \ + auto& act_gate_str = ctx.Attr("gate_activation"); \ + auto& act_state_str = ctx.Attr("activation"); \ + if (platform::jit::MayIUse(platform::jit::avx)) { \ + math::VecActivations act_functor; \ + act_gate = act_functor(act_gate_str); \ + act_state = act_functor(act_state_str); \ + cross = math::vec_cross; \ + } else { \ + math::VecActivations act_functor; \ + act_gate = act_functor(act_gate_str); \ + act_state = act_functor(act_state_str); \ + cross = math::vec_cross; \ + } + + void SeqCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* wx = ctx.Input("WeightX"); + auto* wh = ctx.Input("WeightH"); + auto* bias = ctx.Input("Bias"); + + auto* xx = ctx.Output("XX"); + auto* hidden_out = ctx.Output("Hidden"); + bool is_reverse = ctx.Attr("is_reverse"); + INIT_VEC_FUNC + + auto x_lod = x->lod(); + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 3D + const int N = x_lod[0].size() - 1; + const int total_T = x_dims[0]; + const int M = x_dims[1]; + const int D3 = wh_dims[1]; + const int D = wh_dims[0]; + const int D2 = D * 2; + + const T* x_data = x->data(); + const T* h0_data = h0 ? h0->data() : NULL; + const T* wx_data = wx->data(); + const T* wh_data = wh->data(); + const T* wh_state_data = wh_data + D * D2; + T* xx_data = xx->mutable_data(ctx.GetPlace()); + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, total_T, D3, M, x_data, wx_data, + xx_data, bias ? bias->data() : NULL); + + int xx_offset = D3; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 3; + hidden_out_data = hidden_out_data + offset; + xx_offset = -D3; + gate_offset = -D; + } + auto move_step = [&]() { + xx_data = xx_data + xx_offset; + hidden_out_data = hidden_out_data + gate_offset; + }; + for (int i = 0; i < N; ++i) { + int bid = is_reverse ? N - 1 - i : i; + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; + const T* prev_hidden_data = NULL; + int tstart = 0; + if (h0_data) { + prev_hidden_data = h0_data + bid * D; + } else { + // W: {W_update, W_reset; W_state} + // update gate + act_gate(D, xx_data, xx_data); + // state gate + act_state(D, xx_data + D2, xx_data + D2); + // out = a*b + blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data); + // save prev + prev_hidden_data = hidden_out_data; + tstart = 1; + move_step(); + } + for (int step = tstart; step < seq_len; ++step) { + // gemm prev * (Wu + Wr) + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast(1), + prev_hidden_data, D, wh_data, D2, static_cast(1), xx_data, + D3); + act_gate(D2, xx_data, xx_data); + // rt = rt*ht_1 inplace result + blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data); + + // gemm rt * Ws + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast(1), + hidden_out_data, D, wh_state_data, D, static_cast(1), + xx_data + D2, D3); + act_state(D, xx_data + D2, xx_data + D2); + // out = zt*ht~ + (1-zt)*ht_1 + cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data); + // save prev + prev_hidden_data = hidden_out_data; + move_step(); + } + } + } + + void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); auto* wx = ctx.Input("WeightX"); @@ -171,21 +294,7 @@ class FusionGRUKernel : public framework::OpKernel { auto* hidden_out = ctx.Output("Hidden"); bool is_reverse = ctx.Attr("is_reverse"); - std::function act_gate, act_state; - std::function cross; - auto& act_gate_str = ctx.Attr("gate_activation"); - auto& act_state_str = ctx.Attr("activation"); - if (platform::jit::MayIUse(platform::jit::avx)) { - math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_state = act_functor(act_state_str); - cross = math::vec_cross; - } else { - math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_state = act_functor(act_state_str); - cross = math::vec_cross; - } + INIT_VEC_FUNC const T* x_data = x->data(); const T* wx_data = wx->data(); @@ -305,6 +414,7 @@ class FusionGRUKernel : public framework::OpKernel { batched_out->set_lod(batched_lod); to_seq(dev_ctx, *batched_out, hidden_out); } +#undef INIT_VEC_FUNC }; } // namespace operators