提交 d1a11056 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update test_cross_entropy_loss.py

上级 dd0140bd
...@@ -1533,9 +1533,9 @@ class TestCrossEntropyFAPIError(unittest.TestCase): ...@@ -1533,9 +1533,9 @@ class TestCrossEntropyFAPIError(unittest.TestCase):
self.assertRaises(ValueError, test_LabelValue_ExceedMin) self.assertRaises(ValueError, test_LabelValue_ExceedMin)
def static_test_WeightLength_NotEqual(): def static_test_WeightLength_NotEqual():
input_np = np.random.random([2, 4]).astype(self.dtype) input_np = np.random.random([2, 4]).astype('float32')
label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)
weight_np = np.random.random([3]).astype(self.dtype) #shape:C weight_np = np.random.random([3]).astype('float32')
paddle.enable_static() paddle.enable_static()
prog = fluid.Program() prog = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
...@@ -1543,11 +1543,11 @@ class TestCrossEntropyFAPIError(unittest.TestCase): ...@@ -1543,11 +1543,11 @@ class TestCrossEntropyFAPIError(unittest.TestCase):
) else fluid.CPUPlace() ) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog): with fluid.program_guard(prog, startup_prog):
input = fluid.data( input = fluid.data(
name='input', shape=[2, 4], dtype=self.dtype) name='input', shape=[2, 4], dtype='float32')
label = fluid.data(name='label', shape=[2], dtype='int64') label = fluid.data(name='label', shape=[2], dtype='int64')
weight = fluid.data( weight = fluid.data(
name='weight', shape=[3], name='weight', shape=[3],
dtype=self.dtype) #weight for each class dtype='float32') #weight for each class
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
weight=weight) weight=weight)
ret = cross_entropy_loss(input, label) ret = cross_entropy_loss(input, label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册