diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 6baa81173718f78fa7a491efe0c9d7492f6fa86d..2a545e70691f030a3a1e3f2a9a9822f5cd8b85b9 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -109,6 +109,8 @@ void ConvCompute::PrepareForRun() { int pw = paddings[2]; int sh = param.strides[1]; int sw = param.strides[0]; + int hin = param.x->dims()[2]; + int win = param.x->dims()[3]; bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); @@ -116,13 +118,12 @@ void ConvCompute::PrepareForRun() { bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); bool flag_dw_5x5 = pads_all_equal && (kw == 5 && (sw == 1 || sw == 2)); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; - if (param.groups == ic && ic == oc && kps_equal && pads_equal && no_dilation && flag_dw) { impl_ = new DepthwiseConv; // VLOG(3) << "Run DepthwiseConv Int8"; } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && - kps_equal && no_dilation) { + ic * oc < 4 * hin * win && kps_equal && no_dilation) { impl_ = new DirectConv; // VLOG(3) << "Run DirectConv Int8"; } else { @@ -154,6 +155,8 @@ void ConvCompute::PrepareForRun() { int pw = paddings[2]; int sh = param.strides[1]; int sw = param.strides[0]; + int hin = param.x->dims()[2]; + int win = param.x->dims()[3]; bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); @@ -167,7 +170,7 @@ void ConvCompute::PrepareForRun() { impl_ = new DepthwiseConv; // VLOG(3) << "Run DepthwiseConv Int8"; } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && - kps_equal && no_dilation) { + ic * oc < 4 * hin * win && kps_equal && no_dilation) { impl_ = new DirectConv; // VLOG(3) << "Run DirectConv Int8"; } else {