提交 1e3e17df 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Remove the labels range check under the dynamic graph

上级 87d9fdae
...@@ -1696,6 +1696,13 @@ def cross_entropy(input, ...@@ -1696,6 +1696,13 @@ def cross_entropy(input,
out = _C_ops.elementwise_mul(out, weight_gather_reshape) out = _C_ops.elementwise_mul(out, weight_gather_reshape)
else: else:
if input.shape[axis] != weight.shape[-1]:
raise ValueError(
"input's class_dimension({}) must equal to "
"weight's class_dimension({}) "
"when weight is provided" \
.format(input.shape[axis], weight.shape[-1]))
valid_label = paddle.where(label == ignore_index, valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label) paddle.zeros_like(label), label)
# TODO: Temporarily use paddle.nonzero instead of paddle.max # TODO: Temporarily use paddle.nonzero instead of paddle.max
...@@ -1715,12 +1722,6 @@ def cross_entropy(input, ...@@ -1715,12 +1722,6 @@ def cross_entropy(input,
raise ValueError( raise ValueError(
"Target({}) is out of class_dimension's upper bound({})". "Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1)) format(invalid_label[0], input.shape[axis] - 1))
if input.shape[axis] != weight.shape[-1]:
raise ValueError(
"input's class_dimension({}) must equal to "
"weight's class_dimension({}) "
"when weight is provided" \
.format(input.shape[axis], weight.shape[-1]))
ignore_weight_mask = paddle.cast((label != ignore_index), ignore_weight_mask = paddle.cast((label != ignore_index),
out.dtype) out.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册