diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index 08dafed971e3a5b402ea883ed7dd93573b7b990f..3d2f0d0aecffcd0fe51166c3d863aa8b91bba196 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -25,6 +25,14 @@ template using EigenMatrix = framework::EigenMatrix; +template +struct ValueClip { + HOSTDEVICE T operator()(const T& x) const { + const T kThreshold = -64.; + return x < kThreshold ? kThreshold : x; + } +}; + template class SoftmaxFunctor { public: @@ -47,7 +55,8 @@ class SoftmaxFunctor { logits.maximum(along_class) .eval() .reshape(batch_by_one) - .broadcast(one_by_class)); + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); softmax.device(context.GetEigenDevice()) = diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py index 1b948f252fa631e9886840b377de2996e110dc91..b41c810d9a6269c934a434b085748a86deccb475 100644 --- a/python/paddle/v2/framework/tests/test_softmax_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_op.py @@ -5,7 +5,7 @@ from op_test import OpTest def stable_softmax(x): """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - np.max(x) + shiftx = x - np.max(x).clip(-64.) exps = np.exp(shiftx) return exps / np.sum(exps)