提交 77287093 编写于 作者: S sunsuodong

batch_norm_fp16

上级 0bbce936
......@@ -23,44 +23,75 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_BatchNorm;
namespace mindspore::kernel {
int BatchnormFp16CPUKernel::DoExecute(int task_id) {
auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_);
if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
auto input = in_tensors_.at(0);
auto mean = in_tensors_.at(1);
auto variance = in_tensors_.at(2);
auto output = out_tensors_.at(0);
int BatchnormFp16CPUKernel::InitConstTensor() {
isFloat32Tensor_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
if (isFloat32Tensor_) {
auto mean_fp32 = in_tensors_.at(1);
auto variance_fp32 = in_tensors_.at(2);
mean_ = malloc(mean_fp32->ElementsNum() * sizeof(float16_t));
variance_ = malloc(variance_fp32->ElementsNum() * sizeof(float16_t));
if (mean_ == nullptr || variance_ == nullptr) {
FreeMeanAndVariance();
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(mean_fp32->Data()),
reinterpret_cast<float16_t *>(mean_), mean_fp32->ElementsNum());
Float32ToFloat16(reinterpret_cast<float *>(variance_fp32->Data()),
reinterpret_cast<float16_t *>(variance_), variance_fp32->ElementsNum());
} else {
BatchnormCPUKernel::InitConstTensor();
}
return RET_OK;
}
auto input_fp16 = context_->allocator->Malloc(input->ElementsNum() * sizeof(float16_t));
auto mean_fp16 = context_->allocator->Malloc(mean->ElementsNum() * sizeof(float16_t));
auto variance_fp16 = context_->allocator->Malloc(variance->ElementsNum() * sizeof(float16_t));
auto output_fp16 = context_->allocator->Malloc(output->ElementsNum() * sizeof(float16_t));
if (input_fp16 == nullptr || mean_fp16 == nullptr || variance_fp16 == nullptr || output_fp16 == nullptr) {
context_->allocator->Free(input_fp16);
context_->allocator->Free(mean_fp16);
context_->allocator->Free(variance_fp16);
context_->allocator->Free(output_fp16);
int BatchnormFp16CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
return ret;
}
auto input_fp32 = in_tensors_.at(0);
auto output_fp32 = out_tensors_.at(0);
if (isFloat32Tensor_) {
input_ = context_->allocator->Malloc(input_fp32->ElementsNum() * sizeof(float16_t));
output_ = context_->allocator->Malloc(output_fp32->ElementsNum() * sizeof(float16_t));
if (input_ == nullptr || output_ == nullptr) {
FreeInputAndOutput();
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(input->Data()),
reinterpret_cast<float16_t *>(input_fp16), input->ElementsNum());
Float32ToFloat16(reinterpret_cast<float *>(mean->Data()),
reinterpret_cast<float16_t *>(mean_fp16), mean->ElementsNum());
Float32ToFloat16(reinterpret_cast<float *>(variance->Data()),
reinterpret_cast<float16_t *>(variance_fp16), variance->ElementsNum());
Float32ToFloat16(reinterpret_cast<float *>(input_fp32->Data()),
reinterpret_cast<float16_t *>(input_), input_fp32->ElementsNum());
} else {
input_ = in_tensors_.at(0)->Data();
output_ = out_tensors_.at(0)->Data();
}
ret = LiteBackendParallelLaunch(BatchNormRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]";
}
if (isFloat32Tensor_) {
Float16ToFloat32(reinterpret_cast<float16_t *>(output_), reinterpret_cast<float *>(output_fp32->Data()),
output_fp32->ElementsNum());
FreeInputAndOutput();
}
return ret;
}
BatchNormFp16(input_fp16, mean_fp16, variance_fp16, param, task_id, output_fp16);
int BatchnormFp16CPUKernel::DoExecute(int task_id) {
auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_);
BatchNormFp16(input_, mean_, variance_, param, task_id, output_);
return mindspore::lite::RET_OK;
}
Float16ToFloat32(reinterpret_cast<float16_t *>(output_fp16), reinterpret_cast<float *>(output),
output->ElementsNum());
context_->allocator->Free(input_fp16);
context_->allocator->Free(mean_fp16);
context_->allocator->Free(variance_fp16);
context_->allocator->Free(output_fp16);
return mindspore::lite::RET_OK;
void BatchnormFp16CPUKernel::FreeInputAndOutput() {
if (input_ != nullptr) {
context_->allocator->Free(input_);
input_ = nullptr;
}
if (output_ != nullptr) {
context_->allocator->Free(output_);
output_ = nullptr;
}
BatchNormFp16(in_tensors_.at(0)->Data(), mean_, variance_, param, task_id, out_tensors_.at(0)->Data());
return mindspore::lite::RET_OK;
}
kernel::LiteKernel *CpuBatchnormFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
......@@ -83,5 +114,5 @@ kernel::LiteKernel *CpuBatchnormFp16KernelCreator(const std::vector<lite::tensor
return kernel;
}
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BatchNorm, CpuBatchnormFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BatchNorm, CpuBatchnormFp16KernelCreator)
} // namespace mindspore::kernel
......@@ -29,7 +29,15 @@ class BatchnormFp16CPUKernel : public BatchnormCPUKernel {
: BatchnormCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
virtual ~BatchnormFp16CPUKernel() {}
virtual int DoExecute(int task_id);
int Run() override;
int InitConstTensor() override;
int DoExecute(int task_id) override;
private:
void FreeInputAndOutput();
bool isFloat32Tensor_ = false;
void *input_ = nullptr;
void *output_ = nullptr;
};
} // namespace mindspore::kernel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册