提交 87d9fdae 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Remove the labels range check under the dynamic graph

上级 46e856c7
...@@ -1681,7 +1681,7 @@ def cross_entropy(input, ...@@ -1681,7 +1681,7 @@ def cross_entropy(input,
# trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases. # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.
if soft_label == True: if soft_label == True:
# chajchaj: # chajchaj:
# weight's shape is C, where C is class num. # weight's shape is C, where C is class num.
# for 1d case: label's shape is [N,C], weight_gather's shape is N. # for 1d case: label's shape is [N,C], weight_gather's shape is N.
# for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W]. # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].
weight_gather = paddle.matmul( weight_gather = paddle.matmul(
...@@ -1697,7 +1697,7 @@ def cross_entropy(input, ...@@ -1697,7 +1697,7 @@ def cross_entropy(input,
else: else:
valid_label = paddle.where(label == ignore_index, valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label) paddle.zeros_like(label), label)
# TODO: Temporarily use paddle.nonzero instead of paddle.max # TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values # to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label < 0)) > 0: if len(paddle.nonzero(valid_label < 0)) > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册