diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 605fa17c3c70ec3151cc1a2fb249edab336548a1..d3e6de3134ff91f47c66c927194a5ba688e931b0 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3230,6 +3230,8 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, int L1 = 64 / max_threads * 1024; KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3255,7 +3257,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3284,12 +3286,10 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); @@ -3352,6 +3352,8 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int L1 = 64 / max_threads * 1024; KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3377,7 +3379,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3405,12 +3407,10 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); @@ -3480,6 +3480,8 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, int L1 = 8 * 1024; KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3505,7 +3507,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3533,12 +3535,10 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));