提交 9b30c51e 编写于 作者: H Hongyu Liu 提交者: phlrain

Merge pull request #16861 from tensor-tang/refine/infershape

separate runtime infershape
上级 0cc984bd
...@@ -64,13 +64,20 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -64,13 +64,20 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
}
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
if (ctx->IsRuntime() ||
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) " "The dimension of Input(H0) and Input(C0) "
"should be the same."); "should be the same.");
} }
}
auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
...@@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) { if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias"); auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册