From a20b2b43fc68943a39735c1d051bc07f463415df Mon Sep 17 00:00:00 2001 From: Hongyu Liu <43953930+phlrain@users.noreply.github.com> Date: Thu, 11 Jul 2019 16:10:20 +0800 Subject: [PATCH] fix cudnn lstm shape bug; test=develop (#18492) --- paddle/fluid/operators/cudnn_lstm_op.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index 134f84d59ca..73e04da3b0d 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -45,7 +45,11 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { auto in_dims = ctx->GetInputDim("Input"); PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3."); - ctx->SetOutputDim("Out", ctx->GetInputDim("Input")); + auto out_dims = in_dims; + auto hidden_size = ctx->Attrs().Get("hidden_size"); + out_dims[2] = hidden_size; + + ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH")); ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC")); } -- GitLab