diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 35334f005439dbc8d2909eeacc1be2ef4acbd958..cb487411bf0a4699526eed65e5e5b5ec1600c975 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -162,34 +162,9 @@ int ArithmeticFP16CPUKernel::Init() { } int ArithmeticFP16CPUKernel::ReSize() { - FreeTmpBuffer(); arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { - input0_fp16_ = reinterpret_cast( - context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); - if (input0_fp16_ == nullptr) { - MS_LOG(ERROR) << "malloc data fail!"; - return RET_ERROR; - } - } - if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { - input1_fp16_ = reinterpret_cast( - context_->allocator->Malloc(arithmeticParameter_->in_elements_num1_ * sizeof(float16_t))); - if (input0_fp16_ == nullptr) { - MS_LOG(ERROR) << "malloc data fail!"; - return RET_ERROR; - } - } - if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { - output_fp16_ = reinterpret_cast( - context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t))); - if (output_fp16_ == nullptr) { - MS_LOG(ERROR) << "malloc data fail!"; - return RET_ERROR; - } - } if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { switch (arithmeticParameter_->op_parameter_.type_) { @@ -292,20 +267,6 @@ int ArithmeticFP16CPUKernel::ReSize() { break; } } - - if (arithmeticParameter_->broadcasting_) { - outside_ = 1; - for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) { - if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { - break_pos_ = i; - break; - } - outside_ *= arithmeticParameter_->out_shape_[i]; - } - ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); - ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); - ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); - } return RET_OK; } @@ -344,10 +305,8 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { int error_code = RET_OK; if (arithmeticParameter_->broadcasting_) { - stride = UP_DIV(outside_, context_->thread_num_); - out_count_ = MSMIN(stride, outside_ - stride * task_id); - out_thread_stride_ = stride * task_id; - error_code = broadcast_run_(input0_data, input1_data1, output_data, 0); + error_code = + arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride, output_data + thread_stride, count); } else if (arithmetic_opt_run_ != nullptr) { if (arithmeticParameter_->in_elements_num0_ == 1) { error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count, @@ -364,6 +323,7 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count); } if (error_code != RET_OK) { + FreeTmpBuffer(); return RET_ERROR; } if (output_fp16_ != nullptr) { @@ -390,6 +350,37 @@ int ArithmeticFP16CPUKernel::Run() { return ret; } + arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); + arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); + arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); + if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { + input0_fp16_ = reinterpret_cast( + context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); + if (input0_fp16_ == nullptr) { + MS_LOG(ERROR) << "malloc data fail!"; + FreeTmpBuffer(); + return RET_ERROR; + } + } + if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { + input1_fp16_ = reinterpret_cast( + context_->allocator->Malloc(arithmeticParameter_->in_elements_num1_ * sizeof(float16_t))); + if (input0_fp16_ == nullptr) { + MS_LOG(ERROR) << "malloc data fail!"; + FreeTmpBuffer(); + return RET_ERROR; + } + } + if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { + output_fp16_ = reinterpret_cast( + context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t))); + if (output_fp16_ == nullptr) { + MS_LOG(ERROR) << "malloc data fail!"; + FreeTmpBuffer(); + return RET_ERROR; + } + } + if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { Float32ToFloat16(reinterpret_cast(in_tensors_[0]->Data()), input0_fp16_, arithmeticParameter_->in_elements_num0_); @@ -399,9 +390,33 @@ int ArithmeticFP16CPUKernel::Run() { arithmeticParameter_->in_elements_num1_); } + if (arithmeticParameter_->broadcasting_) { + auto tile_size = arithmeticParameter_->out_elements_num_ * sizeof(float16_t); + tile_data0_ = reinterpret_cast(malloc(tile_size)); + if (tile_data0_ == nullptr) { + MS_LOG(ERROR) << "malloc data fail!"; + FreeTmpBuffer(); + return RET_ERROR; + } + tile_data1_ = reinterpret_cast(malloc(tile_size)); + if (tile_data1_ == nullptr) { + MS_LOG(ERROR) << "malloc data fail!"; + FreeTmpBuffer(); + return RET_ERROR; + } + auto input0 = reinterpret_cast(in_tensors_[0]->Data()); + auto input1 = reinterpret_cast(in_tensors_[1]->Data()); + + float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_; + float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_; + + TileDimensionsFp16(input0_data, input1_data1, tile_data0_, tile_data1_, arithmeticParameter_); + } + ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Arithmetic function fail!ret: " << ret; + FreeTmpBuffer(); return ret; } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c index 35f8be39070cb91a50f81f76417d5b24ce2d4012..a801e936213121d84938c33bf2cb1b8f447702ef 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c @@ -18,6 +18,33 @@ #include #include "nnacl/arithmetic_common.h" +void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t ndim, int *inShape, int *inStrides, + int *outStrides, int *multiple) { + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float16_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionFp16(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionFp16(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h index 5bffd41e5cb1cf7cfafbb047166f4cfe63b80131..25aaa01d4c2d06e82bfc30981921b7be56d4ea73 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h @@ -111,6 +111,8 @@ int ElementLessEqual(float16_t *input0, float16_t *input1, float16_t *output, in int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param); #ifdef __cplusplus } #endif