diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index e7a6df57e538164969bc101ced4b91de8f75ca56..bbc9982d9db4cd5bec872b44d2385afccd77ffd3 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -121,6 +121,16 @@ int64_t product(const DDim& ddim) { return ddim.apply_visitor(ProductVisitor()); } +bool contain_unknown_dim(const DDim& ddim) { + for (int i = 0; i < ddim.size(); ++i) { + if (ddim[i] < 0) { + return true; + } + } + + return false; +} + DDim slice_ddim(const DDim& dim, int begin, int end) { PADDLE_ENFORCE(begin >= 0 && end <= dim.size(), "[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.", diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 31a41dab2a1f1d6bad9fe697c5d367f32e219160..7d2e296b6c1a99180acc105eb73754233cfa15f4 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -182,6 +182,8 @@ std::vector vectorize2int(const DDim& ddim); int64_t product(const DDim& ddim); +bool contain_unknown_dim(const DDim& ddim); + /** * \brief Slice a ddim * diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index b1a6d66b80efdae3e78d7c3321a6107d2dd607aa..029b05bb662440bcf94521376b56d234a828ddf5 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -59,18 +59,13 @@ class ConcatOp : public framework::OperatorWithKernel { } } } else { - if (ctx->IsRuntime()) { + bool check_shape = + ctx->IsRuntime() || (out_dims[j] > 0 && ins[i][j] > 0); + if (check_shape) { // check all shape in run time PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], "Input tensors should have the same " "elements except the specify axis."); - } else { - // not check -1 with other in compile time - if (out_dims[j] > 0 && ins[i][j] > 0) { - PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], - "Input tensors should have the same " - "elements except the specify axis."); - } } } } diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index ad32de53e7019b438b7106ddd031a8f00bd79b5d..da2c74b0c8a8b0fbeee13c4a3d490d7761abb93c 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -35,11 +35,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { int rank = x_dims.size(); PADDLE_ENFORCE_EQ(rank, label_dims.size(), "Input(X) and Input(Label) shall have the same rank."); - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || - framework::product(label_dims) <= 0)) { - check = false; - } + bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || + framework::contain_unknown_dim(label_dims); + bool check = ctx->IsRuntime() || !contain_unknown_dim; if (check) { PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), framework::slice_ddim(label_dims, 0, rank - 1),