From 116792db459ca63398d5373bd5b139e32e67eb47 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Mon, 24 Aug 2020 10:57:37 -0700 Subject: [PATCH] [XLA:CLIENT] Support all gradients of <= 2 operand einsums PiperOrigin-RevId: 328171480 Change-Id: I1adbe658c9e3f435d4a42c2627cbbfef297e02f3 --- tensorflow/compiler/xla/client/lib/BUILD | 1 + tensorflow/compiler/xla/client/lib/matrix.cc | 167 ++++++++++-------- tensorflow/compiler/xla/client/lib/matrix.h | 11 -- .../compiler/xla/client/lib/matrix_test.cc | 44 ++--- .../python/kernel_tests/einsum_op_test.py | 8 - 5 files changed, 115 insertions(+), 116 deletions(-) diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index a3c7c39e3ff..eb09e9c8867 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -199,6 +199,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index b7721f2bbc5..ec1cc7e0487 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" @@ -235,84 +236,92 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } namespace { -std::vector EinsumDiagonalLabels(absl::Span config) { +absl::optional, 3>> EinsumDiagonalLabels( + absl::Span config) { std::vector unique_labels; + std::vector reduce_dims; + std::vector broadcast_dims; for (auto label = config.begin(); label != config.end(); ++label) { auto first_label = absl::c_find(config, *label); + auto dim = label - config.begin(); if (first_label == label) { unique_labels.push_back(*label); + broadcast_dims.push_back(dim); + } else { + reduce_dims.push_back(dim); } } if (unique_labels.size() == config.size()) { - unique_labels.clear(); + return absl::nullopt; } - return unique_labels; + return {{unique_labels, reduce_dims, broadcast_dims}}; } -} // namespace -xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config) { +// Masks a tensor such that only the diagonal of repeated indices are non-zero. +// The result of this can be used to create a diagonal matrix with an identity +// reduction. +xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span config) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - if (EinsumDiagonalLabels(config).empty()) { - return x; - } TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); Shape iota_shape = x_shape; iota_shape.set_element_type(S32); XlaOp mask = ConstantR0(builder, true); - absl::InlinedVector reduce_dims; for (auto label = config.begin(); label != config.end(); ++label) { const int64 dim = label - config.begin(); auto first_label = absl::c_find(config, *label); - if (first_label == label) { - continue; + if (first_label != label) { + const int64 first_dim = first_label - config.begin(); + mask = And(mask, Eq(Iota(builder, iota_shape, first_dim), + Iota(builder, iota_shape, dim))); } - reduce_dims.push_back(dim); - const int64 first_dim = first_label - config.begin(); - mask = And(mask, Eq(Iota(builder, iota_shape, first_dim), - Iota(builder, iota_shape, dim))); + } + return Select(mask, x, ZerosLike(x)); + }); +} + +xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto labels = EinsumDiagonalLabels(config); + if (!labels) { + return x; } auto zero = ScalarLike(x, 0); - return Reduce(Select(mask, x, zero), zero, + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + return Reduce(EinsumDiagonalMask(x, config), zero, CreateScalarIdentityWithZeroComputation( x_shape.element_type(), builder), - reduce_dims); + labels->at(1)); }); } -Status ValidateEinsumNumericDimensions(absl::Span x_config, - absl::Span y_config, - absl::Span output_config) { - for (auto dim : output_config) { - if (absl::c_linear_search(x_config, dim) || - absl::c_linear_search(y_config, dim)) { - if (absl::c_count(output_config, dim) > 1) { - return InvalidArgument("Einsum has repeated output dimension."); - } - continue; - } - return InvalidArgument( - "Einsum has output dimension without corresponding input dimension."); - } - for (auto dim : x_config) { - if (absl::c_linear_search(y_config, dim) || - absl::c_linear_search(output_config, dim)) { - if (absl::c_count(x_config, dim) > 1) { - return InvalidArgument("Einsum has repeated lhs dimension."); - } +xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span config) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto labels = EinsumDiagonalLabels(config); + if (!labels) { + return x; } - } - for (auto dim : y_config) { - if (absl::c_linear_search(x_config, dim) || - absl::c_linear_search(output_config, dim)) { - if (absl::c_count(y_config, dim) > 1) { - return InvalidArgument("Einsum has repeated rhs dimension."); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + std::vector broadcast_sizes; + int64 x_dim = 0; + for (auto label = config.begin(); label != config.end(); ++label) { + auto first_label = absl::c_find(config, *label); + if (first_label == label) { + broadcast_sizes.push_back(x_shape.dimensions(x_dim)); + ++x_dim; + } else { + broadcast_sizes.push_back( + broadcast_sizes[first_label - config.begin()]); } } - } - return Status::OK(); + x = BroadcastInDim(x, broadcast_sizes, labels->at(2)); + return EinsumDiagonalMask(x, config); + }); } +} // namespace namespace { // Helper method to remove dimensions from a shape and dot dimension numbers @@ -347,21 +356,23 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { auto x_diagonal_labels = EinsumDiagonalLabels(x_config); + if (x_diagonal_labels) { + return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y, + y_config, output_config, precision); + } auto y_diagonal_labels = EinsumDiagonalLabels(y_config); - if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) { - return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, - EinsumDiagonal(y, y_config), y_diagonal_labels, - output_config, precision); - } else if (!x_diagonal_labels.empty()) { - return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config, - output_config, precision); - } else if (!y_diagonal_labels.empty()) { - return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels, - output_config, precision); - } - - TF_RETURN_IF_ERROR( - ValidateEinsumNumericDimensions(x_config, y_config, output_config)); + if (y_diagonal_labels) { + return Einsum(x, x_config, EinsumDiagonal(y, y_config), + y_diagonal_labels->at(0), output_config, precision); + } + auto output_diagonal_labels = EinsumDiagonalLabels(output_config); + if (output_diagonal_labels) { + return EinsumInverseDiagonal( + Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0), + precision), + output_config); + } + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); const int64 x_rank = x_config.size(); @@ -372,21 +383,15 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, absl::flat_hash_set output_map; for (auto d : x_config) { - if (!x_map.insert(d).second) { - return InvalidArgument("XLA Einsum does not support rhs tracing"); - } + x_map.insert(d); } for (auto d : y_config) { - if (!y_map.insert(d).second) { - return InvalidArgument("XLA Einsum does not support lhs tracing"); - } + y_map.insert(d); } for (auto d : output_config) { - if (!output_map.insert(d).second) { - return InvalidArgument("XLA Einsum does not support output tracing"); - } + output_map.insert(d); } DotDimensionNumbers dnums; @@ -397,6 +402,7 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, auto is_contracting = [&](int64 d) { return x_map.contains(d) && y_map.contains(d); }; + auto rhs_dimension_number = [&](int64 d) { return absl::c_find(y_config, d) - y_config.begin(); }; @@ -468,8 +474,9 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, output_dimension_number(y_config[d]); } + const int64 transpose_rank = output_transpose_dims.size(); std::vector transpose_dims(output_rank); - for (int64 i = 0; i < output_rank; ++i) { + for (int64 i = 0; i < transpose_rank; ++i) { transpose_dims[output_transpose_dims[i]] = i; } @@ -498,7 +505,27 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, CreateScalarAddComputation(x_shape.element_type(), builder), output_reduce_dims); } - return Transpose(dot, transpose_dims); + dot = Transpose(dot, transpose_dims); + if (transpose_rank == output_rank) { + return dot; + } + + auto is_output_only = [&](int64 d) { + return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d); + }; + + int64 dot_dim = 0; + std::vector new_dims; + new_dims.reserve(output_rank); + TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot)); + for (auto d : output_config) { + if (is_output_only(d)) { + new_dims.push_back(1); + } else { + new_dims.push_back(dot_shape.dimensions(dot_dim)); + } + } + return Reshape(dot, new_dims); }); } diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 46f70ed27b9..1a9f72dedf2 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -112,14 +112,6 @@ StatusOr, 3>> ParseEinsumString( // Returns an empty string if the einsum string already has an ->. std::string NormalizeEinsumString(absl::string_view einsum_config); -// Determine if each dimension label is in at least two inputs. -// -// NOTE: This function is meant for testing, there is no need to call it -// directly. -Status ValidateEinsumNumericDimensions(absl::Span x_config, - absl::Span y_config, - absl::Span output_config); - // Supports two operand einsum notation like "ab,cb->ac". xla::XlaOp Einsum( xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, @@ -128,9 +120,6 @@ xla::XlaOp Einsum( xla::XlaOp x, absl::string_view einsum_config, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); -// Handles repeated indices within an operand by taking the tensor diagonal of -// the input. -xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config); // Same as above but supporting numeric labels on dimensions. So "ab,cb->ac" // becomes: diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index ebbf39ec096..628447c289e 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -233,12 +233,23 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { }; std::vector> good_test_cases = { - {"ab", "bc", "ac"}, {"Bab", "Bbc", "Bac"}, - {"ab", "cd", "dcba"}, {"abc", "abd", "cbd"}, - {"...ab", "...bc", "...ac"}, {"a...bc", "...abd", "cbd..."}, - {"...ab", "...bc", "ac"}, {"...b", "...bc", "...c"}, - {"...abz", "...bc", "...ac"}, {"...ab", "...bcz", "...ac"}, - {"abz", "bc", "ac"}, {"ab", "bcz", "ac"}, + {"ab", "bc", "ac"}, + {"Bab", "Bbc", "Bac"}, + {"ab", "cd", "dcba"}, + {"abc", "abd", "cbd"}, + {"...ab", "...bc", "...ac"}, + {"a...bc", "...abd", "cbd..."}, + {"...ab", "...bc", "ac"}, + {"...b", "...bc", "...c"}, + {"...abz", "...bc", "...ac"}, + {"...ab", "...bcz", "...ac"}, + {"abz", "bc", "ac"}, + {"ab", "bcz", "ac"}, + + {"a", "b", "c"}, + {"...a", "...b", "...c"}, + {"abb", "bcc", "ac"}, + {"ab", "bc", "ad"}, }; for (auto test_case : good_test_cases) { auto parse_result_or_status = @@ -249,9 +260,6 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { for (int i = 0; i < 3; ++i) { EXPECT_EQ(parse_result[i], to_vec(test_case[i])); } - EXPECT_TRUE(ValidateEinsumNumericDimensions( - parse_result[0], parse_result[1], parse_result[2]) - .ok()); } std::vector einsum_strings_that_fail_parsing = { @@ -261,24 +269,6 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { auto parse_result_or_status = ParseEinsumString(test_case, 3, 3); EXPECT_FALSE(parse_result_or_status.status().ok()); } - std::vector> einsum_strings_that_fail_numeric_validation = - { - {"a", "b", "c"}, - {"...a", "...b", "...c"}, - {"abb", "bcc", "ac"}, - {"ab", "bc", "ad"}, - }; - - for (auto test_case : einsum_strings_that_fail_numeric_validation) { - auto parse_result_or_status = - ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]), - test_case[0].size(), test_case[1].size()); - EXPECT_TRUE(parse_result_or_status.status().ok()); - auto parse_result = parse_result_or_status.ValueOrDie(); - EXPECT_FALSE(ValidateEinsumNumericDimensions( - parse_result[0], parse_result[1], parse_result[2]) - .ok()); - } } XLA_TEST_F(MatrixTest, NormalizeEinsumString) { diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py index 10b96716580..aa9e356bea5 100644 --- a/tensorflow/python/kernel_tests/einsum_op_test.py +++ b/tensorflow/python/kernel_tests/einsum_op_test.py @@ -237,7 +237,6 @@ class EinsumOpTest(test.TestCase): ((4, 3), (None, 3))) check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None)) - @test_util.disable_xla('b/131919749') def testOutputRepeatedLabels(self): # This is the reverse operation of generalized traces, to be used for # computing symbolic gradients of einsum. Note: this operation is not @@ -264,7 +263,6 @@ class EinsumOpTest(test.TestCase): # From transformer xl. check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10)) - @test_util.disable_xla('b/131919749') def testEmptyWithRepeatedLabels(self): def check(equation, input_shapes, output_shape): @@ -310,7 +308,6 @@ class EinsumGradTest(test.TestCase): self.assertLess( gradient_checker_v2.max_error(analytical, numerical), tol) - @test_util.disable_xla('b/131919749') def testUnary(self): # Unary cases. self._check_gradient('->', ()) @@ -319,7 +316,6 @@ class EinsumGradTest(test.TestCase): self._check_gradient('aabcd->add', (3, 3, 5, 4, 4)) self._check_gradient('abcd->da', (3, 5, 4, 2)) - @test_util.disable_xla('b/131919749') def testUnaryEllipsis(self): self._check_gradient('...->...', ()) self._check_gradient('...->', ()) @@ -362,11 +358,9 @@ class EinsumGradTest(test.TestCase): self._check_gradient('ijkm,ijln->ijmn', (2, 3, 3, 4), (2, 3, 3, 2)) self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3)) - @test_util.disable_xla('b/131919749') def testReducedIndicesWithRepeatedLabels(self): self._check_gradient('abce,badf->bcba', (1, 2, 3, 4), (2, 1, 4, 3)) - @test_util.disable_xla('b/131919749') def testRepeatedLabels(self): # Repeated indices. self._check_gradient('aba,a->b', (3, 4, 3), (3,)) @@ -376,7 +370,6 @@ class EinsumGradTest(test.TestCase): self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4)) self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4)) - @test_util.disable_xla('b/131919749') def testEmptyWithRepeatedLabels(self): self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10)) self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10)) @@ -388,7 +381,6 @@ class EinsumGradTest(test.TestCase): self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4)) self._check_gradient('i...j,j...k->i...k', (3, 1, 2, 2), (2, 2, 3, 1, 4)) - @test_util.disable_xla('b/131919749') def testBroadcastingWithRepeatedLabels(self): self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4)) self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4)) -- GitLab