提交 0a374c71 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add matmul large array tests.

Change: 150099362
上级 90bc9563
......@@ -54,6 +54,8 @@ class DotOperationTest : public ClientLibraryTestBase {
template <typename Element>
void TestNonsquareMatrixDot(bool lhs_row_major = false,
bool rhs_row_major = false);
void TestMatrixDot(int M, int K, int N, bool lhs_row_major = false,
bool rhs_row_major = false);
};
XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) {
......@@ -170,6 +172,84 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
&builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
}
void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major,
bool rhs_row_major) {
std::unique_ptr<Array2D<float>> lhs_data =
MakeLinspaceArray2D(0.0, 1.0, M, K);
std::unique_ptr<Literal> lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*lhs_data,
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)));
auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<float>> rhs_data =
MakeLinspaceArray2D(0.0, 1.0, K, N);
std::unique_ptr<Literal> rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*rhs_data,
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)));
auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<float>();
auto result = builder.Dot(
builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {M, K}), "lhs"),
builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {K, N}), "rhs"));
std::unique_ptr<Array2D<float>> expected =
ReferenceUtil::MatmulArray2D(*lhs_data, *rhs_data);
ComputeAndCompareR2<float>(&builder, *expected,
{lhs_handle.get(), rhs_handle.get()},
ErrorSpec(0.3, 3e-3));
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTF) {
TestMatrixDot(12, 117, 7, true, false);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFT) {
TestMatrixDot(12, 117, 7, false, true);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTT) {
TestMatrixDot(12, 117, 7, true, true);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFF) {
TestMatrixDot(12, 117, 7, false, false);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTT) {
TestMatrixDot(270, 270, 520, true, true);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTF) {
TestMatrixDot(270, 270, 520, true, false);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFT) {
TestMatrixDot(270, 270, 520, false, true);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFF) {
TestMatrixDot(270, 270, 520, false, false);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTT) {
TestMatrixDot(269, 3, 520, true, true);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTF) {
TestMatrixDot(260, 3, 520, true, false);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFT) {
TestMatrixDot(260, 3, 520, false, true);
}
XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) {
TestMatrixDot(260, 3, 520, false, false);
}
XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) {
constexpr bool kLhsRowMajor = false;
constexpr bool kRhsRowMajor = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册