From a9539cbf7d1e6c94df0284fe3dbc1ee22bbaab61 Mon Sep 17 00:00:00 2001 From: Hongyu Liu <43953930+phlrain@users.noreply.github.com> Date: Wed, 17 Apr 2019 09:59:49 +0800 Subject: [PATCH] Merge pull request #16913 from phlrain/fix_bpr_loss Fix bpr loss --- paddle/fluid/operators/bpr_loss_op.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index b2dbaecfcfd..51c4d878142 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel { int rank = x_dims.size(); PADDLE_ENFORCE_EQ(rank, label_dims.size(), "Input(X) and Input(Label) shall have the same rank."); - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "Input(X) and Input(Label) shall have the same shape " - "except the last dimension."); + + if (ctx->IsRuntime() || (framework::product(x_dims) > 0 && + framework::product(label_dims) > 0)) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } auto y_dims = x_dims; y_dims[rank - 1] = 1; -- GitLab