提交 67d563c4 编写于 作者: G Gabriel de Marmiesse 提交者: François Chollet

[updated] improve softmax implementation (#11189)

* adjust softmax implementation

* theano only handle 2-dim case for softmax

[theano softmax](http://deeplearning.net/software/theano/library/tensor/nnet/nnet.html#theano.tensor.nnet.nnet.softmax)

* clean softmax activation and add 3d softmax test

* fix axis

* fix 1d case and test case

* fix 1d case and test case

* change to standard test

* fix test case

* Update activations_test.py

* got the value error back.

* Moved correctness test to the backend tests.

* Added a correctness test to activations_test.py.
上级 91f8e450
......@@ -25,7 +25,9 @@ def softmax(x, axis=-1):
ValueError: In case `dim(x) == 1`.
"""
ndim = K.ndim(x)
if ndim == 2:
if ndim == 1:
raise ValueError('Cannot apply softmax to a tensor that is 1D')
elif ndim == 2:
return K.softmax(x)
elif ndim > 2:
e = K.exp(x - K.max(x, axis=axis, keepdims=True))
......
......@@ -1731,10 +1731,11 @@ def relu(x, alpha=0., max_value=None, threshold=0.):
def softmax(x, axis=-1):
if axis == -1 or axis == x.ndim - 1:
if (axis == -1 or axis == x.ndim - 1) and x.ndim == 2:
return T.nnet.softmax(x)
return T.exp(x - x.max()) / T.exp(
x - x.max()).sum(axis=axis, keepdims=True)
xm = x.max(axis=axis, keepdims=True)
return T.exp(x - xm) / T.exp(
x - xm).sum(axis=axis, keepdims=True)
def softplus(x):
......
......@@ -79,6 +79,23 @@ def test_softmax_invalid():
f = K.function([x], [activations.softmax(x)])
def test_softmax_3d():
"""Test using a reference implementation of softmax.
"""
def softmax(values, axis):
m = np.max(values, axis=axis, keepdims=True)
e = np.exp(values - m)
return e / np.sum(e, axis=axis, keepdims=True)
x = K.placeholder(ndim=3)
f = K.function([x], [activations.softmax(x, axis=1)])
test_values = get_standard_values()[:, :, np.newaxis].copy()
result = f([test_values])[0]
expected = softmax(test_values, axis=1)
assert_allclose(result, expected, rtol=1e-05)
def test_time_distributed_softmax():
x = K.placeholder(shape=(1, 1, 5))
f = K.function([x], [activations.softmax(x)])
......
......@@ -991,6 +991,7 @@ class TestBackend(object):
check_single_tensor_operation('tanh', (4, 2), WITH_NP)
check_single_tensor_operation('softmax', (4, 10), WITH_NP)
check_single_tensor_operation('softmax', (4, 5, 3), WITH_NP, axis=1)
check_single_tensor_operation('softmax', (4, 5, 3, 10), WITH_NP, axis=2)
check_two_tensor_operation('binary_crossentropy', (4, 2), (4, 2), WITH_NP, from_logits=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册