未验证 提交 23d3e36a 编写于 作者: G Guanghua Yu 提交者: GitHub

fix cross_entropy calculation error (#32545)

* fix cross_entropy calculation error

* add unittest and fix static
上级 97794eca
......@@ -59,8 +59,8 @@ def cross_entropy_loss_1d(input,
if reduction == 'sum':
return np.sum(out), np.array([total_weight]).astype('float64')
elif reduction == 'mean':
return out.sum() / total_weight, np.array(
[total_weight]).astype('float64')
out = out.sum() / total_weight if total_weight != 0 else out.sum()
return out, np.array([total_weight]).astype('float64')
elif reduction == 'none':
return out
......@@ -92,8 +92,8 @@ def cross_entropy_loss_2d(input,
if reduction == 'sum':
return np.sum(out), np.array([total_weight]).astype('float64')
elif reduction == 'mean':
return out.sum() / total_weight, np.array(
[total_weight]).astype('float64')
out = out.sum() / total_weight if total_weight != 0 else out.sum()
return out, np.array([total_weight]).astype('float64')
elif reduction == 'none':
return out
......@@ -759,6 +759,45 @@ class CrossEntropyLoss(unittest.TestCase):
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_1d_with_mean_ignore_negative(self):
N = 100
C = 200
input_np = np.random.random([N, C]).astype(self.dtype)
label_np = -np.ones((N)).astype(np.int64)
paddle.enable_static()
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[N, C], dtype=self.dtype)
label = fluid.data(name='label', shape=[N], dtype='int64')
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
ignore_index=-1)
ret = cross_entropy_loss(input, label)
exe = fluid.Executor(place)
static_ret = exe.run(prog,
feed={
'input': input_np,
'label': label_np,
},
fetch_list=[ret])
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
axis=1, ignore_index=-1)
dy_ret = cross_entropy_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_ret_value = dy_ret.numpy()
self.assertIsNotNone(dy_ret_value)
expected = cross_entropy_loss_1d(input_np, label_np, ignore_index=-1)[0]
self.assertTrue(np.allclose(static_ret, dy_ret_value))
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_1d_with_weight_mean_ignore(self):
N = 100
C = 200
......
......@@ -1454,20 +1454,20 @@ def cross_entropy(input,
if weight is None:
mask = paddle.cast(mask, dtype=out_sum.dtype)
count = core.ops.reduce_sum(mask, 'reduce_all', True)
ret = out_sum / count
ret = out_sum / (count + (count == 0.0))
else:
mask = paddle.cast(mask, weight_gather_reshape.dtype)
weight_ignored = core.ops.elementwise_mul(
mask, weight_gather_reshape)
weight_sum = core.ops.reduce_sum(weight_ignored,
'reduce_all', True)
ret = out_sum / weight_sum
ret = out_sum / (weight_sum + (weight_sum == 0.0))
return ret
elif weight is not None:
out_sum = core.ops.reduce_sum(out, 'reduce_all', True)
total_weight = core.ops.reduce_sum(weight_gather_reshape,
'reduce_all', True)
return out_sum / total_weight
return out_sum / (total_weight + (total_weight == 0.0))
else:
return core.ops.mean(out)
......@@ -1537,17 +1537,17 @@ def cross_entropy(input,
if (weight is None):
mask = paddle.cast(mask, dtype=out_sum.dtype)
count = paddle.sum(mask, name=name)
ret = out_sum / count
ret = out_sum / (count + (count == 0.0))
else:
mask = paddle.cast(mask, weight_gather_reshape.dtype)
weight_ignored = paddle.multiply(mask, weight_gather_reshape)
weight_sum = paddle.sum(weight_ignored, name=name)
ret = out_sum / weight_sum
ret = out_sum / (weight_sum + (weight_sum == 0.0))
return ret
elif weight is not None:
out_sum = paddle.sum(out, name=name)
total_weight = paddle.sum(weight_gather_reshape)
return out_sum / total_weight
return out_sum / (total_weight + (total_weight == 0.0))
else:
return paddle.mean(out, name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册