diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 34d6f31c3791a9b700be4dac991faa2de4f53647..6a0d4e4709f354ad07b489e91c45f7079fc56fa2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -235,13 +235,15 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vectorinput_w_ = inputs.front()->Width(); conv_param->output_h_ = outputs.front()->Height(); conv_param->output_w_ = outputs.front()->Width(); - bool prefer_flag = false; - if (conv_param->output_h_ * conv_param->output_w_ > 64) { - prefer_flag = true; - } + // bool prefer_flag = false; + // if (conv_param->output_h_ * conv_param->output_w_ > 64) { + // prefer_flag = true; + // } kernel::LiteKernel *kernel = nullptr; - if (kernel_h == 1 && kernel_w == 1) { + if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); + } else if (kernel_h == 1 && kernel_w == 1) { // kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); } else { @@ -253,9 +255,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector