提交 77287093 编写于 作者: S sunsuodong

batch_norm_fp16

上级 0bbce936
...@@ -23,44 +23,75 @@ using mindspore::lite::KernelRegistrar; ...@@ -23,44 +23,75 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_BatchNorm; using mindspore::schema::PrimitiveType_BatchNorm;
namespace mindspore::kernel { namespace mindspore::kernel {
int BatchnormFp16CPUKernel::DoExecute(int task_id) { int BatchnormFp16CPUKernel::InitConstTensor() {
auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_); isFloat32Tensor_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
if (isFloat32Tensor_) {
if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { auto mean_fp32 = in_tensors_.at(1);
auto input = in_tensors_.at(0); auto variance_fp32 = in_tensors_.at(2);
auto mean = in_tensors_.at(1); mean_ = malloc(mean_fp32->ElementsNum() * sizeof(float16_t));
auto variance = in_tensors_.at(2); variance_ = malloc(variance_fp32->ElementsNum() * sizeof(float16_t));
auto output = out_tensors_.at(0); if (mean_ == nullptr || variance_ == nullptr) {
FreeMeanAndVariance();
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(mean_fp32->Data()),
reinterpret_cast<float16_t *>(mean_), mean_fp32->ElementsNum());
Float32ToFloat16(reinterpret_cast<float *>(variance_fp32->Data()),
reinterpret_cast<float16_t *>(variance_), variance_fp32->ElementsNum());
} else {
BatchnormCPUKernel::InitConstTensor();
}
return RET_OK;
}
auto input_fp16 = context_->allocator->Malloc(input->ElementsNum() * sizeof(float16_t)); int BatchnormFp16CPUKernel::Run() {
auto mean_fp16 = context_->allocator->Malloc(mean->ElementsNum() * sizeof(float16_t)); auto ret = Prepare();
auto variance_fp16 = context_->allocator->Malloc(variance->ElementsNum() * sizeof(float16_t)); if (ret != RET_OK) {
auto output_fp16 = context_->allocator->Malloc(output->ElementsNum() * sizeof(float16_t)); MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
if (input_fp16 == nullptr || mean_fp16 == nullptr || variance_fp16 == nullptr || output_fp16 == nullptr) { return ret;
context_->allocator->Free(input_fp16); }
context_->allocator->Free(mean_fp16); auto input_fp32 = in_tensors_.at(0);
context_->allocator->Free(variance_fp16); auto output_fp32 = out_tensors_.at(0);
context_->allocator->Free(output_fp16); if (isFloat32Tensor_) {
input_ = context_->allocator->Malloc(input_fp32->ElementsNum() * sizeof(float16_t));
output_ = context_->allocator->Malloc(output_fp32->ElementsNum() * sizeof(float16_t));
if (input_ == nullptr || output_ == nullptr) {
FreeInputAndOutput();
return RET_ERROR;
} }
Float32ToFloat16(reinterpret_cast<float *>(input->Data()), Float32ToFloat16(reinterpret_cast<float *>(input_fp32->Data()),
reinterpret_cast<float16_t *>(input_fp16), input->ElementsNum()); reinterpret_cast<float16_t *>(input_), input_fp32->ElementsNum());
Float32ToFloat16(reinterpret_cast<float *>(mean->Data()), } else {
reinterpret_cast<float16_t *>(mean_fp16), mean->ElementsNum()); input_ = in_tensors_.at(0)->Data();
Float32ToFloat16(reinterpret_cast<float *>(variance->Data()), output_ = out_tensors_.at(0)->Data();
reinterpret_cast<float16_t *>(variance_fp16), variance->ElementsNum()); }
ret = LiteBackendParallelLaunch(BatchNormRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]";
}
if (isFloat32Tensor_) {
Float16ToFloat32(reinterpret_cast<float16_t *>(output_), reinterpret_cast<float *>(output_fp32->Data()),
output_fp32->ElementsNum());
FreeInputAndOutput();
}
return ret;
}
BatchNormFp16(input_fp16, mean_fp16, variance_fp16, param, task_id, output_fp16); int BatchnormFp16CPUKernel::DoExecute(int task_id) {
auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_);
BatchNormFp16(input_, mean_, variance_, param, task_id, output_);
return mindspore::lite::RET_OK;
}
Float16ToFloat32(reinterpret_cast<float16_t *>(output_fp16), reinterpret_cast<float *>(output), void BatchnormFp16CPUKernel::FreeInputAndOutput() {
output->ElementsNum()); if (input_ != nullptr) {
context_->allocator->Free(input_fp16); context_->allocator->Free(input_);
context_->allocator->Free(mean_fp16); input_ = nullptr;
context_->allocator->Free(variance_fp16); }
context_->allocator->Free(output_fp16); if (output_ != nullptr) {
return mindspore::lite::RET_OK; context_->allocator->Free(output_);
output_ = nullptr;
} }
BatchNormFp16(in_tensors_.at(0)->Data(), mean_, variance_, param, task_id, out_tensors_.at(0)->Data());
return mindspore::lite::RET_OK;
} }
kernel::LiteKernel *CpuBatchnormFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuBatchnormFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
...@@ -83,5 +114,5 @@ kernel::LiteKernel *CpuBatchnormFp16KernelCreator(const std::vector<lite::tensor ...@@ -83,5 +114,5 @@ kernel::LiteKernel *CpuBatchnormFp16KernelCreator(const std::vector<lite::tensor
return kernel; return kernel;
} }
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BatchNorm, CpuBatchnormFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BatchNorm, CpuBatchnormFp16KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -29,7 +29,15 @@ class BatchnormFp16CPUKernel : public BatchnormCPUKernel { ...@@ -29,7 +29,15 @@ class BatchnormFp16CPUKernel : public BatchnormCPUKernel {
: BatchnormCPUKernel(parameter, inputs, outputs, ctx, primitive) {} : BatchnormCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
virtual ~BatchnormFp16CPUKernel() {} virtual ~BatchnormFp16CPUKernel() {}
virtual int DoExecute(int task_id); int Run() override;
int InitConstTensor() override;
int DoExecute(int task_id) override;
private:
void FreeInputAndOutput();
bool isFloat32Tensor_ = false;
void *input_ = nullptr;
void *output_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册