diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc index 096127e184d622c228ad7120093b38dd567e3809..99ff4a5b0b0eb9e0e00b517bb8d52cab54314ecd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc @@ -23,6 +23,11 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { FullconnectionCPUKernel::~FullconnectionCPUKernel() { + FreeBuf(); + return; +} + +void FullconnectionCPUKernel::FreeBuf() { if (a_c8_ptr_ != nullptr) { free(a_c8_ptr_); a_c8_ptr_ = nullptr; @@ -41,7 +46,11 @@ FullconnectionCPUKernel::~FullconnectionCPUKernel() { } } -int FullconnectionCPUKernel::ReSize() { return RET_OK; } +int FullconnectionCPUKernel::ReSize() { + FreeBuf(); + Init(); + return RET_OK; +} int FullconnectionCPUKernel::Init() { if (context_->infer_shape_interrupt_ && !context_->running_) { @@ -75,16 +84,44 @@ int FullconnectionCPUKernel::Init() { return RET_MEMORY_FAILED; } memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)); - RowMajor2Col8Major(reinterpret_cast(in_tensors_[1]->Data()), b_r8_ptr_, fc_param_->col_, fc_param_->deep_); c_r8x8_ptr_ = reinterpret_cast(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float))); if (c_r8x8_ptr_ == nullptr) { return RET_MEMORY_FAILED; } memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)); + + fc_param_->a_const_ = false; + fc_param_->b_const_ = false; + InitMatrixA(reinterpret_cast(in_tensors_[0]->Data()), a_c8_ptr_); + InitMatrixB(reinterpret_cast(in_tensors_[1]->Data()), b_r8_ptr_); return RET_OK; } +void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { + if (fc_param_->a_const_ == true) { + return; + } + if (src_ptr == nullptr) { + return; + } + fc_param_->a_const_ = true; + RowMajor2Col8Major(src_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + return; +} + +void FullconnectionCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) { + if (fc_param_->b_const_ == true) { + return; + } + if (src_ptr == nullptr) { + return; + } + fc_param_->b_const_ = true; + RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); + return; +} + int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { auto fc = reinterpret_cast(cdata); auto error_code = fc->DoMatmul(task_id); @@ -115,9 +152,11 @@ int FullconnectionCPUKernel::Run() { return prepare_ret; } auto a_ptr = reinterpret_cast(in_tensors_.at(0)->Data()); + auto b_ptr = reinterpret_cast(in_tensors_.at(1)->Data()); auto output_ptr = reinterpret_cast(out_tensors_.at(0)->Data()); - RowMajor2Col8Major(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + InitMatrixA(a_ptr, a_c8_ptr_); + InitMatrixB(b_ptr, b_r8_ptr_); LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h index b8b0d5defe1a00fa29fbbf8de61c656e51101a0e..9f66b03757551bb49681a5c47ddd14bdbbbb1cc1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h @@ -40,6 +40,11 @@ class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel { public: int DoMatmul(int task_id); + void FreeBuf(); + + private: + void InitMatrixA(float *src_ptr, float *dst_ptr); + void InitMatrixB(float *src_ptr, float *dst_ptr); private: float *a_c8_ptr_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc index 2e35323ba6d98a18a117e7385fd29ed74a0a2905..5aa685be4031752aaa5419a50371c35713ba3137 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc @@ -78,6 +78,11 @@ int MatmulCPUKernel::Init() { } memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(float)); + params_->a_const_ = false; + params_->b_const_ = false; + InitMatrixA(reinterpret_cast(in_tensors_[0]->Data()), a_c8_ptr_); + InitMatrixB(reinterpret_cast(in_tensors_[1]->Data()), b_r8_ptr_); + if (in_tensors_.size() == 3) { bias_ptr_ = reinterpret_cast(malloc(params_->col_8_ * sizeof(float))); memset(bias_ptr_, 0, params_->col_8_ * sizeof(float)); @@ -89,6 +94,40 @@ int MatmulCPUKernel::Init() { return RET_OK; } +void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { + if (params_->a_const_ == true) { + return; + } + if (src_ptr == nullptr) { + return; + } + params_->a_const_ = true; + + if (params_->a_transpose_) { + RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->row_); + } else { + RowMajor2Col8Major(src_ptr, a_c8_ptr_, params_->row_, params_->deep_); + } + return; +} + +void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) { + if (params_->b_const_ == true) { + return; + } + if (src_ptr == nullptr) { + return; + } + params_->b_const_ = true; + + if (params_->b_transpose_) { + RowMajor2Col8Major(src_ptr, dst_ptr, params_->col_, params_->deep_); + } else { + RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->col_); + } + return; +} + int MatmulCPUKernel::RunImpl(int task_id) { int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_); if (cur_oc <= 0) { @@ -131,16 +170,10 @@ int MatmulCPUKernel::Run() { auto cur_a_ptr = a_ptr + i * a_stride; auto cur_b_ptr = b_ptr + i * b_stride; auto cur_c_ptr = c_ptr + i * c_stride; - if (params_->a_transpose_) { - RowMajor2Row8Major(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_); - } else { - RowMajor2Col8Major(cur_a_ptr, a_c8_ptr_, params_->row_, params_->deep_); - } - if (params_->b_transpose_) { - RowMajor2Col8Major(cur_b_ptr, b_r8_ptr_, params_->col_, params_->deep_); - } else { - RowMajor2Row8Major(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_); - } + + InitMatrixA(cur_a_ptr, a_c8_ptr_); + InitMatrixB(cur_b_ptr, b_r8_ptr_); + LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_); Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_, params_->col_); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h index 4642c8ad6c09212ee4e93c2b6df5100ed5cb2c79..38c0e445ac29a62df450059426f553d55dc6c4d9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h @@ -35,6 +35,10 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel { int Run() override; int RunImpl(int task_id); + private: + void InitMatrixA(float *src_ptr, float *dst_ptr); + void InitMatrixB(float *src_ptr, float *dst_ptr); + private: float *a_c8_ptr_; float *b_r8_ptr_; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h index be01e4beb2960c9160cc0db6dee0cf35d08fbb3c..b2b24064d354af1d296d8778503402c0fd3d6000 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h @@ -33,6 +33,8 @@ typedef struct MatMulParameter { int batch; bool a_transpose_; /* false : row-major */ bool b_transpose_; /* true : col-major */ + bool a_const_; + bool b_const_; ActType act_type_; } MatMulParameter;