提交 f77083bb 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 59841e6f
......@@ -1713,7 +1713,8 @@ def cross_entropy(input,
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
axis] == 1:
# TODO: Temporarily use squeeze instead of squeeze_
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)
ignore_weight_mask = paddle.squeeze(ignore_weight_mask,
axis)
if axis != -1 and axis != valid_label.ndim - 1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册