diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index 29a9c471aff3e30d92c6de6605fadfde870cb293..2a3a5e17e1d9da8db3ee30924c066bf195ddb97e 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -57,8 +57,8 @@ void InitBaseConvKernel(ConvParam *param) { param->Dilations()[0] == param->Dilations()[1] && param->Strides()[0] == 1 && param->Dilations()[0] == 1 #if 1 - && param->Input()->dims()[1] >= 4 - && param->Input()->dims()[2] >= 16 + && (param->Input()->dims()[1] >= 4 || + param->Output()->dims()[1] >= 16) #endif ) { param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT;