From f17b2de8edac2c1b9d5dd42bed49d1249ae18db1 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Wed, 4 Jan 2023 19:38:59 +0800 Subject: [PATCH] Add the input check for softmax_with_cross_entropy (#49333) --- paddle/phi/infermeta/binary.cc | 7 +++++++ python/paddle/nn/functional/loss.py | 7 ------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 532aed7f66d..74eecad101c 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -859,6 +859,7 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits, auto logits_dims = logits.dims(); auto labels_dims = label.dims(); auto logits_rank = logits_dims.size(); + auto labels_rank = labels_dims.size(); PADDLE_ENFORCE_GE(axis, -logits_rank, phi::errors::InvalidArgument( @@ -891,6 +892,12 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits, "when not in numeric_stable_mode.")); } + PADDLE_ENFORCE_EQ( + (logits_rank - 1 != labels_rank) && (logits_rank != labels_rank), + false, + phi::errors::InvalidArgument("Expected input_dims - 1 == label_dims " + "or input_dims == label_dims.")); + if (soft_label) { if (config.is_runtime || (logits_dims[axis] > 0 && labels_dims[axis] > 0)) { PADDLE_ENFORCE_EQ(logits_dims[axis], diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index a255495c8de..74f42804341 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2540,13 +2540,6 @@ def cross_entropy( raise ValueError('The dimention of input should be larger than zero!') label_dims = len(list(label.shape)) - if input_dims - 1 != label_dims and input_dims != label_dims: - raise ValueError( - 'Expected nput_dims - 1 = label_dims or input_dims == label_dims\ - (got nput_dims{}, label_dims{})'.format( - input_dims, label_dims - ) - ) if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=axis) -- GitLab