diff --git a/test/common/test_gemm.cpp b/test/common/test_gemm.cpp index 4ae4a1a7c5ba37c8745ff23ff2afccaaeb184176..27cd6d474a83f7f24b9e6fd5991a81322dd6d9ab 100644 --- a/test/common/test_gemm.cpp +++ b/test/common/test_gemm.cpp @@ -25,7 +25,6 @@ limitations under the License. */ #define c(i, j) c[(i)*ldc + (j)] #define c1(i, j) c1[(i)*ldc + (j)] - void print_matirx(int m, int n, int ldc, float *c) { for (int i = 0; i < m; ++i) { std::cout << c(i, 0); @@ -48,7 +47,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { float *c1 = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); float* scale = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m)); float* bias = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m)); - + srand(unsigned(time(0))); for (int i = 0; i < m * k; ++i) { a[i] = t1 + rand() % t2; @@ -62,7 +61,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { for (int i = 0; i < m; ++i) { bias[i] = t1 + rand() % t2; } - + for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { float r = 0; @@ -77,7 +76,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { c1(i, j) = r; } } - + paddle_mobile::operators::math::SgemmWithBn(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias); int eq = 0; @@ -89,22 +88,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { ++neq; } } - + if (pr > 0) { std::cout << "A:" << std::endl; print_matirx(m, k, lda, a); - std::cout << "B:" << std::endl; print_matirx(k, n, ldb, b); - std::cout << "C:" << std::endl; print_matirx(m, n, ldc, c); - std::cout << "C1:" << std::endl; print_matirx(m, n, ldc, c1); } - - std::cout << "mnk=" << m << " " << n << " " << k << + + std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu << " eq=" << eq << " neq=" << neq << std::endl; @@ -114,19 +110,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { paddle_mobile::memory::Free(c1); paddle_mobile::memory::Free(scale); paddle_mobile::memory::Free(bias); - + return 0; } int main() { - + do_sgemm(9, 9, 9, true, 10, 10, 10); do_sgemm(10, 6, 12, false, 10, 10, 0); do_sgemm(512, 256, 384, false, 10, 10, 0); do_sgemm(1366, 768, 256, false, 10, 10, 0); do_sgemm(1255, 755, 333, false, 10, 10, 0); do_sgemm(555, 777, 999, false, 10, 10, 0); - + do_sgemm(10, 6, 12, true, -4, 10, 0); do_sgemm(512, 256, 384, true, -4, 10, 0); do_sgemm(1366, 768, 256, true, -4, 10, 0);