提交 9a09cf28 编写于 作者: C chenjiaoAngel

fix conv_dw choose kernel method

上级 b0fe17e4
......@@ -619,9 +619,10 @@ void conv_depthwise_3x3_fp32(const void* din,
int stride = param.strides[1];
int pad = pad_w;
bool flag_bias = param.bias != nullptr;
bool ch_four = ch_in <= 4 * w_in;
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (stride == 1) {
if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
if (ch_four && pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -676,7 +677,7 @@ void conv_depthwise_3x3_fp32(const void* din,
#endif
}
} else if (stride == 2) {
if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
if (ch_four && pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -734,6 +735,7 @@ void conv_depthwise_5x5_fp32(const void* din,
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
bool ch_four = ch_in > 4 * w_in;
ctx->ExtendWorkspace((w_in + w_out) * sizeof(float));
bool flag_act = act_param.has_active;
if (stride == 2) {
......@@ -752,7 +754,7 @@ void conv_depthwise_5x5_fp32(const void* din,
act_param,
ctx);
} else if (stride == 1) {
if (h_in < 5 || w_in < 5) {
if (ch_four || h_in < 5 || w_in < 5) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights),
......
......@@ -28,12 +28,16 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims();
auto kw = w_dims[3];
auto channel = w_dims[0];
auto hin = param.x->dims()[2];
auto win = param.x->dims()[3];
auto paddings = *param.paddings;
bool ch_four = channel <= 4 * win;
// select dw conv kernel
if (kw == 3) {
// VLOG(5) << "invoke 3x3 dw conv fp32";
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (pads_less && paddings[0] == paddings[2] &&
if (ch_four && pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) {
flag_trans_weights_ = false;
} else {
......@@ -56,9 +60,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
} else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32";
auto strides = param.strides;
auto hin = param.x->dims()[2];
auto win = param.x->dims()[3];
if (win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) {
if (ch_four && win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) {
flag_trans_weights_ = false;
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
#ifdef LITE_WITH_PROFILE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册