diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index ca24261bcc84e2d476891ef5ab7b89a981437b36..2ea15c85f338165df06763afc9a886228de8722e 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -145,11 +145,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { "But received: Y@Grad's rank is [%d], Y's rank is [%d]", dy_dims.size(), label_dims.size())); - bool check = true; - if ((!ctx->IsRuntime()) && - (framework::product(x_dims) <= 0 || framework::product(dy_dims) <= 0)) { - check = false; - } + bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || + framework::contain_unknown_dim(dy_dims); + + bool check = ctx->IsRuntime() || !contain_unknown_dim; if (check) { PADDLE_ENFORCE_EQ(