提交 d9ccbb32 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5361 [MS][LITE][Develop]fix batchnorm output fp16

Merge pull request !5361 from sunsuodong/fix_batchnorm_output_fp16
......@@ -17,8 +17,8 @@
#include "nnacl/fp16/batchnorm_fp16.h"
#include <math.h>
void BatchNormFp16(const void *input, const void *mean, const void *variance,
BatchNormParameter *param, int task_id, void *output) {
void BatchNormFp16(const float16_t *input, const void *mean, const void *variance,
BatchNormParameter *param, int task_id, float16_t *output) {
int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
int completed_units = task_id * units_per_thread;
int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
......@@ -27,8 +27,9 @@ void BatchNormFp16(const void *input, const void *mean, const void *variance,
for (int i = 0; i < cur_unit; i++) {
for (int c = 0; c < param->channel_; c++) {
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
((float16_t *)output)[cur_offset + c] =
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
if (variance_sqrt != 0) {
output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
}
}
cur_offset += param->channel_;
}
......@@ -44,8 +45,12 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset
for (int i = 0; i < cur_unit; i++) {
for (int c = 0; c < param->channel_; c++) {
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
float16_t norm_val = (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
((float16_t *)output)[cur_offset + c] = norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c];
if (variance_sqrt != 0) {
float16_t norm_val =
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
((float16_t *)output)[cur_offset + c] =
norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c];
}
}
cur_offset += param->channel_;
}
......
......@@ -25,8 +25,8 @@
extern "C" {
#endif
void BatchNormFp16(const void *input, const void *mean, const void *variance, BatchNormParameter *param, int task_id,
void *output);
void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, BatchNormParameter *param,
int task_id, float16_t *output);
void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean,
const void *variance, BatchNormParameter *param, int task_id, void *output);
......
......@@ -15,7 +15,6 @@
*/
#include "nnacl/fp32/batchnorm.h"
#include "nnacl/fp16/batchnorm_fp16.h"
#include <math.h>
#include "nnacl/batchnorm_parameter.h"
#include "nnacl/op_base.h"
......
......@@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp16/batchnorm_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "nnacl/fp16/batchnorm_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "src/kernel_registry.h"
......@@ -24,8 +25,9 @@ using mindspore::schema::PrimitiveType_BatchNorm;
namespace mindspore::kernel {
int BatchnormFp16CPUKernel::InitConstTensor() {
isFloat32Tensor_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
if (isFloat32Tensor_) {
is_input_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
is_output_fp32_ = out_tensors_.at(0)->data_type() == kNumberTypeFloat32;
if (is_input_fp32_) {
auto mean_fp32 = in_tensors_.at(1);
auto variance_fp32 = in_tensors_.at(2);
mean_ = malloc(mean_fp32->ElementsNum() * sizeof(float16_t));
......@@ -50,30 +52,24 @@ int BatchnormFp16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
return ret;
}
auto input_fp32 = in_tensors_.at(0);
auto output_fp32 = out_tensors_.at(0);
if (isFloat32Tensor_) {
input_ = context_->allocator->Malloc(input_fp32->ElementsNum() * sizeof(float16_t));
output_ = context_->allocator->Malloc(output_fp32->ElementsNum() * sizeof(float16_t));
if (input_ == nullptr || output_ == nullptr) {
FreeInputAndOutput();
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(input_fp32->Data()),
reinterpret_cast<float16_t *>(input_), input_fp32->ElementsNum());
} else {
input_ = in_tensors_.at(0)->Data();
output_ = out_tensors_.at(0)->Data();
auto input_tensor = in_tensors_.at(0);
auto output_tensor = out_tensors_.at(0);
input_ = ConvertInputFp32toFp16(input_tensor, context_);
output_ = MallocOutputFp16(output_tensor, context_);
if (input_ == nullptr || output_ == nullptr) {
FreeInputAndOutput();
MS_LOG(ERROR) << "input or output is nullptr";
return RET_ERROR;
}
ret = ParallelLaunch(THREAD_POOL_DEFAULT, BatchNormRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]";
}
if (isFloat32Tensor_) {
Float16ToFloat32(reinterpret_cast<float16_t *>(output_), reinterpret_cast<float *>(output_fp32->Data()),
output_fp32->ElementsNum());
FreeInputAndOutput();
if (is_output_fp32_) {
Float16ToFloat32(output_, reinterpret_cast<float *>(output_tensor->Data()), output_tensor->ElementsNum());
}
FreeInputAndOutput();
return ret;
}
......@@ -84,11 +80,11 @@ int BatchnormFp16CPUKernel::DoExecute(int task_id) {
}
void BatchnormFp16CPUKernel::FreeInputAndOutput() {
if (input_ != nullptr) {
if (is_input_fp32_) {
context_->allocator->Free(input_);
input_ = nullptr;
}
if (output_ != nullptr) {
if (is_output_fp32_) {
context_->allocator->Free(output_);
output_ = nullptr;
}
......
......@@ -35,9 +35,10 @@ class BatchnormFp16CPUKernel : public BatchnormCPUKernel {
private:
void FreeInputAndOutput();
bool isFloat32Tensor_ = false;
void *input_ = nullptr;
void *output_ = nullptr;
bool is_input_fp32_ = false;
bool is_output_fp32_ = false;
float16_t *input_ = nullptr;
float16_t *output_ = nullptr;
};
} // namespace mindspore::kernel
......
......@@ -45,7 +45,8 @@ TEST_F(TestTopKFp32, TopK) {
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), nullptr, desc, nullptr);
auto ctx = std::make_shared<lite::Context>();
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册