From 0a374c7182fb8f0970f3dbf13db9e6f8c5464c00 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Mar 2017 11:33:04 -0800 Subject: [PATCH] Add matmul large array tests. Change: 150099362 --- .../compiler/xla/tests/dot_operation_test.cc | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 197a8f86cb0..45df8114534 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -54,6 +54,8 @@ class DotOperationTest : public ClientLibraryTestBase { template void TestNonsquareMatrixDot(bool lhs_row_major = false, bool rhs_row_major = false); + void TestMatrixDot(int M, int K, int N, bool lhs_row_major = false, + bool rhs_row_major = false); }; XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { @@ -170,6 +172,84 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } +void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, + bool rhs_row_major) { + std::unique_ptr> lhs_data = + MakeLinspaceArray2D(0.0, 1.0, M, K); + std::unique_ptr lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *lhs_data, + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))); + auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> rhs_data = + MakeLinspaceArray2D(0.0, 1.0, K, N); + std::unique_ptr rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *rhs_data, + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))); + auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {M, K}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {K, N}), "rhs")); + + std::unique_ptr> expected = + ReferenceUtil::MatmulArray2D(*lhs_data, *rhs_data); + + ComputeAndCompareR2(&builder, *expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(0.3, 3e-3)); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTF) { + TestMatrixDot(12, 117, 7, true, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFT) { + TestMatrixDot(12, 117, 7, false, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTT) { + TestMatrixDot(12, 117, 7, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFF) { + TestMatrixDot(12, 117, 7, false, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTT) { + TestMatrixDot(270, 270, 520, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTF) { + TestMatrixDot(270, 270, 520, true, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFT) { + TestMatrixDot(270, 270, 520, false, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFF) { + TestMatrixDot(270, 270, 520, false, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTT) { + TestMatrixDot(269, 3, 520, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTF) { + TestMatrixDot(260, 3, 520, true, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFT) { + TestMatrixDot(260, 3, 520, false, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { + TestMatrixDot(260, 3, 520, false, false); +} + XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { constexpr bool kLhsRowMajor = false; constexpr bool kRhsRowMajor = false; -- GitLab