提交 11e9d4e3 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 0c2d6bcb
...@@ -1806,8 +1806,8 @@ def cross_entropy(input, ...@@ -1806,8 +1806,8 @@ def cross_entropy(input,
valid_label = paddle.where(label == ignore_index, valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label) paddle.zeros_like(label), label)
ignore_weight_mask = paddle.cast( ignore_weight_mask = paddle.cast((label != ignore_index),
(label != ignore_index), input.dtype) input.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
-1] == 1: -1] == 1:
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1) ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册