提交 d9ccbb32 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5361 [MS][LITE][Develop]fix batchnorm output fp16

Merge pull request !5361 from sunsuodong/fix_batchnorm_output_fp16
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include "nnacl/fp16/batchnorm_fp16.h" #include "nnacl/fp16/batchnorm_fp16.h"
#include <math.h> #include <math.h>
void BatchNormFp16(const void *input, const void *mean, const void *variance, void BatchNormFp16(const float16_t *input, const void *mean, const void *variance,
BatchNormParameter *param, int task_id, void *output) { BatchNormParameter *param, int task_id, float16_t *output) {
int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_); int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
int completed_units = task_id * units_per_thread; int completed_units = task_id * units_per_thread;
int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
...@@ -27,8 +27,9 @@ void BatchNormFp16(const void *input, const void *mean, const void *variance, ...@@ -27,8 +27,9 @@ void BatchNormFp16(const void *input, const void *mean, const void *variance,
for (int i = 0; i < cur_unit; i++) { for (int i = 0; i < cur_unit; i++) {
for (int c = 0; c < param->channel_; c++) { for (int c = 0; c < param->channel_; c++) {
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
((float16_t *)output)[cur_offset + c] = if (variance_sqrt != 0) {
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
}
} }
cur_offset += param->channel_; cur_offset += param->channel_;
} }
...@@ -44,8 +45,12 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset ...@@ -44,8 +45,12 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset
for (int i = 0; i < cur_unit; i++) { for (int i = 0; i < cur_unit; i++) {
for (int c = 0; c < param->channel_; c++) { for (int c = 0; c < param->channel_; c++) {
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
float16_t norm_val = (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; if (variance_sqrt != 0) {
((float16_t *)output)[cur_offset + c] = norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c]; float16_t norm_val =
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
((float16_t *)output)[cur_offset + c] =
norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c];
}
} }
cur_offset += param->channel_; cur_offset += param->channel_;
} }
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
extern "C" { extern "C" {
#endif #endif
void BatchNormFp16(const void *input, const void *mean, const void *variance, BatchNormParameter *param, int task_id, void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, BatchNormParameter *param,
void *output); int task_id, float16_t *output);
void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean, void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean,
const void *variance, BatchNormParameter *param, int task_id, void *output); const void *variance, BatchNormParameter *param, int task_id, void *output);
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
*/ */
#include "nnacl/fp32/batchnorm.h" #include "nnacl/fp32/batchnorm.h"
#include "nnacl/fp16/batchnorm_fp16.h"
#include <math.h> #include <math.h>
#include "nnacl/batchnorm_parameter.h" #include "nnacl/batchnorm_parameter.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*/ */
#include "src/runtime/kernel/arm/fp16/batchnorm_fp16.h" #include "src/runtime/kernel/arm/fp16/batchnorm_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "nnacl/fp16/batchnorm_fp16.h" #include "nnacl/fp16/batchnorm_fp16.h"
#include "nnacl/fp16/cast_fp16.h" #include "nnacl/fp16/cast_fp16.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
...@@ -24,8 +25,9 @@ using mindspore::schema::PrimitiveType_BatchNorm; ...@@ -24,8 +25,9 @@ using mindspore::schema::PrimitiveType_BatchNorm;
namespace mindspore::kernel { namespace mindspore::kernel {
int BatchnormFp16CPUKernel::InitConstTensor() { int BatchnormFp16CPUKernel::InitConstTensor() {
isFloat32Tensor_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; is_input_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
if (isFloat32Tensor_) { is_output_fp32_ = out_tensors_.at(0)->data_type() == kNumberTypeFloat32;
if (is_input_fp32_) {
auto mean_fp32 = in_tensors_.at(1); auto mean_fp32 = in_tensors_.at(1);
auto variance_fp32 = in_tensors_.at(2); auto variance_fp32 = in_tensors_.at(2);
mean_ = malloc(mean_fp32->ElementsNum() * sizeof(float16_t)); mean_ = malloc(mean_fp32->ElementsNum() * sizeof(float16_t));
...@@ -50,30 +52,24 @@ int BatchnormFp16CPUKernel::Run() { ...@@ -50,30 +52,24 @@ int BatchnormFp16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
return ret; return ret;
} }
auto input_fp32 = in_tensors_.at(0); auto input_tensor = in_tensors_.at(0);
auto output_fp32 = out_tensors_.at(0); auto output_tensor = out_tensors_.at(0);
if (isFloat32Tensor_) { input_ = ConvertInputFp32toFp16(input_tensor, context_);
input_ = context_->allocator->Malloc(input_fp32->ElementsNum() * sizeof(float16_t)); output_ = MallocOutputFp16(output_tensor, context_);
output_ = context_->allocator->Malloc(output_fp32->ElementsNum() * sizeof(float16_t)); if (input_ == nullptr || output_ == nullptr) {
if (input_ == nullptr || output_ == nullptr) { FreeInputAndOutput();
FreeInputAndOutput(); MS_LOG(ERROR) << "input or output is nullptr";
return RET_ERROR; return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(input_fp32->Data()),
reinterpret_cast<float16_t *>(input_), input_fp32->ElementsNum());
} else {
input_ = in_tensors_.at(0)->Data();
output_ = out_tensors_.at(0)->Data();
} }
ret = ParallelLaunch(THREAD_POOL_DEFAULT, BatchNormRun, this, op_parameter_->thread_num_); ret = ParallelLaunch(THREAD_POOL_DEFAULT, BatchNormRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]"; MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]";
} }
if (isFloat32Tensor_) { if (is_output_fp32_) {
Float16ToFloat32(reinterpret_cast<float16_t *>(output_), reinterpret_cast<float *>(output_fp32->Data()), Float16ToFloat32(output_, reinterpret_cast<float *>(output_tensor->Data()), output_tensor->ElementsNum());
output_fp32->ElementsNum());
FreeInputAndOutput();
} }
FreeInputAndOutput();
return ret; return ret;
} }
...@@ -84,11 +80,11 @@ int BatchnormFp16CPUKernel::DoExecute(int task_id) { ...@@ -84,11 +80,11 @@ int BatchnormFp16CPUKernel::DoExecute(int task_id) {
} }
void BatchnormFp16CPUKernel::FreeInputAndOutput() { void BatchnormFp16CPUKernel::FreeInputAndOutput() {
if (input_ != nullptr) { if (is_input_fp32_) {
context_->allocator->Free(input_); context_->allocator->Free(input_);
input_ = nullptr; input_ = nullptr;
} }
if (output_ != nullptr) { if (is_output_fp32_) {
context_->allocator->Free(output_); context_->allocator->Free(output_);
output_ = nullptr; output_ = nullptr;
} }
......
...@@ -35,9 +35,10 @@ class BatchnormFp16CPUKernel : public BatchnormCPUKernel { ...@@ -35,9 +35,10 @@ class BatchnormFp16CPUKernel : public BatchnormCPUKernel {
private: private:
void FreeInputAndOutput(); void FreeInputAndOutput();
bool isFloat32Tensor_ = false; bool is_input_fp32_ = false;
void *input_ = nullptr; bool is_output_fp32_ = false;
void *output_ = nullptr; float16_t *input_ = nullptr;
float16_t *output_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
...@@ -45,7 +45,8 @@ TEST_F(TestTopKFp32, TopK) { ...@@ -45,7 +45,8 @@ TEST_F(TestTopKFp32, TopK) {
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr); ASSERT_NE(creator, nullptr);
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), nullptr, desc, nullptr); auto ctx = std::make_shared<lite::Context>();
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr); ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run(); auto ret = kernel->Run();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册