提交 ef7e5fc7 编写于 作者: R root 提交者: chajchaj

imporve efficiency

上级 6cd96c19
...@@ -1411,11 +1411,13 @@ def cross_entropy(input, ...@@ -1411,11 +1411,13 @@ def cross_entropy(input,
out = core.ops.elementwise_mul(out, weight_gather_reshape) out = core.ops.elementwise_mul(out, weight_gather_reshape)
else: else:
for label_val in label.flatten(): label_min = paddle.min(label)
if label_val < 0 or label_val >= input.shape[-1]: label_max = paddle.max(label)
if label_min < 0 or label_max >= input.shape[-1]:
raise ValueError( raise ValueError(
'Expected 0 <= label_value < class_dimension({}), but got label_value {}'. 'Expected 0 <= label_value < class_dimension({}), but got {} <= label_value <= {} '.
format(input.shape[-1], label_val.numpy())) format(input.shape[-1],
label_min.numpy(), label_max.numpy()))
weight_gather = core.ops.gather_nd(weight, label) weight_gather = core.ops.gather_nd(weight, label)
input_shape = list(label.shape) input_shape = list(label.shape)
weight_gather_reshape = reshape( weight_gather_reshape = reshape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册