提交 23cc2142 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update test_cross_entropy_loss.py

上级 d1a11056
......@@ -848,30 +848,6 @@ class CrossEntropyLoss(unittest.TestCase):
label_np = np.random.randint(0, C, size=(N)).astype(np.int64)
label_np[0] = 255
weight_np = np.random.random([C]).astype(self.dtype)
paddle.enable_static()
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[N, C], dtype=self.dtype)
label = fluid.data(name='label', shape=[N], dtype='int64')
weight = fluid.data(
name='weight', shape=[C],
dtype=self.dtype) #weight for each class
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
weight=weight, ignore_index=255)
ret = cross_entropy_loss(input, label)
exe = fluid.Executor(place)
static_ret = exe.run(prog,
feed={
'input': input_np,
'label': label_np,
"weight": weight_np
},
fetch_list=[ret])
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
......@@ -886,8 +862,6 @@ class CrossEntropyLoss(unittest.TestCase):
expected = cross_entropy_loss_1d(
input_np, label_np, weight=weight_np, ignore_index=255)[0]
self.assertTrue(np.allclose(static_ret, dy_ret_value))
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_1d_with_weight_mean(self):
......@@ -1214,31 +1188,6 @@ class CrossEntropyLoss(unittest.TestCase):
label_np = np.random.randint(0, C, size=(N, H, W)).astype(np.int64)
label_np[0, 0, 0] = 255
weight_np = np.random.random([C]).astype(self.dtype)
paddle.enable_static()
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(
name='input', shape=[N, H, W, C], dtype=self.dtype)
label = fluid.data(name='label', shape=[N, H, W], dtype='int64')
weight = fluid.data(
name='weight', shape=[C],
dtype=self.dtype) #weight for each class
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
weight=weight, ignore_index=255)
ret = cross_entropy_loss(input, label)
exe = fluid.Executor(place)
static_ret = exe.run(prog,
feed={
'input': input_np,
'label': label_np,
"weight": weight_np
},
fetch_list=[ret])
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册