提交 94fb88b4 编写于 作者: S Shuai Yuan 提交者: Houjiang Chen

BugFix: fix multi-threadings incorrectness of cpu gemm (#1569)

* BugFix: fix multi-threadings incorrectness of cpu gemm

* BugFix: fix lack of header file for memset

* BugFix: fix lacking of header file for memset
上级 769c8083
...@@ -106,13 +106,13 @@ class GemmExecutor : public Executor { ...@@ -106,13 +106,13 @@ class GemmExecutor : public Executor {
// struct timeval tv_begin, tv_end; // struct timeval tv_begin, tv_end;
// gettimeofday(&tv_begin,NULL); // gettimeofday(&tv_begin,NULL);
if (M_ > N_) { if (M_ > N_) {
int nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width(); nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width();
lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_ * num_threads_; lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_ * num_threads_;
rhs_worksize_ = sizeof(Itype) * K_ * nblock; rhs_worksize_ = sizeof(Itype) * K_ * nblock;
out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_; out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_;
ldc_ = nblock; ldc_ = nblock;
} else { } else {
int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height(); mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height();
lhs_worksize_ = sizeof(Itype) * mblock * K_; lhs_worksize_ = sizeof(Itype) * mblock * K_;
rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_; rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_;
out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_; out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_;
...@@ -174,7 +174,7 @@ class GemmExecutor : public Executor { ...@@ -174,7 +174,7 @@ class GemmExecutor : public Executor {
int thread_id = 0; int thread_id = 0;
#endif #endif
float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id; float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id;
float *local_C = out_workspace_ + lhs_tile_num_ * ldc_ * thread_id; float *local_C = out_workspace_ + mblock * ldc_ * thread_id;
// load rhs into rhs_workspace // load rhs into rhs_workspace
strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false); strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false);
for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) {
...@@ -225,6 +225,9 @@ class GemmExecutor : public Executor { ...@@ -225,6 +225,9 @@ class GemmExecutor : public Executor {
unsigned int out_worksize_ = 0; unsigned int out_worksize_ = 0;
unsigned int ldc_ = 0; unsigned int ldc_ = 0;
unsigned int mblock = 0;
unsigned int nblock = 0;
Itype *lhs_workspace_ = nullptr; Itype *lhs_workspace_ = nullptr;
Itype *rhs_workspace_ = nullptr; Itype *rhs_workspace_ = nullptr;
Otype *out_workspace_ = nullptr; Otype *out_workspace_ = nullptr;
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
#include <memory>
#include "operators/math/math.h" #include "operators/math/math.h"
namespace paddle_mobile { namespace paddle_mobile {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册