未验证 提交 f17b2de8 编写于 作者: G Guanghua Yu 提交者: GitHub

Add the input check for softmax_with_cross_entropy (#49333)

上级 4a8708bb
......@@ -859,6 +859,7 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
auto logits_dims = logits.dims();
auto labels_dims = label.dims();
auto logits_rank = logits_dims.size();
auto labels_rank = labels_dims.size();
PADDLE_ENFORCE_GE(axis,
-logits_rank,
phi::errors::InvalidArgument(
......@@ -891,6 +892,12 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
"when not in numeric_stable_mode."));
}
PADDLE_ENFORCE_EQ(
(logits_rank - 1 != labels_rank) && (logits_rank != labels_rank),
false,
phi::errors::InvalidArgument("Expected input_dims - 1 == label_dims "
"or input_dims == label_dims."));
if (soft_label) {
if (config.is_runtime || (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[axis],
......
......@@ -2540,13 +2540,6 @@ def cross_entropy(
raise ValueError('The dimention of input should be larger than zero!')
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(
input_dims, label_dims
)
)
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册