未验证 提交 8bfd45ad 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Cherry-Pick]Move valid check from python to kernel (#46980)

为了提升性能,将label的边界检查从python端转移到kernel内,减少额外op的调用,如min、max和同步拷贝等
    当前的模板参数IgnoreIndex仅在ignore_index取值范围在[0, dim)时才生效,但是当某个label值超出了边界,ignore_index等于该label,这种情况下是应该仍然能正常计算。虽然当前的计算逻辑在结果上不会出错,但逻辑上仍是有问题的,且模板参数IgnoreIndex是没有必要的
上级 5c2bea17
...@@ -2386,14 +2386,6 @@ def cross_entropy(input, ...@@ -2386,14 +2386,6 @@ def cross_entropy(input,
if soft_label == False: if soft_label == False:
valid_label = paddle.cast(label != ignore_index, valid_label = paddle.cast(label != ignore_index,
dtype=label.dtype) * label dtype=label.dtype) * label
label_min = paddle.min(valid_label)
label_max = paddle.max(valid_label)
if label_min < 0:
raise ValueError("Target {} is out of lower bound.".format(
label_min.item()))
if label_max >= input.shape[axis]:
raise ValueError("Target {} is out of upper bound.".format(
label_max.item()))
if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
if soft_label == False: if soft_label == False:
_, _, out = _legacy_C_ops.softmax_with_cross_entropy( _, _, out = _legacy_C_ops.softmax_with_cross_entropy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册