提交 d31a4fff 编写于 作者: M Megvii Engine Team

refactor(mge): use logsumexp in cross_entropy

GitOrigin-RevId: 4a14aabb94cc083762490ce009c69394db012e96
上级 20e93630
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from ..core.tensor.utils import make_shape_tuple from ..core.tensor.utils import make_shape_tuple
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import abs, equal, exp, log, maximum, pow, relu from .elemwise import abs, equal, exp, log, maximum, pow, relu
from .nn import indexing_one_hot, logsigmoid, logsoftmax from .nn import indexing_one_hot, logsigmoid, logsumexp
from .tensor import where from .tensor import where
__all__ = [ __all__ = [
...@@ -191,9 +191,7 @@ def cross_entropy( ...@@ -191,9 +191,7 @@ def cross_entropy(
) )
# Denominator of the softmax # Denominator of the softmax
offset = pred.detach().max(axis=axis, keepdims=True) down = logsumexp(pred, axis=axis, keepdims=True)
pred = pred - offset
down = log(exp(pred).sum(axis=axis, keepdims=True))
up = indexing_one_hot(pred, label, axis) up = indexing_one_hot(pred, label, axis)
......
...@@ -546,7 +546,7 @@ def logsumexp( ...@@ -546,7 +546,7 @@ def logsumexp(
[-0.5481 4.4519] [-0.5481 4.4519]
""" """
max_value = max(inp, axis, keepdims=True) max_value = max(inp.detach(), axis, keepdims=True)
if keepdims: if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims)) return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册