diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h index ebff90d4eea901905f0c8c8c11ac2b907f7ef7f9..1a536cba4e7ce6a52ba409856af9151152bf87eb 100644 --- a/src/operators/math/gemm/executor.h +++ b/src/operators/math/gemm/executor.h @@ -106,13 +106,13 @@ class GemmExecutor : public Executor { // struct timeval tv_begin, tv_end; // gettimeofday(&tv_begin,NULL); if (M_ > N_) { - int nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width(); + nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width(); lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_ * num_threads_; rhs_worksize_ = sizeof(Itype) * K_ * nblock; out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_; ldc_ = nblock; } else { - int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height(); + mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height(); lhs_worksize_ = sizeof(Itype) * mblock * K_; rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_; out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_; @@ -174,7 +174,7 @@ class GemmExecutor : public Executor { int thread_id = 0; #endif float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id; - float *local_C = out_workspace_ + lhs_tile_num_ * ldc_ * thread_id; + float *local_C = out_workspace_ + mblock * ldc_ * thread_id; // load rhs into rhs_workspace strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false); for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { @@ -225,6 +225,9 @@ class GemmExecutor : public Executor { unsigned int out_worksize_ = 0; unsigned int ldc_ = 0; + unsigned int mblock = 0; + unsigned int nblock = 0; + Itype *lhs_workspace_ = nullptr; Itype *rhs_workspace_ = nullptr; Otype *out_workspace_ = nullptr; diff --git a/src/operators/math/gemm/gemm_kernel.h b/src/operators/math/gemm/gemm_kernel.h index fcffd5ec86daf52e8e4a07dc6dead8766b1ba123..a3c1eabf41fa7a325038824f7c518dd41a45b582 100644 --- a/src/operators/math/gemm/gemm_kernel.h +++ b/src/operators/math/gemm/gemm_kernel.h @@ -17,6 +17,7 @@ limitations under the License. */ #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include +#include #include "operators/math/math.h" namespace paddle_mobile {