diff --git a/test/common/test_gemm.cpp b/test/common/test_gemm.cpp index 8cb778c458034aecf6cea89fcf0d3e2a3d8118ba..4ae4a1a7c5ba37c8745ff23ff2afccaaeb184176 100644 --- a/test/common/test_gemm.cpp +++ b/test/common/test_gemm.cpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include +#include #include "../test_helper.h" #include "common/log.h" #include "memory/t_malloc.h" @@ -20,63 +22,115 @@ limitations under the License. */ #define a(i, j) a[(i)*lda + (j)] #define b(i, j) b[(i)*ldb + (j)] +#define c(i, j) c[(i)*ldc + (j)] #define c1(i, j) c1[(i)*ldc + (j)] -#define m 62 -#define n 63 -#define k 74 -int main() { +void print_matirx(int m, int n, int ldc, float *c) { + for (int i = 0; i < m; ++i) { + std::cout << c(i, 0); + for (int j = 1; j < n; ++j) { + std::cout << " | " << c(i, j); + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { int lda = k; int ldb = n; int ldc = n; - float *a = - static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * k)); - float *b = - static_cast(paddle_mobile::memory::Alloc(sizeof(float) * k * n)); - float *c = - static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); - float *c1 = - static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); - + float *a = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * k)); + float *b = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * k * n)); + float *c = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); + float *c1 = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); + float* scale = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m)); + float* bias = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m)); + + srand(unsigned(time(0))); for (int i = 0; i < m * k; ++i) { - a[i] = 2; + a[i] = t1 + rand() % t2; } for (int i = 0; i < k * n; ++i) { - b[i] = 2; + b[i] = t1 + rand() % t2; } - for (int i = 0; i < m * n; ++i) { - c[i] = 2; - c1[i] = 2; + for (int i = 0; i < m; ++i) { + scale[i] = t1 + rand() % t2; } - - auto time1 = time(); - // paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, - // c, - // ldc); - auto time2 = time(); - DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n"; - for (int i = 0; i < m * n; ++i) { - std::cout << c[i] << " | "; - if (i % n == (n - 1)) { - std::cout << std::endl; - } + for (int i = 0; i < m; ++i) { + bias[i] = t1 + rand() % t2; } - for (int j = 0; j < n; ++j) { - for (int i = 0; i < m; ++i) { - c1(i, j) *= 0.3; - for (int p = 0; p < k; ++p) { - c1(i, j) += 0.9 * a(i, p) * b(p, j); + + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + float r = 0; + for (int p = 0; p < k; p++) { + r += a(i, p) * b(p, j); + } + r *= scale[i]; + r += bias[i]; + if (relu && (r < 0)) { + r = 0; } + c1(i, j) = r; } } - std::cout << "正确结果对比:" << std::endl; + + paddle_mobile::operators::math::SgemmWithBn(m, n, k, 0.9, a, lda, + b, ldb, 0.3, c, ldc, relu, scale, bias); + int eq = 0; + int neq = 0; for (int i = 0; i < m * n; ++i) { - std::cout << c1[i] << " | "; - if (i % n == (n - 1)) { - std::cout << std::endl; + if (static_cast(c[i]) == static_cast(c1[i])) { + ++eq; + } else { + ++neq; } } + + if (pr > 0) { + std::cout << "A:" << std::endl; + print_matirx(m, k, lda, a); + + std::cout << "B:" << std::endl; + print_matirx(k, n, ldb, b); + + std::cout << "C:" << std::endl; + print_matirx(m, n, ldc, c); + + std::cout << "C1:" << std::endl; + print_matirx(m, n, ldc, c1); + } + + std::cout << "mnk=" << m << " " << n << " " << k << + " relu=" << relu << + " eq=" << eq << " neq=" << neq << std::endl; + + paddle_mobile::memory::Free(a); + paddle_mobile::memory::Free(b); + paddle_mobile::memory::Free(c); + paddle_mobile::memory::Free(c1); + paddle_mobile::memory::Free(scale); + paddle_mobile::memory::Free(bias); + + return 0; +} + +int main() { + + do_sgemm(9, 9, 9, true, 10, 10, 10); + do_sgemm(10, 6, 12, false, 10, 10, 0); + do_sgemm(512, 256, 384, false, 10, 10, 0); + do_sgemm(1366, 768, 256, false, 10, 10, 0); + do_sgemm(1255, 755, 333, false, 10, 10, 0); + do_sgemm(555, 777, 999, false, 10, 10, 0); + + do_sgemm(10, 6, 12, true, -4, 10, 0); + do_sgemm(512, 256, 384, true, -4, 10, 0); + do_sgemm(1366, 768, 256, true, -4, 10, 0); + do_sgemm(1255, 755, 333, true, -4, 10, 0); + do_sgemm(555, 777, 999, true, -4, 10, 0); return 0; }