提交 4e4ad85c 编写于 作者: Y yangruoqi713

[MS][LITE] fix bug of arm cpu fp16 infer: set subgraph output tensor data_type float32

上级 fc9161a4
......@@ -21,6 +21,7 @@
#include "include/errorcode.h"
#include "nnacl/op_base.h"
#include "nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
......@@ -29,29 +30,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Pooling;
namespace mindspore::kernel {
int PoolingFp16CPUKernel::InitBuffer() {
int in_batch = pooling_param_->input_batch_;
int in_h = pooling_param_->input_h_;
int in_w = pooling_param_->input_w_;
int in_channel = pooling_param_->input_channel_;
fp16_input_ = reinterpret_cast<float16_t *>(malloc(in_batch * in_h * in_w * in_channel * sizeof(float16_t)));
if (fp16_input_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
return RET_ERROR;
}
int out_batch = pooling_param_->output_batch_;
int out_h = pooling_param_->output_h_;
int out_w = pooling_param_->output_w_;
int out_channel = pooling_param_->output_channel_;
fp16_output_ = reinterpret_cast<float16_t *>(malloc(out_batch * out_h * out_w * out_channel * sizeof(float16_t)));
if (fp16_output_ == nullptr) {
MS_LOG(ERROR) << "fp16_out malloc failed.";
return RET_ERROR;
}
return RET_OK;
}
int PoolingFp16CPUKernel::Init() {
auto ret = PoolingBaseCPUKernel::Init();
if (ret != RET_OK) {
......@@ -71,12 +49,6 @@ int PoolingFp16CPUKernel::ReSize() {
MS_LOG(ERROR) << "PoolingBase ReSize fai1!ret: " << ret;
return ret;
}
ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init Buffer fail!ret: " << ret;
return ret;
}
return RET_OK;
}
......@@ -105,9 +77,16 @@ int PoolingFp16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto ele_num = in_tensors_.front()->ElementsNum();
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->Data());
Float32ToFloat16(input_ptr, fp16_input_, ele_num);
auto input_tensor = in_tensors_.at(kInputIndex);
auto in_data_type_ = input_tensor->data_type();
MS_ASSERT(in_data_type_ == kNumberTypeFloat32 || in_data_type_ == kNumberTypeFloat16);
fp16_input_ = ConvertInputFp32toFp16(input_tensor, context_);
auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_data_type_ = out_tensor->data_type();
MS_ASSERT(out_data_type_ == kNumberTypeFloat32 || out_data_type_ == kNumberTypeFloat16);
fp16_output_ = MallocOutputFp16(out_tensor, context_);
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, PoolingFp16Impl, this, thread_count_);
if (error_code != RET_OK) {
......@@ -115,9 +94,15 @@ int PoolingFp16CPUKernel::Run() {
return RET_ERROR;
}
auto out_ele_num = out_tensors_.front()->ElementsNum();
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
Float16ToFloat32(fp16_output_, output_ptr, out_ele_num);
if (in_data_type_ == kNumberTypeFloat32) {
context_->allocator->Free(fp16_input_);
}
if (out_data_type_ == kNumberTypeFloat32) {
auto out_ele_num = out_tensor->ElementsNum();
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
Float16ToFloat32(fp16_output_, output_addr, out_ele_num);
context_->allocator->Free(fp16_output_);
}
return RET_OK;
}
......
......@@ -28,17 +28,9 @@ class PoolingFp16CPUKernel : public PoolingBaseCPUKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: PoolingBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~PoolingFp16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (fp16_output_ != nullptr) {
free(fp16_output_);
}
};
~PoolingFp16CPUKernel() override = default;
int Init() override;
int InitBuffer();
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
......
......@@ -182,9 +182,12 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) {
for (auto kernel : temp_kernels) {
for (auto tensor : kernel->out_tensors()) {
tensor->set_allocator(context_->allocator.get());
if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
}
std::vector<tensor::Tensor *> output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels);
for (auto tensor : output_tensor) {
if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部