提交 f6b4b58c 编写于 作者: E eclipsess

Merge remote-tracking branch 'upstream/develop' into develop

...@@ -3230,6 +3230,8 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -3230,6 +3230,8 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
int L1 = 64 / max_threads * 1024; int L1 = 64 / max_threads * 1024;
KC = k; KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) { if (m > n) {
// 对 A 分块 // 对 A 分块
MC = L1 / (KC * sizeof(float)); MC = L1 / (KC * sizeof(float));
...@@ -3255,7 +3257,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -3255,7 +3257,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
(*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB);
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else { } else {
...@@ -3284,12 +3286,10 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -3284,12 +3286,10 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
(*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); (*this.*procPackA)(m, KC, m % MR, A, lda, packedA);
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
} }
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>( packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
...@@ -3352,6 +3352,8 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, ...@@ -3352,6 +3352,8 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
int L1 = 64 / max_threads * 1024; int L1 = 64 / max_threads * 1024;
KC = k; KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) { if (m > n) {
// 对 A 分块 // 对 A 分块
MC = L1 / (KC * sizeof(float)); MC = L1 / (KC * sizeof(float));
...@@ -3377,7 +3379,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, ...@@ -3377,7 +3379,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
(*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB);
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else { } else {
...@@ -3405,12 +3407,10 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, ...@@ -3405,12 +3407,10 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
(*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); (*this.*procPackA)(m, KC, m % MR, A, lda, packedA);
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
} }
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>( packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
...@@ -3480,6 +3480,8 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3480,6 +3480,8 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
int L1 = 8 * 1024; int L1 = 8 * 1024;
KC = k; KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) { if (m > n) {
// 对 A 分块 // 对 A 分块
MC = L1 / (KC * sizeof(float)); MC = L1 / (KC * sizeof(float));
...@@ -3505,7 +3507,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3505,7 +3507,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
(*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB);
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else { } else {
...@@ -3533,12 +3535,10 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3533,12 +3535,10 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
(*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); (*this.*procPackA)(m, KC, m % MR, A, lda, packedA);
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
} }
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>( packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册