提交 53dc0143 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 8c2fbc31
......@@ -1712,9 +1712,9 @@ def cross_entropy(input,
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
axis] == 1:
ignore_weight_mask.squeeze_(axis)
if axis != -1:
if axis != -1 and axis != valid_label.ndim - 1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis + 1) % valid_label.ndim, valid_label.ndim)) \
+ list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \
+ [axis%valid_label.ndim]
weight_gather = _C_ops.gather_nd(
weight, valid_label.transpose(temp_perm))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册