From d31a4fff733b6df39366fb5b5818a32d37a57e38 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 12 Oct 2020 14:17:49 +0800 Subject: [PATCH] refactor(mge): use logsumexp in cross_entropy GitOrigin-RevId: 4a14aabb94cc083762490ce009c69394db012e96 --- imperative/python/megengine/functional/loss.py | 6 ++---- imperative/python/megengine/functional/nn.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index 8ed5958a4..40da78d8c 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -11,7 +11,7 @@ import numpy as np from ..core.tensor.utils import make_shape_tuple from ..tensor import Tensor from .elemwise import abs, equal, exp, log, maximum, pow, relu -from .nn import indexing_one_hot, logsigmoid, logsoftmax +from .nn import indexing_one_hot, logsigmoid, logsumexp from .tensor import where __all__ = [ @@ -191,9 +191,7 @@ def cross_entropy( ) # Denominator of the softmax - offset = pred.detach().max(axis=axis, keepdims=True) - pred = pred - offset - down = log(exp(pred).sum(axis=axis, keepdims=True)) + down = logsumexp(pred, axis=axis, keepdims=True) up = indexing_one_hot(pred, label, axis) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 3fc2e223d..6ba8ee0ec 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -546,7 +546,7 @@ def logsumexp( [-0.5481 4.4519] """ - max_value = max(inp, axis, keepdims=True) + max_value = max(inp.detach(), axis, keepdims=True) if keepdims: return max_value + log(sum(exp(inp - max_value), axis, keepdims)) else: -- GitLab