提交 6cd41cec 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 3675f25d
......@@ -1712,11 +1712,12 @@ def cross_entropy(input,
out.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
axis] == 1:
ignore_weight_mask.squeeze_(axis)
# TODO: Temporarily use squeeze instead of squeeze_
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)
if axis != -1 and axis != valid_label.ndim - 1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \
+ [axis%valid_label.ndim]
+ [axis % valid_label.ndim]
weight_gather = _C_ops.gather_nd(
weight, valid_label.transpose(temp_perm))
else:
......@@ -1828,9 +1829,9 @@ def cross_entropy(input,
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
axis] == 1:
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)
if axis != -1:
if axis != -1 and axis != valid_label.ndim - 1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis + 1) % valid_label.ndim, valid_label.ndim)) \
+ list(range((axis % valid_label.ndim + 1), valid_label.ndim)) \
+ [axis % valid_label.ndim]
weight_gather = paddle.gather_nd(
weight, paddle.transpose(valid_label, temp_perm))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册