未验证 提交 ad2a2bb0 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #16913 from phlrain/fix_bpr_loss

Fix bpr loss
...@@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel { ...@@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel {
int rank = x_dims.size(); int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(), PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank."); "Input(X) and Input(Label) shall have the same rank.");
if (ctx->IsRuntime() || (framework::product(x_dims) > 0 &&
framework::product(label_dims) > 0)) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1), framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape " "Input(X) and Input(Label) shall have the same shape "
"except the last dimension."); "except the last dimension.");
}
auto y_dims = x_dims; auto y_dims = x_dims;
y_dims[rank - 1] = 1; y_dims[rank - 1] = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册