提交 fa4805b4 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 39e81532
......@@ -1812,10 +1812,11 @@ def cross_entropy(input,
valid_label = paddle.where(
label == ignore_index,
paddle.zeros(
[1], dtype=label.dtype),
label)
[1], dtype=label.dtype),
label)
if (paddle.numel(paddle.nonzero(valid_label < 0)) > 0) or (
paddle.numel(paddle.nonzero(valid_label >= input.shape[-1])) > 0):
paddle.numel(
paddle.nonzero(valid_label >= input.shape[-1])) > 0):
invalid_label = paddle.gather_nd(
input, paddle.nonzero(valid_label < 0))
if paddle.numel(invalid_label) > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册