diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index b2dbaecfcfd67cc679d02e22d4e89cfedeeba80c..51c4d878142dcd93a170c9ea4211b9c6ec8e4422 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel { int rank = x_dims.size(); PADDLE_ENFORCE_EQ(rank, label_dims.size(), "Input(X) and Input(Label) shall have the same rank."); - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "Input(X) and Input(Label) shall have the same shape " - "except the last dimension."); + + if (ctx->IsRuntime() || (framework::product(x_dims) > 0 && + framework::product(label_dims) > 0)) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } auto y_dims = x_dims; y_dims[rank - 1] = 1;