提交 ed892eba 编写于 作者: T tensor-tang

update

test=develop
上级 411b9ba5
...@@ -70,9 +70,13 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -70,9 +70,13 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
"The dimension of Input(H0) and Input(C0) " if (ctx->IsRuntime() ||
"should be the same."); (framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
} }
auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册