diff --git a/lite/api/_paddle_use_kernels.h b/lite/api/_paddle_use_kernels.h index 2f4d7350b56f7c56a329b629b27ed5b517708ef6..16924fbb0b0952411c6d73e675ecd57fc0236b92 100644 --- a/lite/api/_paddle_use_kernels.h +++ b/lite/api/_paddle_use_kernels.h @@ -35,7 +35,7 @@ USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); #ifdef LITE_WITH_ARM USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(matmul, kARM, kFloat, kNCHW, def); // for x2paddle +USE_LITE_KERNEL(matmul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(lrn, kARM, kFloat, kNCHW, def); diff --git a/lite/api/_paddle_use_ops.h b/lite/api/_paddle_use_ops.h index 5cf62224de33f14d5e0637fc9cc54752a79ba445..5e3c5f2e28037e71d4d6a7053f7e0d8559531807 100644 --- a/lite/api/_paddle_use_ops.h +++ b/lite/api/_paddle_use_ops.h @@ -19,7 +19,7 @@ #include "paddle_lite_factory_helper.h" // NOLINT USE_LITE_OP(mul); -USE_LITE_OP(matmul); // for x2paddle +USE_LITE_OP(matmul); USE_LITE_OP(fc); USE_LITE_OP(relu); USE_LITE_OP(relu6); diff --git a/lite/kernels/arm/matmul_compute.cc b/lite/kernels/arm/matmul_compute.cc index ba34228b48d6b66bd8fc64b6d7e03ffefb6105db..d6928652eee456b1345eb7882267a40540aa88a9 100644 --- a/lite/kernels/arm/matmul_compute.cc +++ b/lite/kernels/arm/matmul_compute.cc @@ -36,6 +36,7 @@ void MatMulCompute::Run() { auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); + auto o_dims = param.Out->dims(); bool x_transpose = param.transpose_X; bool y_transpose = param.transpose_Y; float alpha = param.alpha; @@ -44,137 +45,103 @@ void MatMulCompute::Run() { if (x_dims.size() > 2 && y_dims.size() >= 2) { // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] // x: [B, M, K], y: [K, N], out: [B, M, N] - if (x_transpose || y_transpose) { - LOG(FATAL) << "not supported transpose for x or y."; + + if (!x_transpose && !y_transpose) { + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ") x_transpose is " << x_transpose << "y_transpose is " + << y_transpose; + } else if (!x_transpose && y_transpose) { + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 1]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ") x_transpose is " << x_transpose << "y_transpose is " + << y_transpose; + } else if (x_transpose && !y_transpose) { + CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 2]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ") x_transpose is " << x_transpose << "y_transpose is " + << y_transpose; + } else { + CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 1]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ") x_transpose is " << x_transpose << "y_transpose is " + << y_transpose; } - CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << ")"; - if (y_dims.size() > 2) { + int lda, ldb, ldc; + if (!x_transpose) { m_ = x_dims[x_dims.size() - 2]; - k_ = y_dims[y_dims.size() - 2]; + k_ = x_dims[x_dims.size() - 1]; + lda = k_; + } else { + m_ = x_dims[x_dims.size() - 1]; + k_ = x_dims[x_dims.size() - 2]; + lda = m_; + } + + if (!y_transpose) { n_ = y_dims[y_dims.size() - 1]; - int hblock = lite::arm::math::get_hblock(ctx.arch()); - int m_round = 0; - m_round = hblock * ((m_ + hblock - 1) / hblock); - ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); - int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1]; - int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1]; - int out_inner = x_dims[x_dims.size() - 2] * y_dims[y_dims.size() - 1]; - if (n_ == 1) { - for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { - lite::arm::math::sgemv(x_data + i * x_inner, - y_data + i * y_inner, - o_data + i * out_inner, - false, - m_, - k_, - false, - nullptr, - false); - } - if (fabsf(alpha - 1.f) > 1e-8f) { - for (size_t i = 0; i < param.Out->dims().production(); ++i) { - o_data[i] *= alpha; - } - } - } else { - for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { - float* packed_x = static_cast(ctx.workspace_data()) + - ctx.llc_size() / sizeof(float); - lite::arm::math::prepackA(packed_x, - x_data + i * x_inner, - alpha, - k_, - 0, - m_, - 0, - k_, - false, - &ctx); - int ldb = n_; - if (y_transpose) { - ldb = k_; - } - lite::arm::math::sgemm_prepack(y_transpose, - m_, - n_, - k_, - packed_x, - y_data + i * y_inner, - ldb, - 0.f, - o_data + i * out_inner, - n_, - nullptr, - false, - false, - &ctx); - } + ldb = n_; + } else { + n_ = y_dims[y_dims.size() - 2]; + ldb = k_; + } + + ldc = n_; + + int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1]; + int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1]; + int out_inner = o_dims[o_dims.size() - 2] * o_dims[o_dims.size() - 1]; + + float* x_data_trans = nullptr; + if (x_transpose) { + x_data_trans = static_cast(malloc(sizeof(float) * x_inner)); + } + + if (y_dims.size() > 2) { + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { + lite::arm::math::sgemm(x_transpose, + y_transpose, + m_, + n_, + k_, + alpha, + x_data + i * x_inner, + lda, + y_data + i * y_inner, + ldb, + 0.f, + o_data + i * out_inner, + ldc, + nullptr, + false, + false, + &ctx); } } else { - m_ = x_dims[x_dims.size() - 2]; - k_ = y_dims[0]; - n_ = y_dims[1]; - int hblock = lite::arm::math::get_hblock(ctx.arch()); - int m_round = 0; - m_round = hblock * ((m_ + hblock - 1) / hblock); - ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); - int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1]; - int out_inner = x_dims[x_dims.size() - 2] * y_dims[1]; - if (n_ == 1) { - for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { - lite::arm::math::sgemv(x_data + i * x_inner, - y_data, - o_data + i * out_inner, - false, - m_, - k_, - false, - nullptr, - false); - } - if (fabsf(param.alpha - 1.f) > 1e-8f) { - for (size_t i = 0; i < param.Out->dims().production(); ++i) { - o_data[i] *= param.alpha; - } - } - } else { - for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { - float* packed_x = static_cast(ctx.workspace_data()) + - ctx.llc_size() / sizeof(float); - lite::arm::math::prepackA(packed_x, - x_data + i * x_inner, - alpha, - k_, - 0, - m_, - 0, - k_, - false, - &ctx); - int ldb = n_; - if (y_transpose) { - ldb = k_; - } - lite::arm::math::sgemm_prepack(y_transpose, - m_, - n_, - k_, - packed_x, - y_data, - ldb, - 0.f, - o_data + i * out_inner, - n_, - nullptr, - false, - false, - &ctx); - } + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { + lite::arm::math::sgemm(x_transpose, + y_transpose, + m_, + n_, + k_, + alpha, + x_data + i * x_inner, + lda, + y_data, + ldb, + 0.f, + o_data + i * out_inner, + ldc, + nullptr, + false, + false, + &ctx); } } + if (x_data_trans) { + free(x_data_trans); + } } else if (x_dims.size() == 2 && y_dims.size() == 2) { // x: [M, K], y: [K, N], out: [M, N] if (!x_transpose && !y_transpose) { @@ -198,50 +165,43 @@ void MatMulCompute::Run() { << "), x_transpose is " << x_transpose << ", y_transpose is " << y_transpose; } - // not supported transpose - if (x_transpose || y_transpose) { - LOG(FATAL) << "not supported transpose for x and y."; - } - m_ = x_dims[0]; - k_ = x_dims[1]; - n_ = y_dims[1]; - int hblock = lite::arm::math::get_hblock(ctx.arch()); - int m_round = 0; - m_round = hblock * ((m_ + hblock - 1) / hblock); - ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); - if (n_ == 1) { - lite::arm::math::sgemv( - x_data, y_data, o_data, x_transpose, m_, k_, false, nullptr, false); - if (fabsf(param.alpha - 1.f) > 1e-8f) { - for (size_t i = 0; i < param.Out->dims().production(); ++i) { - o_data[i] *= param.alpha; - } - } + int lda, ldb, ldc; + if (!x_transpose) { + m_ = x_dims[0]; + k_ = x_dims[1]; + lda = k_; } else { - float* packed_x = static_cast(ctx.workspace_data()) + - ctx.llc_size() / sizeof(float); - lite::arm::math::prepackA( - packed_x, x_data, alpha, k_, 0, m_, 0, k_, x_transpose, &ctx); - int ldb = n_; - if (y_transpose) { - ldb = k_; - } - lite::arm::math::sgemm_prepack(y_transpose, - m_, - n_, - k_, - packed_x, - y_data, - ldb, - 0.f, - o_data, - n_, - nullptr, - false, - false, - &ctx); + m_ = x_dims[1]; + k_ = x_dims[0]; + lda = m_; } + if (!y_transpose) { + n_ = y_dims[1]; + ldb = n_; + } else { + n_ = y_dims[0]; + ldb = k_; + } + ldc = n_; + + lite::arm::math::sgemm(x_transpose, + y_transpose, + m_, + n_, + k_, + alpha, + x_data, + lda, + y_data, + ldb, + 0.f, + o_data, + ldc, + nullptr, + false, + false, + &ctx); } else if (x_dims.size() > 2 && y_dims.size() == 1) { // x: [B, M, K], y: [K], out: [B, M] CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0]) @@ -267,6 +227,9 @@ void MatMulCompute::Run() { m_ = x_dims[0]; k_ = 1; n_ = y_dims[0]; + int lda = k_; + int ldb = n_; + int ldc = n_; if (n_ == 1) { lite::arm::math::sgemv( x_data, y_data, o_data, false, m_, k_, false, nullptr, false); @@ -276,25 +239,23 @@ void MatMulCompute::Run() { } } } else { - float* packed_x = static_cast(ctx.workspace_data()) + - ctx.llc_size() / sizeof(float); - lite::arm::math::prepackA( - packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx); - int ldb = n_; - lite::arm::math::sgemm_prepack(false, - m_, - n_, - k_, - packed_x, - y_data, - ldb, - 0.f, - o_data, - n_, - nullptr, - false, - false, - &ctx); + lite::arm::math::sgemm(false, + false, + m_, + n_, + k_, + alpha, + x_data, + lda, + y_data, + ldb, + 0.f, + o_data, + ldc, + nullptr, + false, + false, + &ctx); } } } else { diff --git a/lite/operators/matmul_op.cc b/lite/operators/matmul_op.cc index 90cfd1ddadb2e3167fa0463a2cd5dc10d7d62882..286ade7b2130ce662eea2b7ba4e142bf489306ca 100644 --- a/lite/operators/matmul_op.cc +++ b/lite/operators/matmul_op.cc @@ -37,14 +37,41 @@ bool MatMulOpLite::InferShape() const { if (x_dims.size() > 2 && y_dims.size() >= 2) { // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] // x: [B, M, K], y: [K, N], out: [B, M, N] - CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << ")"; + if (!x_transpose && !y_transpose) { + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + } else if (!x_transpose && y_transpose) { + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 1]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + } else if (x_transpose && !y_transpose) { + CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 2]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + } else { + CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 1]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + } + dim_out_vec.resize(x_dims.size()); - for (size_t i = 0; i < x_dims.size() - 1; ++i) { + for (size_t i = 0; i < x_dims.size() - 2; ++i) { dim_out_vec[i] = x_dims[i]; } - dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1]; + if (!x_transpose && !y_transpose) { + dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 2]; + dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1]; + } else if (!x_transpose && y_transpose) { + dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 2]; + dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 2]; + } else if (x_transpose && !y_transpose) { + dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 1]; + dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1]; + } else { + dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 1]; + dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 2]; + } } else if (x_dims.size() == 2 && y_dims.size() == 2) { // x: [M, K], y: [K, N], out: [M, N] // x: [M, K], y: [K, N], out: [M, N] diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 2d1fd8bfe6d9e1775dec8da506efa5acb82eafbd..deac6410b31da20b0456f419f6d53411f25d12c2 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -695,7 +695,7 @@ struct SliceParam { std::vector decrease_axis{}; }; -/// ----------------------- shape operators ---------------------- +/// ----------------------- squeeze operators ---------------------- struct SqueezeParam { const lite::Tensor* X{}; lite::Tensor* Out{}; @@ -719,7 +719,6 @@ struct MatMulParam { bool transpose_Y{false}; float alpha{1.0f}; }; - } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/matmul_compute_test.cc b/lite/tests/kernels/matmul_compute_test.cc index 648180f832671ae9ff7e5b4da1e2f15e90c90fb9..8b70f59d4756c47ceee039ab7797a66e8f695c2e 100644 --- a/lite/tests/kernels/matmul_compute_test.cc +++ b/lite/tests/kernels/matmul_compute_test.cc @@ -152,29 +152,36 @@ class MatMulComputeTester : public arena::TestCase { auto* out = scope->NewTensor(out_); CHECK(out); - // todo alpha std::vector dim_out_vec; if (x_dims_.size() > 2 && y_dims_.size() >= 2) { // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] // x: [B, M, K], y: [K, N], out: [B, M, N] - if (x_transpose_ || y_transpose_) { - LOG(FATAL) << "not supported transpose for x and y."; - } - CHECK_EQ(x_dims_[x_dims_.size() - 1], y_dims_[y_dims_.size() - 2]) - << "not supported x_dims(" << x_dims_ << ") and y_dims(" << y_dims_ - << ")"; dim_out_vec.resize(x_dims_.size()); - for (size_t i = 0; i < x_dims_.size() - 1; ++i) { + for (size_t i = 0; i < x_dims_.size() - 2; ++i) { dim_out_vec[i] = x_dims_[i]; } - dim_out_vec[x_dims_.size() - 1] = y_dims_[y_dims_.size() - 1]; + if (!x_transpose_ && !y_transpose_) { + dim_out_vec[x_dims_.size() - 2] = x_dims_[x_dims_.size() - 2]; + dim_out_vec[x_dims_.size() - 1] = y_dims_[y_dims_.size() - 1]; + } else if (!x_transpose_ && y_transpose_) { + dim_out_vec[x_dims_.size() - 2] = x_dims_[x_dims_.size() - 2]; + dim_out_vec[x_dims_.size() - 1] = y_dims_[y_dims_.size() - 2]; + } else if (x_transpose_ && !y_transpose_) { + dim_out_vec[x_dims_.size() - 2] = x_dims_[x_dims_.size() - 1]; + dim_out_vec[x_dims_.size() - 1] = y_dims_[y_dims_.size() - 1]; + } else { + dim_out_vec[x_dims_.size() - 2] = x_dims_[x_dims_.size() - 1]; + dim_out_vec[x_dims_.size() - 1] = y_dims_[y_dims_.size() - 2]; + } + out->Resize(dim_out_vec); auto* out_data = out->mutable_data(); int x_inner = x_dims_[x_dims_.size() - 2] * x_dims_[x_dims_.size() - 1]; if (y_dims_.size() > 2) { int y_inner = y_dims_[y_dims_.size() - 2] * y_dims_[y_dims_.size() - 1]; - int o_inner = x_dims_[x_dims_.size() - 2] * y_dims_[y_dims_.size() - 1]; + int o_inner = + dim_out_vec[x_dims_.size() - 2] * dim_out_vec[x_dims_.size() - 1]; for (size_t i = 0; i < x_dims_.count(0, x_dims_.size() - 2); ++i) { mul_low_efficiency( DDim({x_dims_[x_dims_.size() - 2], x_dims_[x_dims_.size() - 1]}), @@ -187,7 +194,8 @@ class MatMulComputeTester : public arena::TestCase { out_data + i * o_inner); } } else { - int o_inner = x_dims_[x_dims_.size() - 2] * y_dims_[1]; + int o_inner = + dim_out_vec[x_dims_.size() - 2] * dim_out_vec[x_dims_.size() - 1]; for (size_t i = 0; i < x_dims_.count(0, x_dims_.size() - 2); ++i) { mul_low_efficiency( DDim({x_dims_[x_dims_.size() - 2], x_dims_[x_dims_.size() - 1]}), @@ -240,7 +248,7 @@ class MatMulComputeTester : public arena::TestCase { out_data[i] += x_data[i * y_dims_[0] + j] * y_data[j] * alpha_; } } - } else if (x_dims_.size() == 1 && y_dims_.size() == 1) { // todo + } else if (x_dims_.size() == 1 && y_dims_.size() == 1) { // x: [K], y: [K], out: [1] if (x_dims_[0] == y_dims_[0] && x_transpose_ == false && y_transpose_ == false) { @@ -325,14 +333,40 @@ void test_matmul2x2_no_transform(Place place) { } } -void test_matmul2x2_transform(Place place) { - DDim x_dim({3, 2}); - DDim y_dim({3, 2}); - float alpha = 1.f; - std::unique_ptr tester( - new MatMulComputeTester(place, "def", false, true, alpha, x_dim, y_dim)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); +void test_matmul2x2_x_transpose(Place place) { + std::vector x_dims({DDim({3, 4}), DDim({2, 5})}); + std::vector y_dims({DDim({3, 2}), DDim({2, 1})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", true, false, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } +} + +void test_matmul2x2_y_transpose(Place place) { + std::vector x_dims({DDim({5, 2}), DDim({2, 5})}); + std::vector y_dims({DDim({3, 2}), DDim({1, 5})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", false, true, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } +} + +void test_matmul2x2_transpose(Place place) { + std::vector x_dims({DDim({6, 2}), DDim({5, 3})}); + std::vector y_dims({DDim({3, 6}), DDim({1, 5})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", true, true, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 5e-5); + arena.TestPrecision(); + } } void test_matmul1x1_no_transpose(Place place) { @@ -366,9 +400,9 @@ void test_matmul_nx1(Place place) { } void test_matmul_nx2_1(Place place) { - DDim x_dim({3, 4, 2, 5}); - DDim y_dim({5, 1}); - float alpha = 1.5f; + DDim x_dim({1, 2, 2, 3}); + DDim y_dim({3, 1}); + float alpha = 1.f; std::unique_ptr tester( new MatMulComputeTester(place, "def", false, false, alpha, x_dim, y_dim)); arena::Arena arena(std::move(tester), place, 2e-5); @@ -376,8 +410,8 @@ void test_matmul_nx2_1(Place place) { } void test_matmul_nx2_2(Place place) { - DDim x_dim({3, 4, 2, 5}); - DDim y_dim({5, 3}); + DDim x_dim({1, 2, 2, 3}); + DDim y_dim({3, 3}); float alpha = 1.5f; std::unique_ptr tester( new MatMulComputeTester(place, "def", false, false, alpha, x_dim, y_dim)); @@ -385,27 +419,127 @@ void test_matmul_nx2_2(Place place) { arena.TestPrecision(); } +void test_matmulnx2_x_transpose(Place place) { + std::vector x_dims({DDim({3, 4, 6, 2}), DDim({5, 3, 5, 2})}); + std::vector y_dims({DDim({6, 2}), DDim({5, 1})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", true, false, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 2e-4); + arena.TestPrecision(); + } +} + +void test_matmulnx2_y_transpose(Place place) { + std::vector x_dims({DDim({3, 4, 6, 2}), DDim({5, 3, 5, 2})}); + std::vector y_dims({DDim({6, 2}), DDim({1, 2})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", false, true, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 5e-5); + arena.TestPrecision(); + } +} + +void test_matmulnx2_transpose(Place place) { + std::vector x_dims({DDim({3, 4, 4, 3}), DDim({5, 3, 3, 2})}); + std::vector y_dims({DDim({2, 4}), DDim({1, 3})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", true, true, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 5e-5); + arena.TestPrecision(); + } +} + void test_matmul_nxn(Place place) { DDim x_dim({3, 4, 2, 5}); DDim y_dim({3, 4, 5, 2}); float alpha = 1.5f; std::unique_ptr tester( new MatMulComputeTester(place, "def", false, false, alpha, x_dim, y_dim)); - arena::Arena arena(std::move(tester), place, 2e-5); + arena::Arena arena(std::move(tester), place, 1e-3); arena.TestPrecision(); } +void test_matmulnxn_x_transpose(Place place) { + std::vector x_dims({DDim({3, 4, 6, 2}), DDim({5, 3, 5, 2})}); + std::vector y_dims({DDim({3, 4, 6, 2}), DDim({5, 3, 5, 1})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", true, false, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 1e-3); + arena.TestPrecision(); + } +} + +void test_matmulnxn_y_transpose(Place place) { + std::vector x_dims({DDim({3, 4, 6, 2}), DDim({5, 3, 5, 2})}); + std::vector y_dims({DDim({3, 4, 6, 2}), DDim({5, 3, 1, 2})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", false, true, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 1e-3); + arena.TestPrecision(); + } +} + +void test_matmulnxn_transpose(Place place) { + std::vector x_dims({DDim({3, 4, 4, 3}), DDim({5, 3, 3, 2})}); + std::vector y_dims({DDim({3, 4, 2, 4}), DDim({5, 3, 1, 3})}); + std::vector alphas({1.f, 2.f}); + for (int i = 0; i < x_dims.size(); ++i) { + std::unique_ptr tester(new MatMulComputeTester( + place, "def", true, true, alphas[i], x_dims[i], y_dims[i])); + arena::Arena arena(std::move(tester), place, 1e-3); + arena.TestPrecision(); + } +} + TEST(Matmul2x2, precision) { #ifdef LITE_WITH_X86 Place place(TARGET(kX86)); #endif #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); - // test_matmul2x2_transform(place); test_matmul2x2_no_transform(place); #endif } +TEST(Matmul2x2_x_transpose, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_matmul2x2_x_transpose(place); +#endif +} +TEST(Matmul2x2_y_transpose, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_matmul2x2_y_transpose(place); +#endif +} + +TEST(Matmul2x2_transpose, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_matmul2x2_transpose(place); +#endif +} + TEST(Matmul1x1, precision) { #ifdef LITE_WITH_X86 Place place(TARGET(kX86)); @@ -435,6 +569,9 @@ TEST(Matmulnx2, precision) { Place place(TARGET(kARM)); test_matmul_nx2_1(place); test_matmul_nx2_2(place); + test_matmulnx2_x_transpose(place); + test_matmulnx2_y_transpose(place); + test_matmulnx2_transpose(place); #endif } @@ -445,6 +582,9 @@ TEST(Matmulnxn, precision) { #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); test_matmul_nxn(place); + test_matmulnxn_x_transpose(place); + test_matmulnxn_y_transpose(place); + test_matmulnxn_transpose(place); #endif }