提交 b4eec5d5 编写于 作者: H HydrogenSulfate 提交者: chajchaj

replace .where to '=='

上级 9765be09
......@@ -1703,11 +1703,16 @@ def cross_entropy(input,
"when weight is provided" \
.format(input.shape[axis], weight.shape[-1]))
valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label)
ignore_weight_mask = (
label != ignore_index) # ignored position will be False
valid_label = paddle.cast(
ignore_weight_mask,
dtype=label.dtype) * label # ignored position will be 0
ignore_weight_mask = paddle.cast(
ignore_weight_mask, out.dtype) # convert from 0 to 0.0
ignore_weight_mask = paddle.cast((label != ignore_index),
out.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
axis] == 1:
# TODO: Temporarily use squeeze instead of squeeze_
......@@ -1821,10 +1826,16 @@ def cross_entropy(input,
"when weight is provided" \
.format(input.shape[axis], weight.shape[-1]))
valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label)
ignore_weight_mask = paddle.cast((label != ignore_index),
input.dtype)
ignore_weight_mask = (
label != ignore_index) # ignored position will be False
valid_label = paddle.cast(
ignore_weight_mask,
dtype=label.dtype) * label # ignored position will be 0
ignore_weight_mask = paddle.cast(ignore_weight_mask,
out.dtype) # convert from 0 to 0.0
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
axis] == 1:
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册