提交 ff6b90d9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4145 [MS][LITE]modify fp16 conv creator

Merge pull request !4145 from 张学同/to_merge
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/)
file(GLOB KERNEL_SRC file(GLOB KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc
nnacl/*.cc nnacl/*.cc
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "nnacl/winograd_utils.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
...@@ -242,7 +243,7 @@ int ConvolutionFP16CPUKernel::Run() { ...@@ -242,7 +243,7 @@ int ConvolutionFP16CPUKernel::Run() {
auto out_tensor = outputs_.at(kOutputIndex); auto out_tensor = outputs_.at(kOutputIndex);
auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
for (int j = 0; j < out_tensor->ElementsNum(); ++j) { for (int j = 0; j < out_tensor->ElementsNum(); ++j) {
output_addr[j] = static_cast<float >(fp16_out_[j]); output_addr[j] = static_cast<float>(fp16_out_[j]);
} }
return RET_OK; return RET_OK;
} }
...@@ -264,20 +265,27 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten ...@@ -264,20 +265,27 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
conv_param->input_w_ = inputs.front()->Width(); conv_param->input_w_ = inputs.front()->Width();
conv_param->output_h_ = outputs.front()->Height(); conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width(); conv_param->output_w_ = outputs.front()->Width();
kernel::LiteKernel *kernel; kernel::LiteKernel *kernel = nullptr;
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx);
} else { } else {
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); bool use_winograd = false;
int out_unit;
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
if (kernel_h != 1 && kernel_w != 1 && !use_winograd) {
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx);
}
} }
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "Create conv fp16 kernel failed."; MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
return nullptr; return nullptr;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr; return nullptr;
} }
......
...@@ -220,32 +220,6 @@ int ConvolutionCPUKernel::Run() { ...@@ -220,32 +220,6 @@ int ConvolutionCPUKernel::Run() {
return RET_OK; return RET_OK;
} }
void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) {
if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) {
*output_unit = SelectOutputUnit(conv_param);
if (*output_unit > 1) {
*use_winograd = true;
int input_unit = conv_param->kernel_h_ + *output_unit - 1;
input_trans_func = GetInputTransFunc(input_unit);
if (input_trans_func == nullptr) {
MS_LOG(INFO) << "No matching input trans func. Turn back to common conv.";
*use_winograd = false;
}
output_trans_func = GetOutputTransFunc(input_unit, *output_unit);
if (output_trans_func == nullptr) {
MS_LOG(INFO) << "No matching output trans func. Turn back to common conv.";
*use_winograd = false;
}
} else {
*use_winograd = false;
}
} else {
*use_winograd = false;
}
}
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx, OpParameter *opParameter, const Context *ctx,
...@@ -270,7 +244,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten ...@@ -270,7 +244,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
kernel::LiteKernel *kernel; kernel::LiteKernel *kernel;
if (kernel_h == 1 && kernel_w == 1) { if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); // kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx);
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx);
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx); kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx);
} else if (use_winograd) { } else if (use_winograd) {
......
...@@ -4708,3 +4708,28 @@ OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit) { ...@@ -4708,3 +4708,28 @@ OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit) {
return nullptr; return nullptr;
} }
} }
void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) {
if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) {
*output_unit = SelectOutputUnit(conv_param);
if (*output_unit > 1) {
*use_winograd = true;
int input_unit = conv_param->kernel_h_ + *output_unit - 1;
input_trans_func = GetInputTransFunc(input_unit);
if (input_trans_func == nullptr) {
*use_winograd = false;
}
output_trans_func = GetOutputTransFunc(input_unit, *output_unit);
if (output_trans_func == nullptr) {
*use_winograd = false;
}
} else {
*use_winograd = false;
}
} else {
*use_winograd = false;
}
}
...@@ -54,5 +54,7 @@ InputTransformUnitFunc GetInputTransFunc(int input_unit); ...@@ -54,5 +54,7 @@ InputTransformUnitFunc GetInputTransFunc(int input_unit);
OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit); OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit);
void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_UTILS_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_UTILS_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册