diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.cc index 6de61af026cfed12af4a92b89442c9044cf55e9b..8faee1f705feeca668f5e6ddd295977b858b2a79 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.cc @@ -23,44 +23,75 @@ using mindspore::lite::KernelRegistrar; using mindspore::schema::PrimitiveType_BatchNorm; namespace mindspore::kernel { -int BatchnormFp16CPUKernel::DoExecute(int task_id) { - auto param = reinterpret_cast(op_parameter_); - - if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { - auto input = in_tensors_.at(0); - auto mean = in_tensors_.at(1); - auto variance = in_tensors_.at(2); - auto output = out_tensors_.at(0); +int BatchnormFp16CPUKernel::InitConstTensor() { + isFloat32Tensor_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; + if (isFloat32Tensor_) { + auto mean_fp32 = in_tensors_.at(1); + auto variance_fp32 = in_tensors_.at(2); + mean_ = malloc(mean_fp32->ElementsNum() * sizeof(float16_t)); + variance_ = malloc(variance_fp32->ElementsNum() * sizeof(float16_t)); + if (mean_ == nullptr || variance_ == nullptr) { + FreeMeanAndVariance(); + return RET_ERROR; + } + Float32ToFloat16(reinterpret_cast(mean_fp32->Data()), + reinterpret_cast(mean_), mean_fp32->ElementsNum()); + Float32ToFloat16(reinterpret_cast(variance_fp32->Data()), + reinterpret_cast(variance_), variance_fp32->ElementsNum()); + } else { + BatchnormCPUKernel::InitConstTensor(); + } + return RET_OK; +} - auto input_fp16 = context_->allocator->Malloc(input->ElementsNum() * sizeof(float16_t)); - auto mean_fp16 = context_->allocator->Malloc(mean->ElementsNum() * sizeof(float16_t)); - auto variance_fp16 = context_->allocator->Malloc(variance->ElementsNum() * sizeof(float16_t)); - auto output_fp16 = context_->allocator->Malloc(output->ElementsNum() * sizeof(float16_t)); - if (input_fp16 == nullptr || mean_fp16 == nullptr || variance_fp16 == nullptr || output_fp16 == nullptr) { - context_->allocator->Free(input_fp16); - context_->allocator->Free(mean_fp16); - context_->allocator->Free(variance_fp16); - context_->allocator->Free(output_fp16); +int BatchnormFp16CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; + return ret; + } + auto input_fp32 = in_tensors_.at(0); + auto output_fp32 = out_tensors_.at(0); + if (isFloat32Tensor_) { + input_ = context_->allocator->Malloc(input_fp32->ElementsNum() * sizeof(float16_t)); + output_ = context_->allocator->Malloc(output_fp32->ElementsNum() * sizeof(float16_t)); + if (input_ == nullptr || output_ == nullptr) { + FreeInputAndOutput(); + return RET_ERROR; } - Float32ToFloat16(reinterpret_cast(input->Data()), - reinterpret_cast(input_fp16), input->ElementsNum()); - Float32ToFloat16(reinterpret_cast(mean->Data()), - reinterpret_cast(mean_fp16), mean->ElementsNum()); - Float32ToFloat16(reinterpret_cast(variance->Data()), - reinterpret_cast(variance_fp16), variance->ElementsNum()); + Float32ToFloat16(reinterpret_cast(input_fp32->Data()), + reinterpret_cast(input_), input_fp32->ElementsNum()); + } else { + input_ = in_tensors_.at(0)->Data(); + output_ = out_tensors_.at(0)->Data(); + } + ret = LiteBackendParallelLaunch(BatchNormRun, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]"; + } + if (isFloat32Tensor_) { + Float16ToFloat32(reinterpret_cast(output_), reinterpret_cast(output_fp32->Data()), + output_fp32->ElementsNum()); + FreeInputAndOutput(); + } + return ret; +} - BatchNormFp16(input_fp16, mean_fp16, variance_fp16, param, task_id, output_fp16); +int BatchnormFp16CPUKernel::DoExecute(int task_id) { + auto param = reinterpret_cast(op_parameter_); + BatchNormFp16(input_, mean_, variance_, param, task_id, output_); + return mindspore::lite::RET_OK; +} - Float16ToFloat32(reinterpret_cast(output_fp16), reinterpret_cast(output), - output->ElementsNum()); - context_->allocator->Free(input_fp16); - context_->allocator->Free(mean_fp16); - context_->allocator->Free(variance_fp16); - context_->allocator->Free(output_fp16); - return mindspore::lite::RET_OK; +void BatchnormFp16CPUKernel::FreeInputAndOutput() { + if (input_ != nullptr) { + context_->allocator->Free(input_); + input_ = nullptr; + } + if (output_ != nullptr) { + context_->allocator->Free(output_); + output_ = nullptr; } - BatchNormFp16(in_tensors_.at(0)->Data(), mean_, variance_, param, task_id, out_tensors_.at(0)->Data()); - return mindspore::lite::RET_OK; } kernel::LiteKernel *CpuBatchnormFp16KernelCreator(const std::vector &inputs, @@ -83,5 +114,5 @@ kernel::LiteKernel *CpuBatchnormFp16KernelCreator(const std::vector