diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index 582c75872ab2818cdf834f9a46278db1d6f91d54..916f84cb4a78c3721cb67bd3cf8d3759a8eaf1bf 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { "Input(WeightX) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasInput("WeightH"), "Input(WeightH) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), - "Output(BatchedOut) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Output(Hidden) of GRU should not be null."); @@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { } framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); - ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchedOut", out_dims); ctx->ShareLoD("X", "Hidden"); - int xx_width; if (ctx->Attrs().Get("use_seq")) { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), + "Output(ReorderedH0) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), + "Output(BatchedInput) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), + "Output(BatchedOut) of GRU should not be null."); + ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); + ctx->SetOutputDim("BatchedOut", out_dims); } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX");