From 01eda557cd6dbc6b4c8bc53d26b0d8f0f3a893ee Mon Sep 17 00:00:00 2001 From: phlrain Date: Tue, 16 Apr 2019 09:02:12 +0000 Subject: [PATCH] fix bpr loss; test=developp --- paddle/fluid/operators/bpr_loss_op.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index b2dbaecfcfd..51c4d878142 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel { int rank = x_dims.size(); PADDLE_ENFORCE_EQ(rank, label_dims.size(), "Input(X) and Input(Label) shall have the same rank."); - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "Input(X) and Input(Label) shall have the same shape " - "except the last dimension."); + + if (ctx->IsRuntime() || (framework::product(x_dims) > 0 && + framework::product(label_dims) > 0)) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } auto y_dims = x_dims; y_dims[rank - 1] = 1; -- GitLab