From f1840b534da7d007a4dd79b915bb012dbe53f56f Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Thu, 28 Feb 2019 18:05:03 +0800 Subject: [PATCH] Fix entering the winograd branch for depthwise convolution, fix ShareDataWith bug --- src/operators/kernel/arm/conv_kernel.cpp | 6 ++++++ src/operators/kernel/central-arm-func/conv_arm_func.h | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index f7f55b790d..de19127e68 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -57,6 +57,8 @@ bool ConvKernel::Init(ConvParam *param) { param->Strides()[0] == 2 && param->Paddings()[0] == 1 && param->Paddings()[0] == param->Paddings()[1]) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT; + } else if (depth3x3) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_FLOAT; #ifndef __aarch64__ } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == 1) { @@ -106,6 +108,10 @@ void ConvKernel::Compute(const ConvParam ¶m) { math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), nullptr, false, false); break; + case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), nullptr, param.Output(), false); + break; #ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 86a3c7a969..93be71f554 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -90,8 +90,8 @@ inline void GemmConv(const ConvParam ¶m) { Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); + // col_matrix.ShareDataWith(in_slice); + col_matrix = in_slice; col_matrix.Resize(col_matrix_shape); } else if (data_dim == 2U) { // im2col -- GitLab