提交 2114fd51 编写于 作者: P Peter Hawkins 提交者: TensorFlower Gardener

[TF:XLA] Improve numerical stability of SoftPlus.

PiperOrigin-RevId: 171003559
上级 727d6270
......@@ -309,11 +309,6 @@ class UnaryOpsTest(XLATestCase):
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softplus,
np.array([[-2, 0, 8]], dtype=dtype),
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softsign,
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
......@@ -543,6 +538,25 @@ class UnaryOpsTest(XLATestCase):
[[9, 10, 11, 12],
[13, 14, 15, 16]]]], dtype=dtype))
def _assertSoftplusMatchesExpected(self, features, dtype):
features = np.array(features, dtype=dtype)
zero = np.asarray(0).astype(dtype)
expected = np.logaddexp(zero, features)
self._assertOpOutputMatchesExpected(
nn_ops.softplus, features, expected=expected)
def testSoftplus(self):
for dtype in self.float_types:
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
self._assertSoftplusMatchesExpected(
[[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype)
log_eps = np.log(np.finfo(dtype).eps)
one = dtype(1)
ten = dtype(10)
self._assertSoftplusMatchesExpected([
log_eps, log_eps - one, log_eps + one, log_eps - ten,
log_eps + ten, -log_eps, -log_eps - one, -log_eps + one,
-log_eps - ten, -log_eps + ten], dtype)
if __name__ == "__main__":
googletest.main()
......@@ -129,8 +129,28 @@ XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
XLAJIT_MAKE_UNARY(Sinh,
b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
XLAJIT_MAKE_UNARY(Softplus,
b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0)))));
static xla::ComputationDataHandle Softplus(
xla::ComputationBuilder* b, DataType dtype,
const xla::ComputationDataHandle& features) {
xla::ComputationDataHandle threshold =
b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)),
XlaHelpers::FloatLiteral(b, dtype, 2.0));
// Value above which exp(x) may overflow, but softplus(x) == x
// is within machine epsilon.
xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold));
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
// is within machine epsilon.
xla::ComputationDataHandle too_small = b->Lt(features, threshold);
xla::ComputationDataHandle features_exp = b->Exp(features);
xla::ComputationDataHandle output = b->Select(
too_large, features,
b->Select(too_small, features_exp,
b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype)))));
return output;
}
XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x));
// softsign(x) = x / (abs(x) + 1)
XLAJIT_MAKE_UNARY(Softsign,
b->Div(x,
......
......@@ -54,6 +54,19 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
return b->ConstantLiteral(xla::Literal::One(type));
}
xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
DataType data_type) {
switch (data_type) {
case DT_FLOAT:
return b->ConstantR0<float>(std::numeric_limits<float>::epsilon());
case DT_DOUBLE:
return b->ConstantR0<double>(std::numeric_limits<double>::epsilon());
default:
LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: "
<< DataTypeString(data_type);
}
}
xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
xla::ComputationBuilder* b, DataType data_type, int64 value) {
xla::Literal literal;
......
......@@ -48,6 +48,11 @@ class XlaHelpers {
static xla::ComputationDataHandle One(xla::ComputationBuilder* b,
DataType data_type);
// Returns the machine epsilon for floating-point type `data_type`, i.e.,
// the difference between 1.0 and the next representable value.
static xla::ComputationDataHandle Epsilon(xla::ComputationBuilder* b,
DataType data_type);
// Returns a handle representing the given value of an integer scalar
// element of data_type.
// Note that unlike One and Zero, does not work on boolean types.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册