提交 d31a4fff 编写于 作者: M Megvii Engine Team

refactor(mge): use logsumexp in cross_entropy

GitOrigin-RevId: 4a14aabb94cc083762490ce009c69394db012e96
上级 20e93630
......@@ -11,7 +11,7 @@ import numpy as np
from ..core.tensor.utils import make_shape_tuple
from ..tensor import Tensor
from .elemwise import abs, equal, exp, log, maximum, pow, relu
from .nn import indexing_one_hot, logsigmoid, logsoftmax
from .nn import indexing_one_hot, logsigmoid, logsumexp
from .tensor import where
__all__ = [
......@@ -191,9 +191,7 @@ def cross_entropy(
)
# Denominator of the softmax
offset = pred.detach().max(axis=axis, keepdims=True)
pred = pred - offset
down = log(exp(pred).sum(axis=axis, keepdims=True))
down = logsumexp(pred, axis=axis, keepdims=True)
up = indexing_one_hot(pred, label, axis)
......
......@@ -546,7 +546,7 @@ def logsumexp(
[-0.5481 4.4519]
"""
max_value = max(inp, axis, keepdims=True)
max_value = max(inp.detach(), axis, keepdims=True)
if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册