提交 ea5463cd 编写于 作者: T tao_yunhao

modify arm cpu op: Arithmetic_fp16

上级 d921d853
......@@ -29,6 +29,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Add;
using mindspore::schema::PrimitiveType_Div;
using mindspore::schema::PrimitiveType_Eltwise;
using mindspore::schema::PrimitiveType_Equal;
using mindspore::schema::PrimitiveType_FloorDiv;
using mindspore::schema::PrimitiveType_FloorMod;
......@@ -172,8 +173,6 @@ int ArithmeticFP16CPUKernel::ReSize() {
MS_LOG(ERROR) << "malloc data fail!";
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_,
arithmeticParameter_->in_elements_num0_);
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>(
......@@ -182,8 +181,6 @@ int ArithmeticFP16CPUKernel::ReSize() {
MS_LOG(ERROR) << "malloc data fail!";
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[1]->Data()), input1_fp16_,
arithmeticParameter_->in_elements_num1_);
}
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
output_fp16_ = reinterpret_cast<float16_t *>(
......@@ -297,15 +294,33 @@ int ArithmeticFP16CPUKernel::ReSize() {
}
if (arithmeticParameter_->broadcasting_) {
auto tile_size = arithmeticParameter_->out_elements_num_ * sizeof(float16_t);
tile_data0_ = reinterpret_cast<float16_t *>(malloc(tile_size));
tile_data1_ = reinterpret_cast<float16_t *>(malloc(tile_size));
if (tile_data0_ == nullptr || tile_data1_ == nullptr) {
MS_LOG(ERROR) << "malloc tile data fail!";
return RET_ERROR;
outside_ = 1;
for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
break_pos_ = i;
break;
}
outside_ *= arithmeticParameter_->out_shape_[i];
}
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
}
return RET_OK;
}
int ArithmeticFP16CPUKernel::broadcast_run_(float16_t *input0, float16_t *input1, float16_t *output, int dim) {
if (dim > break_pos_) {
return arithmetic_run_(input0 + out_thread_stride_, input1 + out_thread_stride_, output + out_thread_stride_,
out_count_);
}
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[0] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[0] == 1 ? 0 : i;
return broadcast_run_(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1);
}
return RET_OK;
}
......@@ -329,8 +344,10 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) {
error_code =
arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride, output_data + thread_stride, count);
stride = UP_DIV(outside_, context_->thread_num_);
out_count_ = MSMIN(stride, outside_ - stride * task_id);
out_thread_stride_ = stride * task_id;
error_code = broadcast_run_(input0_data, input1_data1, output_data, 0);
} else if (arithmetic_opt_run_ != nullptr) {
if (arithmeticParameter_->in_elements_num0_ == 1) {
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count,
......@@ -373,13 +390,15 @@ int ArithmeticFP16CPUKernel::Run() {
return ret;
}
if (arithmeticParameter_->broadcasting_) {
auto input_data0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
auto input_data1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
float16_t *input0 = input0_fp16_ == nullptr ? input_data0 : input0_fp16_;
float16_t *input1 = input1_fp16_ == nullptr ? input_data1 : input1_fp16_;
TileDimensionsFp16(input0, input1, tile_data0_, tile_data1_, arithmeticParameter_);
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_,
arithmeticParameter_->in_elements_num0_);
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[1]->Data()), input1_fp16_,
arithmeticParameter_->in_elements_num1_);
}
ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Arithmetic function fail!ret: " << ret;
......@@ -428,4 +447,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16Kernel
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
} // namespace mindspore::kernel
......@@ -30,20 +30,25 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
public:
ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~ArithmeticFP16CPUKernel() override;
int Init() override;
int ReSize() override;
int Run() override;
int DoArithmetic(int task_id);
int broadcast_run_(float16_t *input0, float16_t *input1, float16_t *output, int dim);
private:
void FreeTmpBuffer();
int break_pos_;
int outside_;
int out_thread_stride_;
int out_count_;
float16_t *tile_data0_ = nullptr;
float16_t *tile_data1_ = nullptr;
float16_t *input0_fp16_ = nullptr;
......
......@@ -18,33 +18,6 @@
#include <math.h>
#include "nnacl/arithmetic_common.h"
void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t ndim, int *inShape, int *inStrides,
int *outStrides, int *multiple) {
int srcDimSize = inShape[dim];
if (dim == ndim - 1) {
for (int i = 0; i < multiple[dim]; i++) {
memcpy(outData, inData, srcDimSize * sizeof(float16_t));
outData += srcDimSize;
}
return;
}
for (size_t i = 0; i < srcDimSize; i++) {
for (size_t j = 0; j < multiple[dim]; j++) {
TileOneDimensionFp16(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim,
inShape, inStrides, outStrides, multiple);
}
}
}
void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
ArithmeticParameter *param) {
CalcMultiplesAndStrides(param);
TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_,
param->multiples0_);
TileOneDimensionFp16(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_,
param->multiples1_);
}
int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
......
......@@ -111,8 +111,6 @@ int ElementLessEqual(float16_t *input0, float16_t *input1, float16_t *output, in
int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
ArithmeticParameter *param);
#ifdef __cplusplus
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册