提交 411b9ba5 编写于 作者: T tensor-tang

update

test=develop
上级 10879a3c
...@@ -54,18 +54,13 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -54,18 +54,13 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto w_dims = ctx->GetInputDim("LSTMWeight"); auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4; const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(w_dims[0], D + M,
PADDLE_ENFORCE_EQ(w_dims[0], D + M, "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
}
auto b_dims = ctx->GetInputDim("LSTMBias"); auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.",
4 * D);
}
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
...@@ -83,33 +78,27 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -83,33 +78,27 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2."); "Input(AttentionWeight)'s rank must be 2.");
if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
}
if (ctx->HasInput("AttentionBias")) { if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias"); auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2."); "Input(AttentionBias)'s rank must be 2.");
if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, "AttentionBias shapes must be 1 * 1.");
"AttentionBias shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, "AttentionBias shapes must be 1 * 1.");
"AttentionBias shapes must be 1 * 1.");
}
} }
if (ctx->HasInput("AttentionScalar")) { if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar"); auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2."); "Input(AttentionScalar)'s rank must be 2.");
if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
}
} }
if (ctx->HasInput("AttentionScalarBias")) { if (ctx->HasInput("AttentionScalarBias")) {
...@@ -119,12 +108,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -119,12 +108,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionScalar should not be null when have AttentionScalarBias."); "AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2."); "Input(AttentionScalarBias)'s rank must be 2.");
if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[0], 1, PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
"AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1,
"AttentionScalarBias shapes must be 1 * 1.");
}
} }
framework::DDim out_dims({x_dims[0], D}); framework::DDim out_dims({x_dims[0], D});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册