diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 7bed1570507f97c8f84da5b6138ff49a2912dc91..c404a6c44ccea8287ddfad976889a9f80cf6bad9 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" +#include #include "paddle/fluid/framework/lod_tensor.h" namespace paddle { @@ -97,6 +98,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"}); op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes")); + // TODO(TJ): get from attr + op_desc.SetAttr("use_seq", true); #define TMP_NAME(x) "at.new.tmp." #x #define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)}) @@ -134,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, auto fc_no_bias_handler = [&]( const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - #define GET_NODE(name__) \ std::string name__##key = name_scope + "/" + #name__; \ auto* name__##n = pattern->RetrieveNode(name__##key); \ diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index d6439acf272f0be225beb54c2b85e335e79fc4e7..f91236975d0cf0c89a464188bd6ea1b5b01e0f6d 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -16,14 +16,10 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc_compute.h" -#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" -DEFINE_bool(seq_mode, true, "Use sequence mode"); - namespace paddle { namespace operators { @@ -110,7 +106,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Cell"); int xx_width; - if (FLAGS_seq_mode) { + if (ctx->Attrs().Get("use_seq")) { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; @@ -189,6 +185,10 @@ void FusionLSTMOpMaker::Make() { "(bool, defalut: False) " "whether to compute reversed LSTM.") .SetDefault(false); + AddAttr("use_seq", + "(bool, defalut: True) " + "whether to use seq mode to compute.") + .SetDefault(true); AddAttr("gate_activation", "(string, default: sigmoid)" "The activation for input gate, forget gate and output " @@ -264,8 +264,8 @@ class FuisonLSTMKernel : public framework::OpKernel { const int N = x_lod[0].size() - 1; // batch size const T* x_data = x->data(); - const T* h0_data = h0 ? h0->data() : NULL; - const T* c0_data = c0 ? c0->data() : NULL; + const T* h0_data = h0 ? h0->data() : nullptr; + const T* c0_data = c0 ? c0->data() : nullptr; const T* wx_data = wx->data(); const T* wh_data = wh->data(); T* xx_data = xx->mutable_data(ctx.GetPlace()); @@ -295,8 +295,8 @@ class FuisonLSTMKernel : public framework::OpKernel { for (int i = 0; i < N; ++i) { int bid = is_reverse ? N - 1 - i : i; int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; - const T* prev_c_data = NULL; - const T* prev_h_data = NULL; + const T* prev_c_data = nullptr; + const T* prev_h_data = nullptr; int tstart = 0; if (h0_data) { prev_h_data = h0_data + bid * D; @@ -351,8 +351,9 @@ class FuisonLSTMKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT - if (x->lod()[0].size() == 2) { // batch size == 1 + if (x->lod()[0].size() == 2) { SeqCompute(ctx); + return; } INIT_BASE_SIZES INIT_VEC_FUNC @@ -396,8 +397,8 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_c0->Resize({max_bs, D}); int tstart = 0; - T* prev_h_data = NULL; - T* prev_c_data = NULL; + T* prev_h_data = nullptr; + T* prev_c_data = nullptr; if (h0) { // reorder h0, c0 T* reordered_h0_data = reordered_h0->mutable_data(place); @@ -489,7 +490,7 @@ class FuisonLSTMKernel : public framework::OpKernel { } void Compute(const framework::ExecutionContext& ctx) const override { - if (FLAGS_seq_mode) { + if (ctx.Attr("use_seq")) { SeqCompute(ctx); } else { BatchCompute(ctx);