From c3195de522202fe211e6be1a871c8091a3caecce Mon Sep 17 00:00:00 2001 From: Hongyu Liu <43953930+phlrain@users.noreply.github.com> Date: Wed, 8 May 2019 16:20:32 +0800 Subject: [PATCH] Fix concat shape check (#17247) * fix shape_check; test=develop * fix format; test=develop * fix format; test=develop * fix ddim bug; test=develop * fix c++ format; test=develop * change function name; test=develop --- paddle/fluid/framework/ddim.cc | 10 ++++++++++ paddle/fluid/framework/ddim.h | 2 ++ paddle/fluid/operators/concat_op.cc | 11 +++-------- paddle/fluid/operators/cross_entropy_op.cc | 8 +++----- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index e7a6df57e..bbc9982d9 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -121,6 +121,16 @@ int64_t product(const DDim& ddim) { return ddim.apply_visitor(ProductVisitor()); } +bool contain_unknown_dim(const DDim& ddim) { + for (int i = 0; i < ddim.size(); ++i) { + if (ddim[i] < 0) { + return true; + } + } + + return false; +} + DDim slice_ddim(const DDim& dim, int begin, int end) { PADDLE_ENFORCE(begin >= 0 && end <= dim.size(), "[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.", diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 31a41dab2..7d2e296b6 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -182,6 +182,8 @@ std::vector vectorize2int(const DDim& ddim); int64_t product(const DDim& ddim); +bool contain_unknown_dim(const DDim& ddim); + /** * \brief Slice a ddim * diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index b1a6d66b8..029b05bb6 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -59,18 +59,13 @@ class ConcatOp : public framework::OperatorWithKernel { } } } else { - if (ctx->IsRuntime()) { + bool check_shape = + ctx->IsRuntime() || (out_dims[j] > 0 && ins[i][j] > 0); + if (check_shape) { // check all shape in run time PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], "Input tensors should have the same " "elements except the specify axis."); - } else { - // not check -1 with other in compile time - if (out_dims[j] > 0 && ins[i][j] > 0) { - PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], - "Input tensors should have the same " - "elements except the specify axis."); - } } } } diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index ad32de53e..da2c74b0c 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -35,11 +35,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { int rank = x_dims.size(); PADDLE_ENFORCE_EQ(rank, label_dims.size(), "Input(X) and Input(Label) shall have the same rank."); - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || - framework::product(label_dims) <= 0)) { - check = false; - } + bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || + framework::contain_unknown_dim(label_dims); + bool check = ctx->IsRuntime() || !contain_unknown_dim; if (check) { PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), framework::slice_ddim(label_dims, 0, rank - 1), -- GitLab