提交 f57f1ce4 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[TF:XLA] Support LeakyRelu for alpha values outside of [0,1]

PiperOrigin-RevId: 251594608
上级 9eb67b17
......@@ -57,8 +57,9 @@ class LeakyReluOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto features = ctx->Input("features");
auto output =
xla::Max(features, features * xla::ScalarLike(features, alpha_));
auto prod_with_alpha = features * xla::ScalarLike(features, alpha_);
auto gt_zero = xla::Gt(features, xla::ScalarLike(features, 0));
auto output = xla::Select(gt_zero, features, prod_with_alpha);
ctx->SetOutput(0, output);
}
float alpha_;
......
......@@ -406,7 +406,6 @@ class LeakyReluTest(test.TestCase):
self.evaluate(optimizer.minimize(loss))
self.assertAllClose(x.read_value(), -99.9)
@test_util.disable_xla("XLA does not support values outside of [0,1]")
def testUnexpectedAlphaValue(self):
self.assertAllClose(
np.array([[-9.0, 0.7, -5.0, 0.3, -0.1], [0.1, -3.0, 0.5, -27.0, 0.9]]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册