未验证 提交 c3195de5 编写于 作者: H Hongyu Liu 提交者: GitHub

Fix concat shape check (#17247)

* fix shape_check; test=develop

* fix format; test=develop

* fix format; test=develop

* fix ddim bug; test=develop

* fix c++ format; test=develop

* change function name; test=develop
上级 6d1d7c8a
......@@ -121,6 +121,16 @@ int64_t product(const DDim& ddim) {
return ddim.apply_visitor(ProductVisitor());
}
bool contain_unknown_dim(const DDim& ddim) {
for (int i = 0; i < ddim.size(); ++i) {
if (ddim[i] < 0) {
return true;
}
}
return false;
}
DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0 && end <= dim.size(),
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
......
......@@ -182,6 +182,8 @@ std::vector<int> vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim);
bool contain_unknown_dim(const DDim& ddim);
/**
* \brief Slice a ddim
*
......
......@@ -59,18 +59,13 @@ class ConcatOp : public framework::OperatorWithKernel {
}
}
} else {
if (ctx->IsRuntime()) {
bool check_shape =
ctx->IsRuntime() || (out_dims[j] > 0 && ins[i][j] > 0);
if (check_shape) {
// check all shape in run time
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
} else {
// not check -1 with other in compile time
if (out_dims[j] > 0 && ins[i][j] > 0) {
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
}
}
}
}
......
......@@ -35,11 +35,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(label_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册