提交 24fade6d 编写于 作者: 李寅

Optimize gemm x84

上级 4aa83602
此差异已折叠。
......@@ -34,6 +34,7 @@ void Gemm(const float *A,
void GemmRef(const float *A,
const float *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
......
......@@ -21,62 +21,98 @@
namespace mace {
TEST(GEMMTest, gemm) {
index_t N = 17;
index_t M = 33;
index_t K = 64;
std::unique_ptr<float[]> A(new float[N * K]);
std::unique_ptr<float[]> B(new float[K * M]);
std::unique_ptr<float[]> C(new float[N * M]);
std::unique_ptr<float[]> C_ref(new float[N * M]);
namespace {
void GemmTest(index_t batch, index_t N, index_t K, index_t M) {
std::unique_ptr<float[]> A(new float[batch * N * K]);
std::unique_ptr<float[]> B(new float[batch * K * M]);
std::unique_ptr<float[]> C(new float[batch * N * M]);
std::unique_ptr<float[]> C_ref(new float[batch * N * M]);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * K,
std::generate(A.get(), A.get() + batch * N * K,
[&gen, &nd] {
return nd(gen);
});
std::generate(B.get(), B.get() + K * M,
std::generate(B.get(), B.get() + batch * K * M,
[&gen, &nd] {
return nd(gen);
});
kernels::Gemm(A.get(), B.get(), 1, N, K, M, C.get());
kernels::GemmRef(A.get(), B.get(), N, K, M, C_ref.get());
kernels::Gemm(A.get(), B.get(), batch, N, K, M, C.get());
kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get());
for (int i = 0; i < N * M; ++i) {
for (int i = 0; i < batch * N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
}
}
TEST(GEMMTest, gemv) {
index_t N = 17;
index_t K = 63;
std::unique_ptr<float[]> A(new float[N * K]);
std::unique_ptr<float[]> B(new float[K]);
std::unique_ptr<float[]> C(new float[N]);
std::unique_ptr<float[]> C_ref(new float[N]);
void GemvTest(index_t batch, index_t N, index_t M) {
std::unique_ptr<float[]> A(new float[N * M]);
std::unique_ptr<float[]> B(new float[batch * M]);
std::unique_ptr<float[]> C(new float[batch * N]);
std::unique_ptr<float[]> C_ref(new float[batch * N]);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * K,
std::generate(A.get(), A.get() + N * M,
[&gen, &nd] {
return nd(gen);
});
std::generate(B.get(), B.get() + K,
std::generate(B.get(), B.get() + batch * M,
[&gen, &nd] {
return nd(gen);
});
kernels::Gemv(A.get(), B.get(), 1, K, N, C.get());
kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get());
kernels::Gemv(A.get(), B.get(), batch, M, N, C.get());
kernels::GemvRef(A.get(), B.get(), batch, M, N, C_ref.get());
for (int i = 0; i < N; ++i) {
for (int i = 0; i < batch * N; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
}
}
} // namespace
TEST(GEMMTest, AlignedWithoutBatch) {
GemmTest(1, 1, 64, 128);
GemmTest(1, 2, 64, 128);
GemmTest(1, 3, 64, 128);
GemmTest(1, 4, 64, 128);
GemmTest(1, 5, 64, 128);
GemmTest(1, 6, 64, 128);
GemmTest(1, 7, 64, 128);
GemmTest(1, 17, 64, 128);
}
TEST(GEMMTest, UnalignedWithoutBatch) {
GemmTest(1, 1, 63, 127);
GemmTest(1, 2, 63, 127);
GemmTest(1, 3, 63, 127);
GemmTest(1, 4, 63, 127);
GemmTest(1, 5, 63, 127);
GemmTest(1, 6, 63, 127);
GemmTest(1, 7, 63, 127);
GemmTest(1, 17, 63, 127);
}
TEST(GEMMTest, UnalignedWithBatch) {
GemmTest(3, 1, 63, 127);
GemmTest(3, 2, 63, 127);
GemmTest(3, 3, 63, 127);
GemmTest(3, 4, 63, 127);
GemmTest(3, 5, 63, 127);
GemmTest(3, 6, 63, 127);
GemmTest(3, 7, 63, 127);
GemmTest(3, 17, 63, 127);
}
TEST(GEMMTest, gemv) {
GemvTest(1, 17, 63);
GemvTest(3, 17, 63);
}
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册