diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 6ef9fb2a8252e82014ebebc22f82066eeb324c0d..14269817ededd097c4c9ade20be5ee773c02d692 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -36,13 +36,35 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, int N = dim_out[1]; int K = (!trans_a) ? dim_a[1] : dim_a[0]; + if (trans_a) { + int numel = matrix_a.numel(); + int m = matrix_a.dims()[0]; + int n = matrix_a.dims()[1]; + float *tmp = (float *)(matrix_a.data()); + float *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * numel)); + int index = 0; + for (int j = 0; j < n; j++) { + for (int i = 0; i < m; i++) { + a[index++] = tmp[i * n + j]; + } + } +#ifdef _OPENMP + Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); +#else + Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); +#endif + } else { #ifdef _OPENMP - Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), - N, beta, matrix_out->data(), N, relu, bias); + Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), + N, beta, matrix_out->data(), N, relu, bias); #else - Sgemm(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, - beta, matrix_out->data(), N, relu, bias); + Sgemm(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, + beta, matrix_out->data(), N, relu, bias); #endif + } } template <>