提交 0075e6ac 编写于 作者: M Megvii Engine Team

refactor(mge): refactor cross_entropy

GitOrigin-RevId: 1fac5b5b14e6de742f1373e6834384c12718ec25
上级 f4b16932
......@@ -176,30 +176,19 @@ def cross_entropy(
"target_ndim={}".format(n0, n1)
)
num_classes = pred.shape[axis]
no_label_smooth = (
label_smooth is None or type(label_smooth) in (int, float) and label_smooth == 0
)
ls = label_smooth
if with_logits:
logZ = logsumexp(pred, axis).mean()
primary_term = indexing_one_hot(pred, label, axis).mean()
else:
logZ = 0
primary_term = log(indexing_one_hot(pred, label, axis)).mean()
if ls is None or type(ls) in (int, float) and ls == 0:
return logZ - primary_term
if not with_logits:
if no_label_smooth:
return -log(indexing_one_hot(pred, label, axis)).mean()
pred = log(pred)
return (
label_smooth * pred.mean()
- (1 - label_smooth) * indexing_one_hot(pred, label, axis).mean()
)
# Denominator of the softmax
down = logsumexp(pred, axis=axis, keepdims=True)
up = indexing_one_hot(pred, label, axis)
if not no_label_smooth:
factor = label_smooth / num_classes
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor
return (down - up).mean()
return logZ - ls * pred.mean() - (1 - ls) * primary_term
def binary_cross_entropy(
......
......@@ -13,15 +13,15 @@ from megengine import tensor
def test_cross_entropy_with_logits():
data = tensor([1, 100]).astype(np.float32).reshape((1, 2))
label = tensor([1]).astype(np.int32)
data = tensor([[0, 50], [0, -150]]).astype(np.float32)
label = tensor([1, 0]).astype(np.int32)
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0)
label = tensor([0]).astype(np.int32)
label = tensor([0, 1]).astype(np.int32)
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 100 - 1)
np.testing.assert_allclose(loss.numpy(), 100)
label = np.array([1])
label = np.array([1, 0])
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册