未验证 提交 c3195de5 编写于 作者: H Hongyu Liu 提交者: GitHub

Fix concat shape check (#17247)

* fix shape_check; test=develop

* fix format; test=develop

* fix format; test=develop

* fix ddim bug; test=develop

* fix c++ format; test=develop

* change function name; test=develop
上级 6d1d7c8a
...@@ -121,6 +121,16 @@ int64_t product(const DDim& ddim) { ...@@ -121,6 +121,16 @@ int64_t product(const DDim& ddim) {
return ddim.apply_visitor(ProductVisitor()); return ddim.apply_visitor(ProductVisitor());
} }
bool contain_unknown_dim(const DDim& ddim) {
for (int i = 0; i < ddim.size(); ++i) {
if (ddim[i] < 0) {
return true;
}
}
return false;
}
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0 && end <= dim.size(), PADDLE_ENFORCE(begin >= 0 && end <= dim.size(),
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.", "[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
......
...@@ -182,6 +182,8 @@ std::vector<int> vectorize2int(const DDim& ddim); ...@@ -182,6 +182,8 @@ std::vector<int> vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim); int64_t product(const DDim& ddim);
bool contain_unknown_dim(const DDim& ddim);
/** /**
* \brief Slice a ddim * \brief Slice a ddim
* *
......
...@@ -59,18 +59,13 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -59,18 +59,13 @@ class ConcatOp : public framework::OperatorWithKernel {
} }
} }
} else { } else {
if (ctx->IsRuntime()) { bool check_shape =
ctx->IsRuntime() || (out_dims[j] > 0 && ins[i][j] > 0);
if (check_shape) {
// check all shape in run time // check all shape in run time
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same " "Input tensors should have the same "
"elements except the specify axis."); "elements except the specify axis.");
} else {
// not check -1 with other in compile time
if (out_dims[j] > 0 && ins[i][j] > 0) {
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
}
} }
} }
} }
......
...@@ -35,11 +35,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { ...@@ -35,11 +35,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
int rank = x_dims.size(); int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(), PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank."); "Input(X) and Input(Label) shall have the same rank.");
bool check = true; bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || framework::contain_unknown_dim(label_dims);
framework::product(label_dims) <= 0)) { bool check = ctx->IsRuntime() || !contain_unknown_dim;
check = false;
}
if (check) { if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1), framework::slice_ddim(label_dims, 0, rank - 1),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册