diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc index b4539db98c3ffb1a143c38dd3c4dd9e9924bd63e..25ee9f940481a0c92f354e819d6d2b8d45eff169 100644 --- a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -645,7 +645,6 @@ void conv_3x3s1_depthwise_fp32_bias(const float* i_data, bool flag_bias = param.bias != nullptr; /// get workspace - LOG(INFO) << "conv_3x3s1_depthwise_fp32_bias: "; float* ptr_zero = ctx->workspace_data(); memset(ptr_zero, 0, sizeof(float) * win_round); float* ptr_write = ptr_zero + win_round; diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index fa2f85311b3ff4247d52505d750566ec80e47256..af722fd6413c22c2be7474ba38b54d3f30d0011c 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -620,10 +620,8 @@ void conv_depthwise_3x3_fp32(const void* din, int pad = pad_w; bool flag_bias = param.bias != nullptr; bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); - bool ch_four = ch_in <= 4 * w_in; if (stride == 1) { - if (ch_four && pads_less && (pad_h == pad_w) && - (pad < 2)) { // support pad = [0, 1] + if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s1_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -656,8 +654,7 @@ void conv_depthwise_3x3_fp32(const void* din, ctx); } } else if (stride == 2) { - if (ch_four && pads_less && pad_h == pad_w && - (pad < 2)) { // support pad = [0, 1] + if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s2_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index c5b43a31a0f495f3635d389939acf44e979a3dc7..e04e774cce3af5bd6f8b67c6adfeba06fa814768 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -32,11 +32,10 @@ void DepthwiseConv::PrepareForRun() { auto hin = param.x->dims()[2]; auto win = param.x->dims()[3]; auto paddings = *param.paddings; - bool ch_four = channel <= 4 * win; // select dw conv kernel if (kw == 3) { bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); - if (ch_four && pads_less && paddings[0] == paddings[2] && + if (pads_less && paddings[0] == paddings[2] && (paddings[0] == 0 || paddings[0] == 1)) { flag_trans_weights_ = false; } else {