提交 5613116c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4201 modify CPU op arithmetic fp16

Merge pull request !4201 from 陶云浩/test
......@@ -48,14 +48,6 @@ using mindspore::schema::PrimitiveType_Sub;
namespace mindspore::kernel {
void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
if (tile_data0_ != nullptr) {
free(tile_data0_);
tile_data0_ = nullptr;
}
if (tile_data1_ != nullptr) {
free(tile_data1_);
tile_data1_ = nullptr;
}
if (input0_fp16_ != nullptr) {
context_->allocator->Free(input0_fp16_);
input0_fp16_ = nullptr;
......@@ -70,7 +62,7 @@ void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
}
}
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTmpBuffer(); }
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() {}
int ArithmeticFP16CPUKernel::Init() {
switch (op_parameter_->type_) {
......@@ -162,7 +154,6 @@ int ArithmeticFP16CPUKernel::Init() {
}
int ArithmeticFP16CPUKernel::ReSize() {
FreeTmpBuffer();
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
......@@ -286,7 +277,7 @@ int ArithmeticFP16CPUKernel::ReSize() {
}
int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim,
int out_count, int out_thread_stride) {
int out_count, int out_thread_stride) {
if (dim > break_pos_) {
int error_code =
arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count);
......@@ -303,8 +294,8 @@ int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1,
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code =
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
if (error_code != RET_OK) {
return RET_ERROR;
}
......@@ -327,7 +318,6 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
FreeTmpBuffer();
return RET_ERROR;
}
......@@ -383,8 +373,7 @@ int ArithmeticFP16CPUKernel::Run() {
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
output_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
output_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
if (output_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
......@@ -392,8 +381,7 @@ int ArithmeticFP16CPUKernel::Run() {
}
}
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
input0_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
input0_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
......@@ -403,8 +391,7 @@ int ArithmeticFP16CPUKernel::Run() {
arithmeticParameter_->in_elements_num0_);
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
input1_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input1_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
......@@ -414,6 +401,7 @@ int ArithmeticFP16CPUKernel::Run() {
arithmeticParameter_->in_elements_num1_);
}
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRun_Fp16, this, context_->thread_num_);
FreeTmpBuffer();
return ret;
}
......@@ -441,21 +429,21 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tenso
return kernel;
}
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
} // namespace mindspore::kernel
......@@ -50,8 +50,6 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
int break_pos_;
int out_thread_stride_;
int out_count_;
float16_t *tile_data0_ = nullptr;
float16_t *tile_data1_ = nullptr;
float16_t *input0_fp16_ = nullptr;
float16_t *input1_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册