提交 116792db 编写于 作者: B Blake Hechtman 提交者: TensorFlower Gardener

[XLA:CLIENT] Support all gradients of <= 2 operand einsums

PiperOrigin-RevId: 328171480
Change-Id: I1adbe658c9e3f435d4a42c2627cbbfef297e02f3
上级 2d0592a0
......@@ -199,6 +199,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
......
......@@ -30,6 +30,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
......@@ -235,84 +236,92 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
namespace {
std::vector<int64> EinsumDiagonalLabels(absl::Span<const int64> config) {
absl::optional<std::array<std::vector<int64>, 3>> EinsumDiagonalLabels(
absl::Span<const int64> config) {
std::vector<int64> unique_labels;
std::vector<int64> reduce_dims;
std::vector<int64> broadcast_dims;
for (auto label = config.begin(); label != config.end(); ++label) {
auto first_label = absl::c_find(config, *label);
auto dim = label - config.begin();
if (first_label == label) {
unique_labels.push_back(*label);
broadcast_dims.push_back(dim);
} else {
reduce_dims.push_back(dim);
}
}
if (unique_labels.size() == config.size()) {
unique_labels.clear();
return absl::nullopt;
}
return unique_labels;
return {{unique_labels, reduce_dims, broadcast_dims}};
}
} // namespace
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
// Masks a tensor such that only the diagonal of repeated indices are non-zero.
// The result of this can be used to create a diagonal matrix with an identity
// reduction.
xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span<const int64> config) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (EinsumDiagonalLabels(config).empty()) {
return x;
}
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
Shape iota_shape = x_shape;
iota_shape.set_element_type(S32);
XlaOp mask = ConstantR0(builder, true);
absl::InlinedVector<int64, 8> reduce_dims;
for (auto label = config.begin(); label != config.end(); ++label) {
const int64 dim = label - config.begin();
auto first_label = absl::c_find(config, *label);
if (first_label == label) {
continue;
}
reduce_dims.push_back(dim);
if (first_label != label) {
const int64 first_dim = first_label - config.begin();
mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
Iota(builder, iota_shape, dim)));
}
}
return Select(mask, x, ZerosLike(x));
});
}
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto labels = EinsumDiagonalLabels(config);
if (!labels) {
return x;
}
auto zero = ScalarLike(x, 0);
return Reduce(Select(mask, x, zero), zero,
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
return Reduce(EinsumDiagonalMask(x, config), zero,
CreateScalarIdentityWithZeroComputation(
x_shape.element_type(), builder),
reduce_dims);
labels->at(1));
});
}
Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
absl::Span<const int64> y_config,
absl::Span<const int64> output_config) {
for (auto dim : output_config) {
if (absl::c_linear_search(x_config, dim) ||
absl::c_linear_search(y_config, dim)) {
if (absl::c_count(output_config, dim) > 1) {
return InvalidArgument("Einsum has repeated output dimension.");
}
continue;
}
return InvalidArgument(
"Einsum has output dimension without corresponding input dimension.");
}
for (auto dim : x_config) {
if (absl::c_linear_search(y_config, dim) ||
absl::c_linear_search(output_config, dim)) {
if (absl::c_count(x_config, dim) > 1) {
return InvalidArgument("Einsum has repeated lhs dimension.");
}
}
}
for (auto dim : y_config) {
if (absl::c_linear_search(x_config, dim) ||
absl::c_linear_search(output_config, dim)) {
if (absl::c_count(y_config, dim) > 1) {
return InvalidArgument("Einsum has repeated rhs dimension.");
xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span<const int64> config) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto labels = EinsumDiagonalLabels(config);
if (!labels) {
return x;
}
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
std::vector<int64> broadcast_sizes;
int64 x_dim = 0;
for (auto label = config.begin(); label != config.end(); ++label) {
auto first_label = absl::c_find(config, *label);
if (first_label == label) {
broadcast_sizes.push_back(x_shape.dimensions(x_dim));
++x_dim;
} else {
broadcast_sizes.push_back(
broadcast_sizes[first_label - config.begin()]);
}
}
return Status::OK();
x = BroadcastInDim(x, broadcast_sizes, labels->at(2));
return EinsumDiagonalMask(x, config);
});
}
} // namespace
namespace {
// Helper method to remove dimensions from a shape and dot dimension numbers
......@@ -347,21 +356,23 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
if (x_diagonal_labels) {
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
y_config, output_config, precision);
}
auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) {
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels,
EinsumDiagonal(y, y_config), y_diagonal_labels,
output_config, precision);
} else if (!x_diagonal_labels.empty()) {
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config,
output_config, precision);
} else if (!y_diagonal_labels.empty()) {
return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels,
output_config, precision);
}
TF_RETURN_IF_ERROR(
ValidateEinsumNumericDimensions(x_config, y_config, output_config));
if (y_diagonal_labels) {
return Einsum(x, x_config, EinsumDiagonal(y, y_config),
y_diagonal_labels->at(0), output_config, precision);
}
auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
if (output_diagonal_labels) {
return EinsumInverseDiagonal(
Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
precision),
output_config);
}
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
const int64 x_rank = x_config.size();
......@@ -372,21 +383,15 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
absl::flat_hash_set<int64> output_map;
for (auto d : x_config) {
if (!x_map.insert(d).second) {
return InvalidArgument("XLA Einsum does not support rhs tracing");
}
x_map.insert(d);
}
for (auto d : y_config) {
if (!y_map.insert(d).second) {
return InvalidArgument("XLA Einsum does not support lhs tracing");
}
y_map.insert(d);
}
for (auto d : output_config) {
if (!output_map.insert(d).second) {
return InvalidArgument("XLA Einsum does not support output tracing");
}
output_map.insert(d);
}
DotDimensionNumbers dnums;
......@@ -397,6 +402,7 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
auto is_contracting = [&](int64 d) {
return x_map.contains(d) && y_map.contains(d);
};
auto rhs_dimension_number = [&](int64 d) {
return absl::c_find(y_config, d) - y_config.begin();
};
......@@ -468,8 +474,9 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
output_dimension_number(y_config[d]);
}
const int64 transpose_rank = output_transpose_dims.size();
std::vector<int64> transpose_dims(output_rank);
for (int64 i = 0; i < output_rank; ++i) {
for (int64 i = 0; i < transpose_rank; ++i) {
transpose_dims[output_transpose_dims[i]] = i;
}
......@@ -498,7 +505,27 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
CreateScalarAddComputation(x_shape.element_type(), builder),
output_reduce_dims);
}
return Transpose(dot, transpose_dims);
dot = Transpose(dot, transpose_dims);
if (transpose_rank == output_rank) {
return dot;
}
auto is_output_only = [&](int64 d) {
return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d);
};
int64 dot_dim = 0;
std::vector<int64> new_dims;
new_dims.reserve(output_rank);
TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot));
for (auto d : output_config) {
if (is_output_only(d)) {
new_dims.push_back(1);
} else {
new_dims.push_back(dot_shape.dimensions(dot_dim));
}
}
return Reshape(dot, new_dims);
});
}
......
......@@ -112,14 +112,6 @@ StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
// Returns an empty string if the einsum string already has an ->.
std::string NormalizeEinsumString(absl::string_view einsum_config);
// Determine if each dimension label is in at least two inputs.
//
// NOTE: This function is meant for testing, there is no need to call it
// directly.
Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
absl::Span<const int64> y_config,
absl::Span<const int64> output_config);
// Supports two operand einsum notation like "ab,cb->ac".
xla::XlaOp Einsum(
xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
......@@ -128,9 +120,6 @@ xla::XlaOp Einsum(
xla::XlaOp x, absl::string_view einsum_config,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
// Handles repeated indices within an operand by taking the tensor diagonal of
// the input.
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config);
// Same as above but supporting numeric labels on dimensions. So "ab,cb->ac"
// becomes:
......
......@@ -233,12 +233,23 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
};
std::vector<std::vector<string>> good_test_cases = {
{"ab", "bc", "ac"}, {"Bab", "Bbc", "Bac"},
{"ab", "cd", "dcba"}, {"abc", "abd", "cbd"},
{"...ab", "...bc", "...ac"}, {"a...bc", "...abd", "cbd..."},
{"...ab", "...bc", "ac"}, {"...b", "...bc", "...c"},
{"...abz", "...bc", "...ac"}, {"...ab", "...bcz", "...ac"},
{"abz", "bc", "ac"}, {"ab", "bcz", "ac"},
{"ab", "bc", "ac"},
{"Bab", "Bbc", "Bac"},
{"ab", "cd", "dcba"},
{"abc", "abd", "cbd"},
{"...ab", "...bc", "...ac"},
{"a...bc", "...abd", "cbd..."},
{"...ab", "...bc", "ac"},
{"...b", "...bc", "...c"},
{"...abz", "...bc", "...ac"},
{"...ab", "...bcz", "...ac"},
{"abz", "bc", "ac"},
{"ab", "bcz", "ac"},
{"a", "b", "c"},
{"...a", "...b", "...c"},
{"abb", "bcc", "ac"},
{"ab", "bc", "ad"},
};
for (auto test_case : good_test_cases) {
auto parse_result_or_status =
......@@ -249,9 +260,6 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
}
EXPECT_TRUE(ValidateEinsumNumericDimensions(
parse_result[0], parse_result[1], parse_result[2])
.ok());
}
std::vector<string> einsum_strings_that_fail_parsing = {
......@@ -261,24 +269,6 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
EXPECT_FALSE(parse_result_or_status.status().ok());
}
std::vector<std::vector<string>> einsum_strings_that_fail_numeric_validation =
{
{"a", "b", "c"},
{"...a", "...b", "...c"},
{"abb", "bcc", "ac"},
{"ab", "bc", "ad"},
};
for (auto test_case : einsum_strings_that_fail_numeric_validation) {
auto parse_result_or_status =
ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
test_case[0].size(), test_case[1].size());
EXPECT_TRUE(parse_result_or_status.status().ok());
auto parse_result = parse_result_or_status.ValueOrDie();
EXPECT_FALSE(ValidateEinsumNumericDimensions(
parse_result[0], parse_result[1], parse_result[2])
.ok());
}
}
XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
......
......@@ -237,7 +237,6 @@ class EinsumOpTest(test.TestCase):
((4, 3), (None, 3)))
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
@test_util.disable_xla('b/131919749')
def testOutputRepeatedLabels(self):
# This is the reverse operation of generalized traces, to be used for
# computing symbolic gradients of einsum. Note: this operation is not
......@@ -264,7 +263,6 @@ class EinsumOpTest(test.TestCase):
# From transformer xl.
check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
@test_util.disable_xla('b/131919749')
def testEmptyWithRepeatedLabels(self):
def check(equation, input_shapes, output_shape):
......@@ -310,7 +308,6 @@ class EinsumGradTest(test.TestCase):
self.assertLess(
gradient_checker_v2.max_error(analytical, numerical), tol)
@test_util.disable_xla('b/131919749')
def testUnary(self):
# Unary cases.
self._check_gradient('->', ())
......@@ -319,7 +316,6 @@ class EinsumGradTest(test.TestCase):
self._check_gradient('aabcd->add', (3, 3, 5, 4, 4))
self._check_gradient('abcd->da', (3, 5, 4, 2))
@test_util.disable_xla('b/131919749')
def testUnaryEllipsis(self):
self._check_gradient('...->...', ())
self._check_gradient('...->', ())
......@@ -362,11 +358,9 @@ class EinsumGradTest(test.TestCase):
self._check_gradient('ijkm,ijln->ijmn', (2, 3, 3, 4), (2, 3, 3, 2))
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
@test_util.disable_xla('b/131919749')
def testReducedIndicesWithRepeatedLabels(self):
self._check_gradient('abce,badf->bcba', (1, 2, 3, 4), (2, 1, 4, 3))
@test_util.disable_xla('b/131919749')
def testRepeatedLabels(self):
# Repeated indices.
self._check_gradient('aba,a->b', (3, 4, 3), (3,))
......@@ -376,7 +370,6 @@ class EinsumGradTest(test.TestCase):
self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4))
self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
@test_util.disable_xla('b/131919749')
def testEmptyWithRepeatedLabels(self):
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
......@@ -388,7 +381,6 @@ class EinsumGradTest(test.TestCase):
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
self._check_gradient('i...j,j...k->i...k', (3, 1, 2, 2), (2, 2, 3, 1, 4))
@test_util.disable_xla('b/131919749')
def testBroadcastingWithRepeatedLabels(self):
self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册