diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index dabc4f9915aeebd8651bcf7195cdc844e002a324..86c6e8d3373f743f485e4a69599e8c8323ba0083 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -66,7 +66,7 @@ void InitBaseConvKernel(ConvParam *param) { param->transformed_filter_ = new framework::LoDTensor; operators::math::winograd_transform_weight<8, 3>( *param->Filter(), param->transformed_filter_); - } else if (conv3x3 && !depth3x3 && + } else if (conv3x3 && param->Groups() == 1 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && param->Strides()[0] == 1 && param->Dilations()[0] == 1 @@ -76,7 +76,7 @@ void InitBaseConvKernel(ConvParam *param) { #endif ) { param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT; - } else if (conv3x3 && !depth3x3 && + } else if (conv3x3 && param->Groups() == 1 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && param->Strides()[0] == 2 && param->Dilations()[0] == 1