提交 05eaa9bc 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 a2327374
...@@ -1811,8 +1811,7 @@ def cross_entropy(input, ...@@ -1811,8 +1811,7 @@ def cross_entropy(input,
.format(input.shape[-1], weight.shape[-1])) .format(input.shape[-1], weight.shape[-1]))
valid_label = paddle.where( valid_label = paddle.where(
label == ignore_index, label == ignore_index,
paddle.to_tensor( paddle.zeros([1], dtype=label.dtype), dtype=label.dtype),
0, dtype=label.dtype),
label) label)
if (len(paddle.nonzero(valid_label < 0)) > 0) or ( if (len(paddle.nonzero(valid_label < 0)) > 0) or (
len(paddle.nonzero(valid_label >= input.shape[-1])) > 0): len(paddle.nonzero(valid_label >= input.shape[-1])) > 0):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册