diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index bdd03caa3bc62d6b06ac9298002caf47ea8f87b7..582c75872ab2818cdf834f9a46278db1d6f91d54 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -21,8 +21,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" -DEFINE_bool(gru_use_seq, true, "Use sequence mode"); - namespace paddle { namespace operators { @@ -87,7 +85,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Hidden"); int xx_width; - if (FLAGS_gru_use_seq) { + if (ctx->Attrs().Get("use_seq")) { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; @@ -136,7 +134,10 @@ void FusionGRUOpMaker::Make() { " where T is the total time steps in this mini-batch," " D is the hidden size, M is the dim size of x input.") .AsIntermediate(); - AddOutput("BatchedInput", "(LoDTensor) (T x 3D)").AsIntermediate(); + AddOutput("BatchedInput", + "(LoDTensor) This is the batched result of input X" + "or the batched result after fc, shape (T x 3D)") + .AsIntermediate(); AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.") .AsIntermediate(); AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp"); @@ -153,6 +154,10 @@ void FusionGRUOpMaker::Make() { "(bool, defalut: False) " "whether to compute reversed GRU.") .SetDefault(false); + AddAttr("use_seq", + "(bool, defalut: True) " + "whether to use seq mode to compute GRU.") + .SetDefault(true); AddComment(R"DOC( The Fusion complete GRU Operator. This operator fuse the fully-connected operator into GRU, @@ -164,7 +169,7 @@ template class FusionGRUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - if (FLAGS_gru_use_seq) { + if (ctx.Attr("use_seq")) { SeqCompute(ctx); } else { BatchCompute(ctx); @@ -188,31 +193,35 @@ class FusionGRUKernel : public framework::OpKernel { cross = math::vec_cross; \ } +#define INIT_BASE_INPUT_OUTPUT \ + auto* h0 = ctx.Input("H0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + bool is_reverse = ctx.Attr("is_reverse"); + +#define INIT_BASE_SIZES \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 3D*/ \ + const int total_T = x_dims[0]; \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D3 = wh_dims[1]; \ + const int D2 = D * 2; + void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); - auto* h0 = ctx.Input("H0"); - auto* wx = ctx.Input("WeightX"); - auto* wh = ctx.Input("WeightH"); - auto* bias = ctx.Input("Bias"); - - auto* xx = ctx.Output("XX"); - auto* hidden_out = ctx.Output("Hidden"); - bool is_reverse = ctx.Attr("is_reverse"); + INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES INIT_VEC_FUNC auto x_lod = x->lod(); - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 3D const int N = x_lod[0].size() - 1; - const int total_T = x_dims[0]; - const int M = x_dims[1]; - const int D3 = wh_dims[1]; - const int D = wh_dims[0]; - const int D2 = D * 2; - const T* x_data = x->data(); - const T* h0_data = h0 ? h0->data() : NULL; + const T* h0_data = h0 ? h0->data() : nullptr; const T* wx_data = wx->data(); const T* wh_data = wh->data(); const T* wh_state_data = wh_data + D * D2; @@ -221,7 +230,8 @@ class FusionGRUKernel : public framework::OpKernel { auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D3, M, x_data, wx_data, - xx_data, bias ? bias->data() : NULL); + xx_data, + bias ? bias->data() : nullptr); int xx_offset = D3; int gate_offset = D; @@ -239,7 +249,7 @@ class FusionGRUKernel : public framework::OpKernel { for (int i = 0; i < N; ++i) { int bid = is_reverse ? N - 1 - i : i; int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; - const T* prev_hidden_data = NULL; + const T* prev_hidden_data = nullptr; int tstart = 0; if (h0_data) { prev_hidden_data = h0_data + bid * D; @@ -282,19 +292,17 @@ class FusionGRUKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); - auto* wx = ctx.Input("WeightX"); - auto* wh = ctx.Input("WeightH"); - auto* bias = ctx.Input("Bias"); - auto* h0 = ctx.Input("H0"); + if (x->lod()[0].size() == 2) { + SeqCompute(ctx); + return; + } + INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES + INIT_VEC_FUNC auto* reordered_h0 = ctx.Output("ReorderedH0"); - auto* xx = ctx.Output("XX"); auto* batched_input = ctx.Output("BatchedInput"); auto* batched_out = ctx.Output("BatchedOut"); - auto* hidden_out = ctx.Output("Hidden"); - - bool is_reverse = ctx.Attr("is_reverse"); - INIT_VEC_FUNC const T* x_data = x->data(); const T* wx_data = wx->data(); @@ -304,25 +312,20 @@ class FusionGRUKernel : public framework::OpKernel { T* batched_out_data = batched_out->mutable_data(ctx.GetPlace()); hidden_out->mutable_data(ctx.GetPlace()); - auto x_dims = x->dims(); - auto wx_dims = wx->dims(); - const int D3 = wx_dims[1]; - const int D = D3 / 3; - const int D2 = D * 2; auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); math::LoDTensor2BatchFunctor to_batch; - if (x_dims[1] > wx_dims[1]) { - math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], - x_data, wx_data, xx_data, - bias ? bias->data() : NULL); + if (M > D3) { + math::FCCompute(blas, total_T, D3, M, x_data, wx_data, + xx_data, + bias ? bias->data() : nullptr); to_batch(dev_ctx, *xx, batched_input, true, is_reverse); } else { to_batch(dev_ctx, *x, xx, true, is_reverse); batched_input->set_lod(xx->lod()); - math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], - xx_data, wx_data, batched_input_data, - bias ? bias->data() : NULL); + math::FCCompute(blas, total_T, D3, M, xx_data, wx_data, + batched_input_data, + bias ? bias->data() : nullptr); } auto batched_lod = batched_input->lod(); @@ -331,7 +334,7 @@ class FusionGRUKernel : public framework::OpKernel { reordered_h0->Resize({max_bs, D}); int tstart = 0; - T* prev_hidden_data = NULL; + T* prev_hidden_data = nullptr; if (h0) { // reorder h0 T* reordered_h0_data = reordered_h0->mutable_data(ctx.GetPlace()); @@ -415,6 +418,8 @@ class FusionGRUKernel : public framework::OpKernel { to_seq(dev_ctx, *batched_out, hidden_out); } #undef INIT_VEC_FUNC +#undef INIT_BASE_SIZES +#undef INIT_BASE_INPUT_OUTPUT }; } // namespace operators