diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index 134f84d59cafa661fce727adc3303444c4ef483e..73e04da3b0db275ed4d49878e8c0a8879b3106dd 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -45,7 +45,11 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { auto in_dims = ctx->GetInputDim("Input"); PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3."); - ctx->SetOutputDim("Out", ctx->GetInputDim("Input")); + auto out_dims = in_dims; + auto hidden_size = ctx->Attrs().Get("hidden_size"); + out_dims[2] = hidden_size; + + ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH")); ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC")); }