diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index f91236975d0cf0c89a464188bd6ea1b5b01e0f6d..4187e313865a3d434d7afe6adce329d684720b55 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "Output(Hidden) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), "Output(Cell) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), - "Output(BatchedHidden) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), - "Output(BatchedCell) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), - "Output(ReorderedC0) of LSTM should not be null."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); @@ -99,17 +89,26 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchedHidden", out_dims); - ctx->SetOutputDim("BatchedCell", out_dims); ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); - int xx_width; if (ctx->Attrs().Get("use_seq")) { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), + "Output(BatchedInput) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), + "Output(BatchedHidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), + "Output(BatchedCell) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), + "Output(ReorderedH0) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), + "Output(ReorderedC0) of LSTM should not be null."); + ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); + ctx->SetOutputDim("BatchedHidden", out_dims); + ctx->SetOutputDim("BatchedCell", out_dims); } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX");