diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 0d5f2c2f587bc3a2d2e42477a15a015cf2504aa5..602edf4cfef66532c0586d7ea9a52a28897b011e 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -619,9 +619,10 @@ void conv_depthwise_3x3_fp32(const void* din, int stride = param.strides[1]; int pad = pad_w; bool flag_bias = param.bias != nullptr; + bool ch_four = ch_in <= 4 * w_in; bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); if (stride == 1) { - if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] + if (ch_four && pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s1_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -676,7 +677,7 @@ void conv_depthwise_3x3_fp32(const void* din, #endif } } else if (stride == 2) { - if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] + if (ch_four && pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s2_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -734,6 +735,7 @@ void conv_depthwise_5x5_fp32(const void* din, int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; + bool ch_four = ch_in > 4 * w_in; ctx->ExtendWorkspace((w_in + w_out) * sizeof(float)); bool flag_act = act_param.has_active; if (stride == 2) { @@ -752,7 +754,7 @@ void conv_depthwise_5x5_fp32(const void* din, act_param, ctx); } else if (stride == 1) { - if (h_in < 5 || w_in < 5) { + if (ch_four || h_in < 5 || w_in < 5) { conv_depthwise_5x5s1_fp32(reinterpret_cast(dout), reinterpret_cast(din), reinterpret_cast(weights), diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index ef939d19b179c187a949bc6a466344a5f3c38234..1ad8bd599dd5a1882ddb7c70edba968cbb398aa4 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -28,12 +28,16 @@ void DepthwiseConv::PrepareForRun() { auto& ctx = this->ctx_->template As(); auto w_dims = param.filter->dims(); auto kw = w_dims[3]; + auto channel = w_dims[0]; + auto hin = param.x->dims()[2]; + auto win = param.x->dims()[3]; auto paddings = *param.paddings; + bool ch_four = channel <= 4 * win; // select dw conv kernel if (kw == 3) { // VLOG(5) << "invoke 3x3 dw conv fp32"; bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); - if (pads_less && paddings[0] == paddings[2] && + if (ch_four && pads_less && paddings[0] == paddings[2] && (paddings[0] == 0 || paddings[0] == 1)) { flag_trans_weights_ = false; } else { @@ -56,9 +60,7 @@ void DepthwiseConv::PrepareForRun() { } else if (kw == 5) { // VLOG(5) << "invoke 5x5 dw conv fp32"; auto strides = param.strides; - auto hin = param.x->dims()[2]; - auto win = param.x->dims()[3]; - if (win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) { + if (ch_four && win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) { flag_trans_weights_ = false; impl_ = lite::arm::math::conv_depthwise_5x5_fp32; #ifdef LITE_WITH_PROFILE