diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index c32a210743a887946f2a1111b020c20c088d766e..4c0e7e0f1ba5eec7171ff40b114a462193993739 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" #include "nnacl/fp16/arithmetic_fp16.h" #include "nnacl/fp16/cast_fp16.h" #include "schema/model_generated.h" @@ -29,7 +30,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Add; using mindspore::schema::PrimitiveType_Div; -using mindspore::schema::PrimitiveType_Eltwise; using mindspore::schema::PrimitiveType_Equal; using mindspore::schema::PrimitiveType_FloorDiv; using mindspore::schema::PrimitiveType_FloorMod; @@ -47,121 +47,57 @@ using mindspore::schema::PrimitiveType_SquaredDifference; using mindspore::schema::PrimitiveType_Sub; namespace mindspore::kernel { -void ArithmeticFP16CPUKernel::FreeTmpBuffer() { - if (input0_fp16_ != nullptr) { - context_->allocator->Free(input0_fp16_); - input0_fp16_ = nullptr; - } - if (input1_fp16_ != nullptr) { - context_->allocator->Free(input1_fp16_); - input1_fp16_ = nullptr; - } - if (output_fp16_ != nullptr) { - context_->allocator->Free(output_fp16_); - output_fp16_ = nullptr; +ARITHMETIC_FUNC_INFO_FP16 arithmetic_fun_table_fp16[] = { + {PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16}, + {PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16}, + {PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMulFp16, ElementOptMulFp16}, + {PrimitiveType_Add, schema::ActivationType_RELU, ElementAddReluFp16, ElementOptAddReluFp16}, + {PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16}, + {PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAddFp16, ElementOptAddFp16}, + {PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubReluFp16, ElementOptSubReluFp16}, + {PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16}, + {PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSubFp16, ElementOptSubFp16}, + {PrimitiveType_Div, schema::ActivationType_RELU, ElementDivReluFp16, ElementOptDivReluFp16}, + {PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, + {PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDivFp16, ElementOptDivFp16}, + {PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorModFp16, ElementOptFloorModFp16}, + {PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDivFp16, ElementOptFloorDivFp16}, + {PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAndFp16, ElementOptLogicalAndFp16}, + {PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOrFp16, ElementOptLogicalOrFp16}, + {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16, + ElementOptSquaredDifferenceFp16}, + {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16}, + {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}, + {PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16}, + {PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16}, + {PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16}, + {PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16}, + {PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16}, + {PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16, + ElementOptGreaterEqualFp16}, +}; + +ArithmeticFuncFp16 GetArithmeticFun(int primitive_type, int activation_type) { + for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) { + if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type && + arithmetic_fun_table_fp16[i].activation_type_ == activation_type) { + return arithmetic_fun_table_fp16[i].func_; + } } + return nullptr; } -ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() {} +ArithmeticOptFuncFp16 GetOptimizedArithmeticFun(int primitive_type, int activation_type) { + for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) { + if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type && + arithmetic_fun_table_fp16[i].activation_type_ == activation_type) { + return arithmetic_fun_table_fp16[i].opt_func_; + } + } + return nullptr; +} int ArithmeticFP16CPUKernel::Init() { - switch (op_parameter_->type_) { - case PrimitiveType_Mul: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementMulReluFp16; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementMulRelu6Fp16; - break; - default: - arithmetic_run_ = ElementMulFp16; - break; - } - break; - case PrimitiveType_Add: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementAddReluFp16; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementAddRelu6Fp16; - break; - default: - arithmetic_run_ = ElementAddFp16; - break; - } - break; - case PrimitiveType_Sub: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementSubReluFp16; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementSubRelu6Fp16; - break; - default: - arithmetic_run_ = ElementSubFp16; - break; - } - break; - case PrimitiveType_Div: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementDivReluFp16; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementDivRelu6Fp16; - break; - default: - arithmetic_run_ = ElementDivFp16; - break; - } - break; - case PrimitiveType_FloorMod: - arithmetic_run_ = ElementFloorModFp16; - break; - case PrimitiveType_FloorDiv: - arithmetic_run_ = ElementFloorDivFp16; - break; - case PrimitiveType_LogicalAnd: - arithmetic_run_ = ElementLogicalAndFp16; - break; - case PrimitiveType_LogicalOr: - arithmetic_run_ = ElementLogicalOrFp16; - break; - case PrimitiveType_SquaredDifference: - arithmetic_run_ = ElementSquaredDifferenceFp16; - break; - case PrimitiveType_Maximum: - arithmetic_run_ = ElementMaximumFp16; - break; - case PrimitiveType_Minimum: - arithmetic_run_ = ElementMinimumFp16; - break; - case PrimitiveType_NotEqual: - arithmetic_run_ = ElementNotEqualFp16; - break; - case PrimitiveType_Equal: - arithmetic_run_ = ElementEqualFp16; - break; - case PrimitiveType_Less: - arithmetic_run_ = ElementLessFp16; - break; - case PrimitiveType_LessEqual: - arithmetic_run_ = ElementLessEqual; - break; - case PrimitiveType_Greater: - arithmetic_run_ = ElementGreaterFp16; - break; - case PrimitiveType_GreaterEqual: - arithmetic_run_ = ElementGreaterEqualFp16; - break; - default: - MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; - arithmetic_run_ = nullptr; - break; - } if (!InferShapeDone()) { return RET_OK; } @@ -169,162 +105,47 @@ int ArithmeticFP16CPUKernel::Init() { } int ArithmeticFP16CPUKernel::ReSize() { - arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); - arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); - arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); + param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); + param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); + param_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { - switch (arithmeticParameter_->op_parameter_.type_) { - case PrimitiveType_Mul: - arithmeticParameter_->broadcasting_ = false; - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_opt_run_ = ElementOptMulReluFp16; - break; - case schema::ActivationType_RELU6: - arithmetic_opt_run_ = ElementOptMulRelu6Fp16; - break; - default: - arithmetic_opt_run_ = ElementOptMulFp16; - break; - } - break; - case PrimitiveType_Add: - arithmeticParameter_->broadcasting_ = false; - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_opt_run_ = ElementOptAddReluFp16; - break; - case schema::ActivationType_RELU6: - arithmetic_opt_run_ = ElementOptAddRelu6Fp16; - break; - default: - arithmetic_opt_run_ = ElementOptAddFp16; - break; - } - break; - case PrimitiveType_Sub: - arithmeticParameter_->broadcasting_ = false; - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_opt_run_ = ElementOptSubReluFp16; - break; - case schema::ActivationType_RELU6: - arithmetic_opt_run_ = ElementOptSubRelu6Fp16; - break; - default: - arithmetic_opt_run_ = ElementOptSubFp16; - break; - } - break; - case PrimitiveType_Div: - arithmeticParameter_->broadcasting_ = false; - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_opt_run_ = ElementOptDivReluFp16; - break; - case schema::ActivationType_RELU6: - arithmetic_opt_run_ = ElementOptDivRelu6Fp16; - break; - default: - arithmetic_opt_run_ = ElementOptDivFp16; - break; - } - break; - case PrimitiveType_FloorMod: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptFloorModFp16; - break; - case PrimitiveType_FloorDiv: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptFloorDivFp16; - break; - case PrimitiveType_LogicalAnd: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptLogicalAndFp16; - break; - case PrimitiveType_LogicalOr: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptLogicalOrFp16; - break; - case PrimitiveType_SquaredDifference: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptSquaredDifferenceFp16; - break; - case PrimitiveType_Maximum: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptMaximumFp16; - break; - case PrimitiveType_Minimum: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptMinimumFp16; - break; - case PrimitiveType_NotEqual: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptNotEqualFp16; - break; - case PrimitiveType_Equal: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptEqualFp16; - break; - case PrimitiveType_Less: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptLessFp16; - break; - case PrimitiveType_LessEqual: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptLessEqualFp16; - break; - case PrimitiveType_Greater: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptGreaterFp16; - break; - case PrimitiveType_GreaterEqual: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptGreaterEqualFp16; - break; - default: - break; - } + if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { + param_->broadcasting_ = false; + arithmetic_opt_func_ = GetOptimizedArithmeticFun(param_->op_parameter_.type_, param_->activation_type_); + } else { + arithmetic_func_ = GetArithmeticFun(param_->op_parameter_.type_, param_->activation_type_); } - - if (arithmeticParameter_->broadcasting_) { + if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) { + MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; + return RET_ERROR; + } + if (param_->broadcasting_) { outside_ = 1; - for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) { - if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { + for (int i = param_->ndim_ - 1; i >= 0; --i) { + if (param_->in_shape0_[i] != param_->in_shape1_[i]) { break_pos_ = i; break; } - outside_ *= arithmeticParameter_->out_shape_[i]; + outside_ *= param_->out_shape_[i]; } - ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); - ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); - ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); + ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_); + ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_); + ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_); } return RET_OK; } int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, - int out_count, int out_thread_stride) { + int out_count, int cur_offset) { if (dim > break_pos_) { - int error_code = - arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count); - if (output_fp16_ != nullptr) { - auto output_fp32 = reinterpret_cast(out_tensors_[0]->Data()); - int bias = output - output_fp16_; - output_fp32 += bias; - Float16ToFloat32(output + out_thread_stride, output_fp32 + out_thread_stride, out_count); - } - return error_code; + return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count); } - for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { - int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; - int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; - int error_code = - BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim], - input1 + pos1_ * arithmeticParameter_->in_strides1_[dim], - output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride); - if (error_code != RET_OK) { + for (int i = 0; i < param_->out_shape_[dim]; ++i) { + int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i; + int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i; + int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim], + output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset); + if (ret != RET_OK) { return RET_ERROR; } } @@ -332,62 +153,33 @@ int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, } int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { - auto input0 = reinterpret_cast(in_tensors_[0]->Data()); - auto input1 = reinterpret_cast(in_tensors_[1]->Data()); - auto output = reinterpret_cast(out_tensors_[0]->Data()); - auto element_num = out_tensors_[0]->ElementsNum(); - - float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_; - float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_; - auto output_data = output_fp16_ == nullptr ? output : output_fp16_; - int stride = UP_DIV(element_num, context_->thread_num_); - int count = MSMIN(stride, element_num - stride * task_id); - auto thread_stride = stride * task_id; + int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_); + int cur_offset = stride_per_thread * task_id; + int cur_count = MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset); - if (arithmetic_run_ == nullptr) { - MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; - return RET_ERROR; - } - - int error_code = RET_OK; - if (arithmeticParameter_->broadcasting_) { - stride = UP_DIV(outside_, context_->thread_num_); - out_count_ = MSMIN(stride, outside_ - stride * task_id); - out_thread_stride_ = stride * task_id; - error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_); - } else if (arithmetic_opt_run_ != nullptr) { - if (arithmeticParameter_->in_elements_num0_ == 1) { - error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count, - arithmeticParameter_); - } else if (arithmeticParameter_->in_elements_num1_ == 1) { - error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1, output_data + thread_stride, count, - arithmeticParameter_); - } else { - error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1 + thread_stride, - output_data + thread_stride, count, arithmeticParameter_); - } + int ret = RET_OK; + if (param_->broadcasting_) { + ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset); + } else if (param_->in_elements_num0_ == 1) { + ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_); + } else if (param_->in_elements_num1_ == 1) { + ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_); } else { - error_code = - arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count); + ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count); } - if (error_code != RET_OK) { - return RET_ERROR; - } - if (output_fp16_ != nullptr && !arithmeticParameter_->broadcasting_) { - auto output_fp32 = reinterpret_cast(out_tensors_[0]->Data()); - Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret; } - return RET_OK; + return ret; } -static int ArithmeticsRun_Fp16(void *cdata, int task_id) { +static int ArithmeticsRunFp16(void *cdata, int task_id) { auto arithmetic_kernel = reinterpret_cast(cdata); - auto error_code = arithmetic_kernel->DoArithmetic(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; + auto ret = arithmetic_kernel->DoArithmetic(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]"; } - return RET_OK; + return ret; } int ArithmeticFP16CPUKernel::Run() { @@ -396,43 +188,45 @@ int ArithmeticFP16CPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail!ret: " << ret; return ret; } + auto output_tensor = out_tensors_.at(0); + is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; + is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32; + is_output_fp32_ = output_tensor->data_type() == kNumberTypeFloat32; - arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); - arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); - arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { - output_fp16_ = reinterpret_cast(malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t))); - if (output_fp16_ == nullptr) { - MS_LOG(ERROR) << "malloc data fail!"; - FreeTmpBuffer(); - return RET_ERROR; - } + input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); + input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_); + output_fp16_ = MallocOutputFp16(output_tensor, context_); + if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) { + MS_LOG(ERROR) << "Memory allocation failed"; + FreeTmpBuffer(); + return RET_ERROR; } - if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { - input0_fp16_ = reinterpret_cast(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); - if (input0_fp16_ == nullptr) { - MS_LOG(ERROR) << "malloc data fail!"; - FreeTmpBuffer(); - return RET_ERROR; - } - Float32ToFloat16(reinterpret_cast(in_tensors_[0]->Data()), input0_fp16_, - arithmeticParameter_->in_elements_num0_); + ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRunFp16, this, context_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]"; } - if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { - input1_fp16_ = reinterpret_cast(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); - if (input1_fp16_ == nullptr) { - MS_LOG(ERROR) << "malloc data fail!"; - FreeTmpBuffer(); - return RET_ERROR; - } - Float32ToFloat16(reinterpret_cast(in_tensors_[1]->Data()), input1_fp16_, - arithmeticParameter_->in_elements_num1_); + if (is_output_fp32_) { + Float16ToFloat32(output_fp16_, reinterpret_cast(output_tensor->Data()), output_tensor->ElementsNum()); } - ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRun_Fp16, this, context_->thread_num_); FreeTmpBuffer(); return ret; } +void ArithmeticFP16CPUKernel::FreeTmpBuffer() { + if (is_input0_fp32_) { + context_->allocator->Free(input0_fp16_); + input0_fp16_ = nullptr; + } + if (is_input1_fp32_) { + context_->allocator->Free(input1_fp16_); + input1_fp16_ = nullptr; + } + if (is_output_fp32_) { + context_->allocator->Free(output_fp16_); + output_fp16_ = nullptr; + } +} + kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, @@ -473,5 +267,4 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16Kernel REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h index 5c8a52b77688a213668f6e9fa00b7cf7a091ab15..d3ec77e461457f8b6e4735a6b5921556f9b5f145 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h @@ -23,39 +23,46 @@ #include "schema/model_generated.h" namespace mindspore::kernel { -class ArithmeticFP16CPUKernel : public LiteKernel { - typedef int (*ArithmeticRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size); - typedef int (*ArithmeticOptRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param); +typedef int (*ArithmeticFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +typedef int (*ArithmeticOptFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +typedef struct { + int primitive_type_; + int activation_type_; + ArithmeticFuncFp16 func_; + ArithmeticOptFuncFp16 opt_func_; +} ARITHMETIC_FUNC_INFO_FP16; +class ArithmeticFP16CPUKernel : public LiteKernel { public: ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) { - arithmeticParameter_ = reinterpret_cast(parameter); + param_ = reinterpret_cast(parameter); } - ~ArithmeticFP16CPUKernel() override; + ~ArithmeticFP16CPUKernel() = default; int Init() override; int ReSize() override; int Run() override; int DoArithmetic(int task_id); int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count, - int out_thread_stride); + int out_thread_stride); private: void FreeTmpBuffer(); int outside_; int break_pos_; - int out_thread_stride_; - int out_count_; + bool is_input0_fp32_ = false; + bool is_input1_fp32_ = false; + bool is_output_fp32_ = false; float16_t *input0_fp16_ = nullptr; float16_t *input1_fp16_ = nullptr; float16_t *output_fp16_ = nullptr; - ArithmeticParameter *arithmeticParameter_ = nullptr; - ArithmeticRun arithmetic_run_ = nullptr; - ArithmeticOptRun arithmetic_opt_run_ = nullptr; + ArithmeticParameter *param_ = nullptr; + ArithmeticFuncFp16 arithmetic_func_ = nullptr; + ArithmeticOptFuncFp16 arithmetic_opt_func_ = nullptr; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_