提交 689efef7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4940 modify conv to make fp16 run pass

Merge pull request !4940 from zhaozhenlong/lite/issue/modify_conv_for_fp16
...@@ -235,13 +235,15 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten ...@@ -235,13 +235,15 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
conv_param->input_w_ = inputs.front()->Width(); conv_param->input_w_ = inputs.front()->Width();
conv_param->output_h_ = outputs.front()->Height(); conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width(); conv_param->output_w_ = outputs.front()->Width();
bool prefer_flag = false; // bool prefer_flag = false;
if (conv_param->output_h_ * conv_param->output_w_ > 64) { // if (conv_param->output_h_ * conv_param->output_w_ > 64) {
prefer_flag = true; // prefer_flag = true;
} // }
kernel::LiteKernel *kernel = nullptr; kernel::LiteKernel *kernel = nullptr;
if (kernel_h == 1 && kernel_w == 1) { if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else if (kernel_h == 1 && kernel_w == 1) {
// kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); // kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else { } else {
...@@ -253,9 +255,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten ...@@ -253,9 +255,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
if (use_winograd) { if (use_winograd) {
kernel = new (std::nothrow) kernel = new (std::nothrow)
kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
} else if (prefer_flag && kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 &&
dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} }
if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { if (kernel_h != 1 && kernel_w != 1 && !use_winograd) {
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册