未验证 提交 5a2d6d6b 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #16956 from guoshengCS/cherry-pick-infer-shape

cherry-pick #16898 and #16902 to release/1.4
...@@ -71,8 +71,16 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase { ...@@ -71,8 +71,16 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
"Input(Y) of %s operator must not be null", comment.type); "Input(Y) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X"); auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y"); auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
"The number of elements in X and Y should be same"); int product_x = framework::product(dim_x);
int product_y = framework::product(dim_y);
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
if (check) {
PADDLE_ENFORCE_EQ(
product_x, product_y,
"The number of elements in X and Y should be same, %d != %d",
product_x, product_y);
}
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
......
...@@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel { ...@@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel {
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1]; int input_size = input_dims[1];
int frame_size = weight_dims[0]; int frame_size = weight_dims[0];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3, if (ctx->IsRuntime()) {
"The input_size must be 3 times of frame_size in GRUOp."); PADDLE_ENFORCE_EQ(
input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUOp.");
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_dims[1], frame_size * 3, weight_dims[1], frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); "The shape of Weight matrix must be [frame_size, frame_size * 3].");
......
...@@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel { ...@@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel {
auto c_prev_dims = ctx->GetInputDim("C_prev"); auto c_prev_dims = ctx->GetInputDim("C_prev");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0], if (ctx->IsRuntime()) {
"Batch size of inputs and states must be equal"); PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4, "Batch size of inputs and states must be equal");
"Dimension of FC should equal to prev state * 4"); PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
"Dimension of FC should equal to prev state * 4");
}
int b_size = c_prev_dims[0]; // batch size int b_size = c_prev_dims[0]; // batch size
int s_dim = c_prev_dims[1]; // state dim int s_dim = c_prev_dims[1]; // state dim
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册