From d21b5f6d08b0e63fb12db0ecf06546874e356d19 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 25 Feb 2019 09:53:04 -0800 Subject: [PATCH] [XLA:CPU] Enable F64 convolutions There's no good reason for disabling it completely. Right now this falls back to the slow generic implementation. We could make it use Eigen if we want to. PiperOrigin-RevId: 235548224 --- tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc | 3 ++- tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 2 +- tensorflow/compiler/xla/tests/convolution_test.cc | 8 ++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index a8b139aec9e..2cc618e4302 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -72,7 +72,8 @@ bool PotentiallyImplementedAsEigenConvolution( CHECK( ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape)); // TODO(b/65408531): Explore using Eigen dot for complex64 type. - if (ShapeUtil::ElementIsComplex(input_shape)) { + PrimitiveType primitive_type = input_shape.element_type(); + if (primitive_type != F16 && primitive_type != F32) { return false; } if (window_util::HasWindowReversal(convolution.window())) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 2418d96440f..47041eb348b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1090,7 +1090,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto rhs = convolution->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, C64, C128})); + /*supported_types=*/{F16, F32, F64, C64, C128})); // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 9db9f2563b6..cfee9c0f8a4 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -1945,7 +1945,7 @@ XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { class ConvolutionHloTest : public HloTestBase {}; -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) { +XLA_TEST_F(ConvolutionHloTest, ConvolveF64Forward) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1957,7 +1957,7 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF32ForwardReversed)) { +XLA_TEST_F(ConvolutionHloTest, ConvolveF32ForwardReversed) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1969,7 +1969,7 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { +XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardFilter) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1981,7 +1981,7 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardInput)) { +XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardInput) { constexpr char kHlo[] = R"( HloModule TestModule -- GitLab