提交 ef7e5fc7 编写于 作者: R root 提交者: chajchaj

imporve efficiency

上级 6cd96c19
......@@ -1411,11 +1411,13 @@ def cross_entropy(input,
out = core.ops.elementwise_mul(out, weight_gather_reshape)
else:
for label_val in label.flatten():
if label_val < 0 or label_val >= input.shape[-1]:
raise ValueError(
'Expected 0 <= label_value < class_dimension({}), but got label_value {}'.
format(input.shape[-1], label_val.numpy()))
label_min = paddle.min(label)
label_max = paddle.max(label)
if label_min < 0 or label_max >= input.shape[-1]:
raise ValueError(
'Expected 0 <= label_value < class_dimension({}), but got {} <= label_value <= {} '.
format(input.shape[-1],
label_min.numpy(), label_max.numpy()))
weight_gather = core.ops.gather_nd(weight, label)
input_shape = list(label.shape)
weight_gather_reshape = reshape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册