未验证 提交 a20b2b43 编写于 作者: H Hongyu Liu 提交者: GitHub

fix cudnn lstm shape bug; test=develop (#18492)

上级 c0a82748
...@@ -45,7 +45,11 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { ...@@ -45,7 +45,11 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3."); PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3.");
ctx->SetOutputDim("Out", ctx->GetInputDim("Input")); auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
out_dims[2] = hidden_size;
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH")); ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH"));
ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC")); ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC"));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册