提交 9939c334 编写于 作者: 吴承辉

Merge branch 'gemm' into 'master'

Optimize gemm x84 (v8/v7)  gemv v7

See merge request !544
此差异已折叠。
...@@ -34,6 +34,7 @@ void Gemm(const float *A, ...@@ -34,6 +34,7 @@ void Gemm(const float *A,
void GemmRef(const float *A, void GemmRef(const float *A,
const float *B, const float *B,
const index_t batch,
const index_t height, const index_t height,
const index_t K, const index_t K,
const index_t width, const index_t width,
......
...@@ -21,62 +21,98 @@ ...@@ -21,62 +21,98 @@
namespace mace { namespace mace {
TEST(GEMMTest, gemm) { namespace {
index_t N = 17;
index_t M = 33; void GemmTest(index_t batch, index_t N, index_t K, index_t M) {
index_t K = 64; std::unique_ptr<float[]> A(new float[batch * N * K]);
std::unique_ptr<float[]> A(new float[N * K]); std::unique_ptr<float[]> B(new float[batch * K * M]);
std::unique_ptr<float[]> B(new float[K * M]); std::unique_ptr<float[]> C(new float[batch * N * M]);
std::unique_ptr<float[]> C(new float[N * M]); std::unique_ptr<float[]> C_ref(new float[batch * N * M]);
std::unique_ptr<float[]> C_ref(new float[N * M]);
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * K, std::generate(A.get(), A.get() + batch * N * K,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
std::generate(B.get(), B.get() + K * M, std::generate(B.get(), B.get() + batch * K * M,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
kernels::Gemm(A.get(), B.get(), 1, N, K, M, C.get()); kernels::Gemm(A.get(), B.get(), batch, N, K, M, C.get());
kernels::GemmRef(A.get(), B.get(), N, K, M, C_ref.get()); kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get());
for (int i = 0; i < N * M; ++i) { for (int i = 0; i < batch * N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1); EXPECT_NEAR(C_ref[i], C[i], 0.1);
} }
} }
TEST(GEMMTest, gemv) { void GemvTest(index_t batch, index_t N, index_t M) {
index_t N = 17; std::unique_ptr<float[]> A(new float[N * M]);
index_t K = 63; std::unique_ptr<float[]> B(new float[batch * M]);
std::unique_ptr<float[]> A(new float[N * K]); std::unique_ptr<float[]> C(new float[batch * N]);
std::unique_ptr<float[]> B(new float[K]); std::unique_ptr<float[]> C_ref(new float[batch * N]);
std::unique_ptr<float[]> C(new float[N]);
std::unique_ptr<float[]> C_ref(new float[N]);
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * K, std::generate(A.get(), A.get() + N * M,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
std::generate(B.get(), B.get() + K, std::generate(B.get(), B.get() + batch * M,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
kernels::Gemv(A.get(), B.get(), 1, K, N, C.get()); kernels::Gemv(A.get(), B.get(), batch, M, N, C.get());
kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get()); kernels::GemvRef(A.get(), B.get(), batch, M, N, C_ref.get());
for (int i = 0; i < N; ++i) { for (int i = 0; i < batch * N; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1); EXPECT_NEAR(C_ref[i], C[i], 0.1);
} }
} }
} // namespace
TEST(GEMMTest, AlignedWithoutBatch) {
GemmTest(1, 1, 64, 128);
GemmTest(1, 2, 64, 128);
GemmTest(1, 3, 64, 128);
GemmTest(1, 4, 64, 128);
GemmTest(1, 5, 64, 128);
GemmTest(1, 6, 64, 128);
GemmTest(1, 7, 64, 128);
GemmTest(1, 17, 64, 128);
}
TEST(GEMMTest, UnalignedWithoutBatch) {
GemmTest(1, 1, 63, 127);
GemmTest(1, 2, 63, 127);
GemmTest(1, 3, 63, 127);
GemmTest(1, 4, 63, 127);
GemmTest(1, 5, 63, 127);
GemmTest(1, 6, 63, 127);
GemmTest(1, 7, 63, 127);
GemmTest(1, 17, 63, 127);
}
TEST(GEMMTest, UnalignedWithBatch) {
GemmTest(3, 1, 63, 127);
GemmTest(3, 2, 63, 127);
GemmTest(3, 3, 63, 127);
GemmTest(3, 4, 63, 127);
GemmTest(3, 5, 63, 127);
GemmTest(3, 6, 63, 127);
GemmTest(3, 7, 63, 127);
GemmTest(3, 17, 63, 127);
}
TEST(GEMMTest, gemv) {
GemvTest(1, 17, 63);
GemvTest(3, 17, 63);
}
} // namespace mace } // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册