提交 e2c293f6 编写于 作者: R root 提交者: chajchaj

fix ci bug

上级 400eb9d8
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import unittest import unittest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
from test_softmax_with_cross_entropy_op import cross_entropy from test_softmax_with_cross_entropy_op import cross_entropy
from paddle.fluid import Program, program_guard
def stable_softmax(x): def stable_softmax(x):
......
...@@ -1411,7 +1411,7 @@ def cross_entropy(input, ...@@ -1411,7 +1411,7 @@ def cross_entropy(input,
out = core.ops.elementwise_mul(out, weight_gather_reshape) out = core.ops.elementwise_mul(out, weight_gather_reshape)
else: else:
for label_val in label: for label_val in label.flatten():
if label_val < 0 or label_val >= input.shape[-1]: if label_val < 0 or label_val >= input.shape[-1]:
raise ValueError( raise ValueError(
'Expected 0 <= label_value < class_dimension({}), but got label_value {}'. 'Expected 0 <= label_value < class_dimension({}), but got label_value {}'.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册