提交 10879a3c 编写于 作者: T tensor-tang

separate runtime infershape

test=develop
上级 de26df44
...@@ -54,17 +54,25 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -54,17 +54,25 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto w_dims = ctx->GetInputDim("LSTMWeight"); auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4; const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(w_dims[0], D + M, PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D); "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
}
auto b_dims = ctx->GetInputDim("LSTMBias"); auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.",
4 * D);
}
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
}
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE(h_dims == c_dims,
...@@ -75,27 +83,34 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -75,27 +83,34 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2."); "Input(AttentionWeight)'s rank must be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
}
if (ctx->HasInput("AttentionBias")) { if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias"); auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2."); "Input(AttentionBias)'s rank must be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1."); "AttentionBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1."); "AttentionBias shapes must be 1 * 1.");
} }
}
if (ctx->HasInput("AttentionScalar")) { if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar"); auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2."); "Input(AttentionScalar)'s rank must be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
} }
}
if (ctx->HasInput("AttentionScalarBias")) { if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias"); auto dims = ctx->GetInputDim("AttentionScalarBias");
...@@ -104,8 +119,12 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -104,8 +119,12 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionScalar should not be null when have AttentionScalarBias."); "AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2."); "Input(AttentionScalarBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[0], 1,
"AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1,
"AttentionScalarBias shapes must be 1 * 1.");
}
} }
framework::DDim out_dims({x_dims[0], D}); framework::DDim out_dims({x_dims[0], D});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册