From 532cff714c9ea75244a8cee0dac7e8222fd5dab5 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Thu, 7 Mar 2019 08:05:24 +0800 Subject: [PATCH] Refator depthwise conv3x3 and fix it's bugs for armv8 --- .../convolution/conv_add_bn_relu_kernel.cpp | 22 +- .../arm/convolution/conv_add_kernel.cpp | 35 +- .../arm/convolution/conv_add_relu_kernel.cpp | 21 +- .../convolution/conv_bn_add_relu_kernel.cpp | 38 +- .../arm/convolution/conv_bn_relu_kernel.cpp | 22 +- .../kernel/arm/convolution/conv_common.cpp | 28 +- .../kernel/arm/convolution/conv_kernel.cpp | 22 +- .../arm/convolution/dwconv_bn_relu_kernel.cpp | 22 +- .../central-arm-func/conv_add_arm_func.h | 29 - .../conv_bn_add_relu_arm_func.h | 25 - .../conv_transpose_arm_func.h | 1 - src/operators/math/depthwise_conv3x3.cpp | 2968 ++++++----------- src/operators/math/depthwise_conv3x3.h | 42 - src/operators/math/gemm/cblas.cc | 6 +- src/operators/math/gemm/executor.h | 23 +- src/operators/math/im2col.cpp | 332 +- src/operators/op_param.h | 6 +- 17 files changed, 1110 insertions(+), 2532 deletions(-) diff --git a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp index b0bfae799c..28cb2c3e40 100644 --- a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -61,25 +61,15 @@ template <> void ConvAddBNReluKernel::Compute( const FusionConvAddBNReluParam ¶m) { switch (param.ExecMode()) { - case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: - math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: - math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + math::DepthwiseConv3x3S1(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + math::DepthwiseConv3x3S2(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; diff --git a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp index 9e27dc62fb..b62fdf71f8 100644 --- a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #ifdef FUSION_CONVADD_OP #include "operators/kernel/conv_add_kernel.h" +#include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_add_arm_func.h" namespace paddle_mobile { @@ -21,12 +22,44 @@ namespace operators { template <> bool ConvAddKernel::Init(FusionConvAddParam *param) { + InitBaseConvKernel(param); return true; } template <> void ConvAddKernel::Compute(const FusionConvAddParam ¶m) { - ConvAddCompute(param); + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + math::DepthwiseConv3x3S1(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); + math::AddChannelWise(param.Output(), param.Bias(), + param.Output()); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + math::DepthwiseConv3x3S2(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); + math::AddChannelWise(param.Output(), param.Bias(), + param.Output()); + break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + math::AddChannelWise(param.Output(), param.Bias(), + param.Output()); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + math::AddChannelWise(param.Output(), param.Bias(), + param.Output()); + break; +#endif // __aarch64__ + case ConvParam::EXEC_GEMM_FLOAT: + ConvAddBasic(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } } template class ConvAddKernel; diff --git a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp index 054cfd4c45..4060c56312 100644 --- a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp @@ -31,21 +31,14 @@ template <> void ConvAddReluKernel::Compute( const FusionConvAddReluParam ¶m) { switch (param.ExecMode()) { - case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - param.Bias(), true, true); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: - math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.Bias(), true, true); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - param.Bias(), true, true); + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + math::DepthwiseConv3x3S1(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); + math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + math::DepthwiseConv3x3S2(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; #ifndef __aarch64__ diff --git a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp index fb5cbb68e6..86a9eb2250 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp @@ -16,7 +16,8 @@ limitations under the License. */ #include "operators/kernel/conv_bn_add_relu_kernel.h" #include -#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -51,13 +52,46 @@ bool ConvBNAddReluKernel::Init( } param->SetNewScale(new_scale); param->SetNewBias(new_bias); + + InitBaseConvKernel(param); return true; } template <> void ConvBNAddReluKernel::Compute( const FusionConvBNAddReluParam ¶m) { - ConvBNAddReluCompute(param); + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + math::DepthwiseConv3x3S1(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + math::DepthwiseConv3x3S2(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; +#endif // __aarch64__ + case ConvParam::EXEC_GEMM_FLOAT: + ConvBNReluBasic>(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } } template class ConvBNAddReluKernel; diff --git a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp index 738e85b01b..0de0704884 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp @@ -60,25 +60,15 @@ template <> void ConvBNReluKernel::Compute( const FusionConvBNReluParam ¶m) { switch (param.ExecMode()) { - case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: - math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: - math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + math::DepthwiseConv3x3S1(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + math::DepthwiseConv3x3S1(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index e070db0ecc..3ff08eae0d 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam *param) { #endif // __aarch64__ } else { if (depth3x3 && param->Strides()[0] == param->Strides()[1] && - param->Strides()[0] == 1 && param->Paddings()[0] == 1 && - param->Paddings()[0] == param->Paddings()[1]) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT; + param->Strides()[0] == 1) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT; } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && - param->Strides()[0] == 2 && param->Paddings()[0] == 0 && - param->Paddings()[0] == param->Paddings()[1]) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT; - } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && - param->Strides()[0] == 2 && param->Paddings()[0] == 1 && - param->Paddings()[0] == param->Paddings()[1]) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT; - } else if (depth3x3) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_FLOAT; + param->Strides()[0] == 2) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT; #ifndef __aarch64__ } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == 1) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; - } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] && + } else if (conv3x3 && !depth3x3 && + param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && - param->Strides()[0] == 1 && param->Dilations()[0] == 1 /* && - param->Output()->dims()[1] >= 16 && + param->Strides()[0] == 1 && param->Dilations()[0] == 1 +#if 0 + && param->Output()->dims()[1] >= 16 && param->Input()->dims()[1] >= 16 && - param->Input()->dims()[2] <= 140 */ /* refered from ncnn */) { + param->Input()->dims()[2] <= 140 */ /* refered from ncnn */ +#endif + ) { param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; // transform weight param->transformed_filter_ = new framework::LoDTensor; diff --git a/src/operators/kernel/arm/convolution/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_kernel.cpp index 45776de8b8..59026be4d6 100644 --- a/src/operators/kernel/arm/convolution/conv_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_kernel.cpp @@ -18,6 +18,8 @@ limitations under the License. */ #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_arm_func.h" +#include + namespace paddle_mobile { namespace operators { @@ -41,21 +43,13 @@ void ConvKernel::Compute(const ConvParam ¶m) { DepthwiseConv5x5(param); break; #endif // __aarch64__ - case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: - math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), nullptr, false, false); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + math::DepthwiseConv3x3S1(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + math::DepthwiseConv3x3S2(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); break; #ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: diff --git a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp index f600d8af0c..9b5f87b1a7 100644 --- a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp @@ -60,25 +60,15 @@ template <> void DWConvBNReluKernel::Compute( const FusionDWConvBNReluParam ¶m) { switch (param.ExecMode()) { - case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: - math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: - math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + math::DepthwiseConv3x3S1(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + math::DepthwiseConv3x3S2(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index d6aa5052dd..24b35229b3 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam ¶m) { } } -template -void ConvAddCompute(const FusionConvAddParam ¶m) { - param.Output()->mutable_data(); - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - param.Bias(), true, false); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConv3x3(param.Input(), param.Strides(), - // param.Paddings(), - // param.Filter(), param.Bias(), - // param.Output(), false); - if (param.Paddings()[0] == 0) { - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - param.Bias(), true, false); - } else { - math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.Bias(), true, false); - } - } else { - ConvAddBasic(param); - } -} - } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h index caaf467141..9e32d20291 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h @@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam ¶m) { } } } -template -void ConvBNAddReluCompute(const FusionConvBNAddReluParam ¶m) { - Tensor Bias; - Bias.mutable_data({param.Groups()}); - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), - // param.Output(), param.NewScale(), - // param.NewBias(), 1); - math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - } else { - ConvBNAddReluBasic(param); - } -} } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h b/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h index 34e9e120ae..33ceefadd8 100644 --- a/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h @@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam ¶m) { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &out_slice); - } else if (data_dim == 3U) { col2vol(col, dilations, strides, paddings, &out_slice); } diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index 6bdfed0b1f..fe571918ba 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -12,2066 +12,1042 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + #include "operators/math/depthwise_conv3x3.h" -#include -#if __ARM_NEON #include -#endif namespace paddle_mobile { namespace operators { namespace math { -void DepthwiseConv3x3(const framework::Tensor *input, - const std::vector &strides, - const std::vector &paddings, - const framework::Tensor *filter, framework::Tensor *bias, - framework::Tensor *output, bool if_bias) { - const int batch_size = input->dims()[0]; - - const int input_height = input->dims()[2]; - - const int input_width = input->dims()[3]; - - const int output_channels = output->dims()[1]; - - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; - const int _kernel_size = 3; - const int stride_height = strides[0]; - const int stride_width = strides[1]; - const int padding_height = paddings[0]; - const int padding_width = paddings[1]; - const float zero = 0; - const int input_channel_stride = input_height * input_width; - const int output_channel_stride = output_height * output_width; - const int filter_channel_stride = 9; - - const float *input_ptr = input->data(); - const float *filter_ptr = filter->data(); - if (if_bias) { - math::expand_bias(*bias, 1, output->dims()); - output->ShareDataWith(*bias); - } - float *output_ptr = output->mutable_data(); - - const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr2; - int hstart, wstart, hend, wend; - float result; - for (int i = 0; i < batch_size; ++i) { -#pragma omp parallel for - for (int c = 0; c < output_channels; ++c) { - const float *input_data = - input_ptr + (i * output_channels + c) * input_channel_stride; - float *output_data = - output_ptr + (i * output_channels + c) * output_channel_stride; - filter1 = filter_ptr + c * filter_channel_stride; - filter2 = filter1 + 3; - filter3 = filter2 + 3; - for (int ph = 0; ph < output_height; ph++) { - for (int pw = 0; pw < output_width; pw++) { - hstart = ph * stride_height - padding_height; - wstart = pw * stride_width - padding_width; - hend = std::min(hstart + _kernel_size, input_height + padding_height); - wend = std::min(wstart + _kernel_size, input_width + padding_width); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, input_height); - wend = std::min(wend, input_width); - pos1 = input_data + hstart * input_width + wstart; - pos2 = input_data + (hstart + 1) * input_width + wstart; - pos3 = input_data + (hstart + 2) * input_width + wstart; - output_ptr2 = output_data + ph * output_width + pw; - - if (hend - hstart != 3 || wend - wstart != 3) { - result = 0; - float fake_input[9] = {0}; - if (hstart == 0 && wstart == 0) { - // 左上角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend && k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k - - (3 - wend)]; - } - } - } - } else if (hstart == 0 && wend == input_width) { - // 右上角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend && k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k + wstart]; - } - } - } - - } else if (hend == input_height && wstart == 0) { - // 左下角 - - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - 1 - hstart && k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k - (3 - wend)]; - } - } - } - } else if (hend == input_height && wend == input_width) { - // 右下角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - hstart - 1 && - k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - } else if (hstart == 0) { - // 顶部 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k + wstart]; - } - } - } - - } else if (hend == input_height) { - // 底部 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - hstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - - } else if (wstart == 0) { - // 左侧 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + - (k - (3 - wend))]; - } - } - } - - } else if (wend == input_width) { - // 右侧 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - } - for (int l = 0; l < 9; ++l) { - result += fake_input[l] * filter1[l]; - } - if (if_bias) { - output_data[ph * output_width + pw] += result; - } else { - output_data[ph * output_width + pw] = result; - } - - } else { -#if __ARM_NEON -#if __aarch64__ - const float32x4_t data1 = vld1q_f32(pos1); - const float32x4_t data2 = vld1q_f32(pos2); - const float32x4_t data3 = vld1q_f32(pos3); - - const float32x4_t v_filter1 = vld1q_f32(filter1); - const float32x4_t v_filter2 = vld1q_f32(filter2); - const float32x4_t v_filter3 = vld1q_f32(filter3); - float32x4_t mula = vmulq_f32(data1, v_filter1); - mula = vmlaq_f32(mula, data2, v_filter2); - mula = vmlaq_f32(mula, data3, v_filter3); - float32x2_t res = vpadd_f32( - vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula)); - res = vpadd_f32(res, res); - if (if_bias) { - output_data[ph * output_width + pw] += vget_lane_f32(res, 0); - } else { - output_data[ph * output_width + pw] = vget_lane_f32(res, 0); - } -#else - asm volatile( - - "vld1.32 {q1}, [%[pos1]] \n\t" - "vld1.32 {q4}, [%[filter1]] \n\t" - "vmov.f32 q0, #0.0 \n\t" - - "vld1.32 {q2}, [%[pos2]] \n\t" - "vld1.32 {q5}, [%[filter2]] \n\t" - "vmla.f32 q0, q1, q4 \n\t" - - "vld1.32 {q3}, [%[pos3]] \n\t" - "vld1.32 {q6}, [%[filter3]] \n\t" - - "vmla.f32 q0, q2, q5 \n\t" - "vmla.f32 q0, q3, q6 \n\t" - - "vmov.f32 d1[1], %[zero] \n\t" - - "vadd.f32 d4, d0, d1 \n\t" - "vadd.f32 s10, s8, s9 \n\t" - "vst1.32 {d5[0]},[%[output_ptr]] \n\t" - : - : [input_data] "r"(input_data), [pos1] "r"(pos1), - [pos2] "r"(pos2), [pos3] "r"(pos3), [filter1] "r"(filter1), - [filter2] "r"(filter2), [filter3] "r"(filter3), - [output_ptr] "r"(output_ptr2), [zero] "r"(zero) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); -#endif // __aarch64__ -#else - -#endif // __ARM_NEON - } - } - } - } - } +#ifndef __aarch64__ +inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) { + float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0)); + float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1)); + return vcombine_f32(sum0, sum1); } +#endif -void DepthwiseConv3x3s1p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu) { -#if __ARM_NEON - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int h = static_cast(input->dims()[2]); - const int w = static_cast(input->dims()[3]); - const int hxw = h * w; - // const int l = h; - - // leftTop, rightTop, leftBottom, rightBottom - const int lt = 0; - const int rt = w - 1; - const int lb = (h - 1) * w; - const int rb = h * w - 1; - - const float *bias_data; - if (if_bias) { - bias_data = bias->data(); - } - float32x4_t zero = vdupq_n_f32(0.0); - - for (int b = 0; b < batch_size; ++b) { -#pragma omp parallel for - for (int j = 0; j < c; ++j) { - const float *filter_data_tmp = filter->data() + j * 9; - const float *input_data = input->data() + j * hxw; - float *output_data = output->mutable_data() + j * hxw; - float32x4_t vbias; - if (if_bias) { - vbias = vdupq_n_f32(bias_data[j]); - } - - int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0 - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - output_data[lt] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[w] + w22 * input_data[w + 1]; - output_data[rt] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + - w20 * input_data[2 * w - 2] + - w21 * input_data[2 * w - 1]; - output_data[lb] = - w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] + - w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; - output_data[rb] = - w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + - w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; - if (if_bias) { - output_data[lt] += bias_data[j]; - output_data[rt] += bias_data[j]; - output_data[lb] += bias_data[j]; - output_data[rb] += bias_data[j]; - } - if (if_relu) { - output_data[lt] = output_data[lt] < 0 ? 0 : output_data[lt]; - output_data[rt] = output_data[rt] < 0 ? 0 : output_data[rt]; - output_data[lb] = output_data[lb] < 0 ? 0 : output_data[lb]; - output_data[rb] = output_data[rb] < 0 ? 0 : output_data[rb]; - } - - for (int i = 1; i < h - 1; ++i) { - int left = i * w; - int right = i * w + w - 1; - output_data[left] = - w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + - w11 * input_data[i * w] + w12 * input_data[i * w + 1] + - w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; - - output_data[right] = w00 * input_data[i * w + w - 1 - w - 1] + - w01 * input_data[i * w + w - 1 - w] + - w10 * input_data[i * w + w - 1 - 1] + - w11 * input_data[i * w + w - 1] + - w20 * input_data[i * w + w - 1 + w - 1] + - w21 * input_data[i * w + w - 1 + w]; - if (if_bias) { - output_data[left] += bias_data[j]; - output_data[right] += bias_data[j]; - } - if (if_relu) { - output_data[left] = output_data[left] < 0 ? 0 : output_data[left]; - output_data[right] = output_data[right] < 0 ? 0 : output_data[right]; - } - } - - // top 1 row and bottom 1 row - const float *input_tmp = input_data; - - float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, - tmp3, tmp4, tmp5, out0; - in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + w); - const float *input_tmp_end = input_tmp + (h - 2) * w; - in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + w); - int c_mid = w_mid; - auto output_ptr = output_data + 1; - for (; c_mid > 3; c_mid -= 4) { - in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + w + 4); - - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - vst1q_f32(output_ptr, out0); - - in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + w + 4); - - tmp0 = vextq_f32(in4, in5, 1); - tmp1 = vextq_f32(in4, in5, 2); - tmp2 = vextq_f32(in6, in7, 1); - tmp3 = vextq_f32(in6, in7, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - vst1q_f32(output_ptr + (h - 1) * w, out0); - - // can optimize to each 8 stride. - input_tmp += 4; - input_tmp_end += 4; - output_ptr += 4; - in0 = in1; - in2 = in3; - in4 = in5; - in6 = in7; - } - - // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); - - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 1); - tmp3 = vextq_f32(in2, pad1, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - - // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]); - float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); - - tmp0 = vextq_f32(in4, pad2, 1); - tmp1 = vextq_f32(in4, pad2, 2); - tmp2 = vextq_f32(in6, pad3, 1); - tmp3 = vextq_f32(in6, pad3, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); - } - } - // mid - - for (int i = 0; i < h - 2; ++i) { - auto output_ptr = output_data + (i + 1) * w + 1; - input_tmp = input_data + i * w; - auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + w); - auto in4_tmp = vld1q_f32(input_tmp + w + w); - c_mid = w_mid; - for (; c_mid > 3; c_mid -= 4) { - auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + w + 4); - auto in5_tmp = vld1q_f32(input_tmp + w + w + 4); - - tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); - tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); - tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); - tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); - tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); - tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - - vst1q_f32(output_ptr, out0); +template +inline void Depth3x3NormalRowLoadInput(const float *input, float32x4_t *y) { + y[0] = vld1q_f32(input); + y[2] = vld1q_f32(input + 4); + y[1] = vextq_f32(y[0], y[2], 1); + y[2] = vextq_f32(y[0], y[2], 2); +} - output_ptr += 4; - input_tmp += 4; - in0_tmp = in1_tmp; - in2_tmp = in3_tmp; - in4_tmp = in5_tmp; - } +template <> +inline void Depth3x3NormalRowLoadInput<2>(const float *input, float32x4_t *y) { + float32x4x2_t x = vld2q_f32(input); + y[0] = x.val[0]; + y[1] = x.val[1]; + y[2] = vextq_f32(y[0], y[0], 1); + y[2] = vsetq_lane_f32(input[8], y[2], 3); +} - float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); - - tmp0 = vextq_f32(in0_tmp, pad0, 1); - tmp1 = vextq_f32(in0_tmp, pad0, 2); - tmp2 = vextq_f32(in2_tmp, pad1, 1); - tmp3 = vextq_f32(in2_tmp, pad1, 2); - tmp4 = vextq_f32(in4_tmp, pad2, 1); - tmp5 = vextq_f32(in4_tmp, pad2, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } +#define DEPTHWISE_CONV3X3_NORMAL_BORDER(start, end) \ + for (int w = start; w < end; ++w) { \ + const int w_in_start = -padding_w + w * Stride_w; \ + const int w_in_end = w_in_start + 3; \ + const int w_start = w_in_start > 0 ? w_in_start : 0; \ + const int w_end = w_in_end < input_w ? w_in_end : input_w; \ + float value = 0; \ + for (int h_in = h_start; h_in < h_end; ++h_in) { \ + for (int w_in = w_start; w_in < w_end; ++w_in) { \ + value += filter[(h_in - h_in_start) * 3 + (w_in - w_in_start)] * \ + input[h_in * input_w + w_in]; \ + } \ + } \ + output_ptr[w] = value; \ + } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - } +template +inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter, + const int h_output, const int input_h, + const int input_w, const int padding_h, + const int padding_w, const int output_w, + float *output, float32x4_t *ker) { + const int h_in_start = -padding_h + h_output * Stride_h; + const int h_in_end = h_in_start + 3; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int h_end = h_in_end < input_h ? h_in_end : input_h; + + const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; + const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1; + // const int valid_w_end = output_w - valid_w_start; + float *output_ptr = output + h_output * output_w; + // border left + DEPTHWISE_CONV3X3_NORMAL_BORDER(0, valid_w_start) + // middle + int output_tiles = (valid_w_end - valid_w_start) >> 2; + float32x4_t _sum, _x[3]; + // valid w + for (int w = 0; w < output_tiles * 4; w += 4) { + _sum = vdupq_n_f32(0.f); + int output_offset = valid_w_start + w; + int input_w_offset = output_offset * Stride_w - padding_w; + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth3x3NormalRowLoadInput( + input + h_in * input_w + input_w_offset, _x); + _sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0); + _sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1); + _sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0); } + vst1q_f32(output_ptr + output_offset, _sum); } -#endif + // remain valid w + int remain = (valid_w_end - valid_w_start) & 0x3; + if (remain > 0) { + _sum = vdupq_n_f32(0.f); + int remain_start = valid_w_start + (output_tiles << 2); + int input_w_offset = remain_start * Stride_w - padding_w; + float *output_ptr0 = output_ptr + remain_start; + + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth3x3NormalRowLoadInput( + input + h_in * input_w + input_w_offset, _x); + _sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0); + _sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1); + _sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0); + } + switch (remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _sum, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_sum)); + break; + case 1: + vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0); + break; + } + } + // border right + DEPTHWISE_CONV3X3_NORMAL_BORDER(valid_w_end, output_w) } -void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu) { -#if __ARM_NEON - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->mutable_data(); - const float *newscale_data = new_scale->data(); - const float *newbias_data = new_bias->data(); - - const int batch_size = static_cast(input->dims()[0]); - const int input_channel = static_cast(input->dims()[1]); - - const int input_height = static_cast(input->dims()[2]); - const int input_width = static_cast(input->dims()[3]); - const int output_height = static_cast(output->dims()[2]); - const int output_width = static_cast(output->dims()[3]); - - const int hxw = input_height * input_width; - - // const int l = input_height; - const int h = input_height; - const int w = input_width; - float32x4_t vzero = vdupq_n_f32(0); - - for (int b = 0; b < batch_size; b++) { -#pragma omp parallel for - for (int c = 0; c < input_channel; c++) { - const float *filter_data = filter->data() + c * 9; - const float *input_data = input->data() + c * hxw; - float *output_data = output->data() + c * hxw; - float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]); - float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]); - - float w00 = filter_data[0]; - float w01 = filter_data[1]; - float w02 = filter_data[2]; - float w10 = filter_data[3]; - float w11 = filter_data[4]; - float w12 = filter_data[5]; - float w20 = filter_data[6]; - float w21 = filter_data[7]; - float w22 = filter_data[8]; - - for (int i = 1; i < output_height - 1; i++) { - float *output_ptr; - float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, tmp4, - tmp5, out0; - for (int m = 1; m < output_width - 4; m += 4) { - output_ptr = output_data + i * output_width + m; - in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1); - in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3); - in2 = vld1q_f32(input_data + i * input_width + m - 1); - in3 = vld1q_f32(input_data + i * input_width + m + 3); - in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1); - in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3); - - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - tmp4 = vextq_f32(in4, in5, 1); - tmp5 = vextq_f32(in4, in5, 2); - - out0 = vmulq_n_f32(in0, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - vst1q_f32(output_ptr, out0); - } - int m; - for (m = 1; (m + 3) < output_width - 1; m = m + 4) { - } +template <> +void DepthwiseConv3x3S1(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output) { + const float *input_data = input.data(); + const float *filter_data = filter.data(); + float *out_data = output->mutable_data(); + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; + int valid_h_start = padding_h; + int valid_h_end = output_h - valid_h_start; + int valid_h = valid_h_end - valid_h_start; + int valid_w_start = padding_w; + int valid_w_end = output_w - valid_w_start; + int valid_w = valid_w_end - valid_w_start; + + #pragma omp parallel for + for (int g = 0; g < input.dims()[1]; ++g) { + const float *input_ptr = input_data + g * image_size; + const float *filter_ptr = filter_data + g * 9; + float *output_ptr = out_data + g * out_image_size; + + const float *filter_ptr0 = filter_ptr; + const float *filter_ptr1 = filter_ptr0 + 3; + const float *filter_ptr2 = filter_ptr1 + 3; + float32x4_t _ker[3]; + _ker[0] = vld1q_f32(filter_ptr0); + _ker[1] = vld1q_f32(filter_ptr1); + _ker[2] = vld1q_f32(filter_ptr2); + + // pad top + for (int h = 0; h < valid_h_start; ++h) { + DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); + } - for (int j = m; j < output_width - 1; j++) { - output_data[i * output_width + j] = - input_data[(i - 1) * input_width + j - 1] * w00 + - input_data[(i - 1) * input_width + j] * w01 + - input_data[(i - 1) * input_width + j + 1] * w02 + - input_data[(i)*input_width + j - 1] * w10 + - input_data[(i)*input_width + j] * w11 + - input_data[(i)*input_width + j + 1] * w12 + - input_data[(i + 1) * input_width + j - 1] * w20 + - input_data[(i + 1) * input_width + j] * w21 + - input_data[(i + 1) * input_width + j + 1] * w22; - output_data[i * output_width + j] = - newscale_data[c] * output_data[i * output_width + j] + - newbias_data[c]; - if (if_relu) { - output_data[i * output_width + j] = - output_data[i * output_width + j] < 0 - ? 0 - : output_data[i * output_width + j]; + // output 2x6 + int output_w_tiles = valid_w / 6; + int output_w_remain = valid_w - output_w_tiles * 6; + for (int h = valid_h_start; h < valid_h_end - 1; h += 2) { + const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + float *output_ptr1 = output_ptr0 + output_w; + // pad left + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t row3 = vld1q_f32(input_ptr3); + float32x4_t zero = vdupq_n_f32(0.f); + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + row3 = vextq_f32(zero, row3, 3); + float32x4_t acc0, acc1; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0.f; + output_ptr1[w] = 0.f; + } else { + acc0 = vmulq_f32(row0, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + acc0 = vextq_f32(acc0, acc0, 1); + acc1 = vmulq_f32(row1, _ker[0]); + acc1 = vmlaq_f32(acc1, row2, _ker[1]); + acc1 = vmlaq_f32(acc1, row3, _ker[2]); + acc1 = vextq_f32(acc1, acc1, 1); + float32x2_t sum = vpadd_f32(vget_low_f32(acc0), vget_low_f32(acc1)); + vst1_lane_f32(output_ptr0 + w, sum, 0); + vst1_lane_f32(output_ptr1 + w, sum, 1); + + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + row3 = vextq_f32(zero, row3, 3); } } + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; } - - output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[w] + w22 * input_data[w + 1]; - output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + - w20 * input_data[2 * w - 2] + - w21 * input_data[2 * w - 1]; - output_data[(h - 1) * w] = - w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] + - w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; - output_data[h * w - 1] = - w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + - w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; - output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c]; - output_data[w - 1] = - output_data[w - 1] * newscale_data[c] + newbias_data[c]; - output_data[(h - 1) * w] = - output_data[(h - 1) * w] * newscale_data[c] + newbias_data[c]; - output_data[h * w - 1] = - output_data[h * w - 1] * newscale_data[c] + newbias_data[c]; - - if (if_relu) { - output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; - output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - 1]; - output_data[(h - 1) * w] = - output_data[(h - 1) * w] < 0 ? 0 : output_data[(h - 1) * w]; - output_data[h * w - 1] = - output_data[h * w - 1] < 0 ? 0 : output_data[h * w - 1]; - } - for (int i = 1; i < h - 1; ++i) { - output_data[i * w] = - w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + - w11 * input_data[i * w] + w12 * input_data[i * w + 1] + - w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; - - output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + - w01 * input_data[i * w + w - 1 - w] + - w10 * input_data[i * w + w - 1 - 1] + - w11 * input_data[i * w + w - 1] + - w20 * input_data[i * w + w - 1 + w - 1] + - w21 * input_data[i * w + w - 1 + w]; - output_data[i * w] = - output_data[i * w] * newscale_data[c] + newbias_data[c]; - output_data[i * w + w - 1] = - output_data[i * w + w - 1] * newscale_data[c] + newbias_data[c]; - - if (if_relu) { - output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i * w]; - output_data[i * w + w - 1] = - output_data[i * w + w - 1] < 0 ? 0 : output_data[i * w + w - 1]; - } + // valid + float32x4_t _result0, _result1, _result2, _result3; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4_t _row00 = vld1q_f32(input_ptr0); + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4); + float32x4_t _row10 = vld1q_f32(input_ptr1); + float32x4_t _row11 = vld1q_f32(input_ptr1 + 4); + + float32x4_t _ext01 = vextq_f32(_row00, _row01, 1); + float32x4_t _ext02 = vextq_f32(_row00, _row01, 2); + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1); + float32x4_t _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0); + _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0); + + _result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0); + _result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0); + + _row00 = vld1q_f32(input_ptr2); + _row01 = vld1q_f32(input_ptr2 + 4); + _row10 = vld1q_f32(input_ptr3); + _row11 = vld1q_f32(input_ptr3 + 4); + + _ext01 = vextq_f32(_row00, _row01, 1); + _ext02 = vextq_f32(_row00, _row01, 2); + _ext03 = vextq_f32(_row01, _row01, 1); + _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0); + + _result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0); + _result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0); + _result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0); + + vst1q_f32(output_ptr0, _result0); + vst1_f32(output_ptr0 + 4, vget_low_f32(_result1)); + vst1q_f32(output_ptr1, _result2); + vst1_f32(output_ptr1 + 4, vget_low_f32(_result3)); + + input_ptr0 += 6; + input_ptr1 += 6; + input_ptr2 += 6; + input_ptr3 += 6; + output_ptr0 += 6; + output_ptr1 += 6; } - - int m; - for (m = 1; m < output_width - 4; m += 4) { - float *output_ptr = output_data + m; - float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; - in0 = vld1q_f32(input_data + m - 1); - in1 = vld1q_f32(input_data + m + 3); - in2 = vld1q_f32(input_data + input_width + m - 1); - in3 = vld1q_f32(input_data + input_width + m + 3); - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); + // remain w + if (output_w_remain > 0) { + float32x4_t _row00 = vld1q_f32(input_ptr0); + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4); + float32x4_t _row10 = vld1q_f32(input_ptr1); + float32x4_t _row11 = vld1q_f32(input_ptr1 + 4); + + float32x4_t _ext01 = vextq_f32(_row00, _row01, 1); + float32x4_t _ext02 = vextq_f32(_row00, _row01, 2); + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1); + float32x4_t _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0); + _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0); + + _result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0); + _result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0); + + _row00 = vld1q_f32(input_ptr2); + _row01 = vld1q_f32(input_ptr2 + 4); + _row10 = vld1q_f32(input_ptr3); + _row11 = vld1q_f32(input_ptr3 + 4); + + _ext01 = vextq_f32(_row00, _row01, 1); + _ext02 = vextq_f32(_row00, _row01, 2); + _ext03 = vextq_f32(_row01, _row01, 1); + _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0); + + _result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0); + _result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0); + _result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0); + + switch (output_w_remain) { + case 5: + vst1q_lane_f32(output_ptr0 + 4, _result1, 0); + vst1q_lane_f32(output_ptr1 + 4, _result3, 0); + case 4: + vst1q_f32(output_ptr0, _result0); + vst1q_f32(output_ptr1, _result2); + break; + case 3: + vst1q_lane_f32(output_ptr0 + 2, _result0, 2); + vst1q_lane_f32(output_ptr1 + 2, _result2, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_result0)); + vst1_f32(output_ptr1, vget_low_f32(_result2)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _result0, 0); + vst1q_lane_f32(output_ptr1, _result2, 0); + break; } - vst1q_f32(output_ptr, out0); - } - for (m = 1; (m + 3) < output_width - 1; m += 4) { + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + input_ptr3 += output_w_remain; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; } - for (int j = m; j < output_width - 1; j++) { - output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 + - input_data[j + 1] * w12 + - input_data[input_width + j - 1] * w20 + - input_data[input_width + j] * w21 + - input_data[input_width + j + 1] * w22; - output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c]; - - if (if_relu) { - output_data[j] = output_data[j] < 0 ? 0 : output_data[j]; + // pad right + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t row3 = vld1_f32(input_ptr3); + float32x2_t zero = vdup_n_f32(0.f); + float32x2_t acc0, acc1; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0.f; + *output_ptr1 = 0.f; + } else { + acc0 = vmul_f32(row0, vget_low_f32(_ker[0])); + acc0 = vmla_f32(acc0, row1, vget_low_f32(_ker[1])); + acc0 = vmla_f32(acc0, row2, vget_low_f32(_ker[2])); + acc1 = vmul_f32(row1, vget_low_f32(_ker[0])); + acc1 = vmla_f32(acc1, row2, vget_low_f32(_ker[1])); + acc1 = vmla_f32(acc1, row3, vget_low_f32(_ker[2])); + float32x2_t sum = vpadd_f32(acc0, acc1); + vst1_lane_f32(output_ptr0, sum, 0); + vst1_lane_f32(output_ptr1, sum, 1); + row0 = vext_f32(row0, zero, 1); + row1 = vext_f32(row1, zero, 1); + row2 = vext_f32(row2, zero, 1); + row3 = vext_f32(row3, zero, 1); + } + output_ptr0++; + output_ptr1++; } } - - for (m = 1; m < output_width - 4; m += 4) { - float *output_ptr = - output_data + (output_height - 1) * output_width + m; - - float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; - in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1); - in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3); - in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1); - in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3); - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - out0 = vmulq_n_f32(in0, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); + } + // remain height + int start_h = valid_h_start + (valid_h & 0xfffe); + if (start_h < valid_h_end) { + const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + float *output_ptr0 = output_ptr + start_h * output_w; + // pad left + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t zero = vdupq_n_f32(0.f); + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + float32x4_t acc; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0.f; + } else { + acc = vmulq_f32(row0, _ker[0]); + acc = vmlaq_f32(acc, row1, _ker[1]); + acc = vmlaq_f32(acc, row2, _ker[2]); + acc = vextq_f32(acc, acc, 1); + float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_low_f32(acc)); + vst1_lane_f32(output_ptr0 + w, sum, 0); + + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + } } - vst1q_f32(output_ptr, out0); + output_ptr0 += valid_w_start; } - for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + // valid + float32x4_t _result0, _result1; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4_t _row00 = vld1q_f32(input_ptr0); + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4); + float32x4_t _row10 = vld1q_f32(input_ptr1); + float32x4_t _row11 = vld1q_f32(input_ptr1 + 4); + + float32x4_t _ext01 = vextq_f32(_row00, _row01, 1); + float32x4_t _ext02 = vextq_f32(_row00, _row01, 2); + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1); + float32x4_t _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0); + _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0); + + _row00 = vld1q_f32(input_ptr2); + _row01 = vld1q_f32(input_ptr2 + 4); + + _ext01 = vextq_f32(_row00, _row01, 1); + _ext02 = vextq_f32(_row00, _row01, 2); + _ext03 = vextq_f32(_row01, _row01, 1); + _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0); + + vst1q_f32(output_ptr0, _result0); + vst1_f32(output_ptr0 + 4, vget_low_f32(_result1)); + + input_ptr0 += 6; + input_ptr1 += 6; + input_ptr2 += 6; + output_ptr0 += 6; } - for (int j = m; j < output_width - 1; j++) { - output_data[(output_height - 1) * input_width + j] = - input_data[(output_height - 2) * input_width + j - 1] * w00 + - input_data[(output_height - 2) * input_width + j] * w01 + - input_data[(output_height - 2) * input_width + j + 1] * w02 + - input_data[(output_height - 1) * input_width + j - 1] * w10 + - input_data[(output_height - 1) * input_width + j] * w11 + - input_data[(output_height - 1) * input_width + j + 1] * w12; - output_data[(output_height - 1) * output_width + j] = - output_data[(output_height - 1) * output_width + j] * - newscale_data[c] + - newbias_data[c]; - - if (if_relu) { - output_data[(output_height - 1) * output_width + j] = - output_data[(output_height - 1) * output_width + j] < 0 - ? 0 - : output_data[(output_height - 1) * output_width + j]; - } - } - } - } - - /* - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->data(); - const float *newscale_data = new_scale->data(); - const float *newbias_data = new_bias->data(); - - const int h = static_cast(input->dims()[2]); - const int w = static_cast(input->dims()[3]); -// const int l = h; - - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int hxw = h * w; - float32x4_t vnewbias = vdupq_n_f32(0.0); - float32x4_t vnewscale = vdupq_n_f32(1.0); - float32x4_t vzero = vdupq_n_f32(0); - - for (int b = 0; b < batch_size; ++b) { - const float *filter_data_tmp = filter_data; - - for (int j = 0; j < c; ++j) { - vnewbias = vdupq_n_f32(newbias_data[j]); - vnewscale = vdupq_n_f32(newscale_data[j]); - - int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0 - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[w] + w22 * input_data[w + 1]; - - output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - - 1] + w20 * input_data[2 * w - 2] + w21 * input_data[2 * w - 1]; - - output_data[(h - 1) * w] = - w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + - 1] + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; - output_data[h * w - 1] = w00 * input_data[h*w-w-2] + - w01 * input_data[h*w-w-1] + - w10 * input_data[h * w - 2] + - w11 * input_data[h * w - 1]; - output_data[0] = output_data[0] * newscale_data[j] + - newbias_data[j]; output_data[w - 1] = output_data[w - 1] * - newscale_data[j] + newbias_data[j]; output_data[(h - 1) * w] = - output_data[(h - 1) * w] * newscale_data[j] + newbias_data[j]; - output_data[h * w - 1] = - output_data[h * w - 1] * newscale_data[j] + newbias_data[j]; - - if (if_relu) { - output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; - output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - - 1]; output_data[(h - 1) * w] = output_data[(h - 1) * w] < 0 ? 0 : - output_data[(h - 1) * w]; output_data[h * w - 1] = output_data[h * w - 1] - < 0 ? 0 : output_data[h * w - 1]; - } - for (int i = 1; i < h - 1; ++i) { - output_data[i * w] = - w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] - + w11 * input_data[i * w] + w12 * input_data[i * w + 1] + w21 * - input_data[i * w + w] + w22 * input_data[i * w + w + 1]; output_data[i * - w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + w01 * input_data[i - * w + w - 1 - w] + w10 * input_data[i * w + w - 1 - 1] + w11 * - input_data[i * w + w - 1] + w20 * input_data[i * w + w - 1 + w - 1] + w21 - * input_data[i * w + w - 1 + w]; output_data[i * w] = output_data[i * w] - * newscale_data[j] + newbias_data[j]; output_data[i * w + w - 1] = - output_data[i * w + w - 1] * newscale_data[j] + - newbias_data[j]; - - if (if_relu) { - output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i - * w]; output_data[i * w + w - 1] = output_data[i * w + w - 1] < 0 ? 0 : - output_data[i * w + w - 1]; - } - } - - // top 1 row and bottom 1 row - const float *input_tmp = input_data; - - float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, - tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); in2 = - vld1q_f32(input_tmp + w); const float *input_tmp_end = input_tmp + (h - - 2) * w; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end + - w); int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid > - 3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); in3 = - vld1q_f32(input_tmp + w + 4); - - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - vst1q_f32(output_ptr, out0); - - in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + w + 4); - - tmp0 = vextq_f32(in4, in5, 1); - tmp1 = vextq_f32(in4, in5, 2); - tmp2 = vextq_f32(in6, in7, 1); - tmp3 = vextq_f32(in6, in7, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - vst1q_f32(output_ptr + (h - 1) * w, out0); - - // can optimize to each 8 stride. - input_tmp += 4; - input_tmp_end += 4; - output_ptr += 4; - in0 = in1; - in2 = in3; - in4 = in5; - in6 = in7; - } - // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); - - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 1); - tmp3 = vextq_f32(in2, pad1, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - - // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]); - float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); - - tmp0 = vextq_f32(in4, pad2, 1); - tmp1 = vextq_f32(in4, pad2, 2); - tmp2 = vextq_f32(in6, pad3, 1); - tmp3 = vextq_f32(in6, pad3, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); - } - } - // mid - - - for (int i = 0; i < h - 2; ++i) { - auto output_ptr = output_data + (i + 1) * w + 1; - input_tmp = input_data + i * w; - auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + w); - auto in4_tmp = vld1q_f32(input_tmp + w + w); - c_mid = w_mid; - for (; c_mid > 3; c_mid -= 4) { - auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + w + 4); - auto in5_tmp = vld1q_f32(input_tmp + w + w + 4); - - tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); - tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); - tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); - tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); - tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); - tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - vst1q_f32(output_ptr, out0); - - output_ptr += 4; - input_tmp += 4; - in0_tmp = in1_tmp; - in2_tmp = in3_tmp; - in4_tmp = in5_tmp; - } - - float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); - - tmp0 = vextq_f32(in0_tmp, pad0, 1); - tmp1 = vextq_f32(in0_tmp, pad0, 2); - tmp2 = vextq_f32(in2_tmp, pad1, 1); - tmp3 = vextq_f32(in2_tmp, pad1, 2); - tmp4 = vextq_f32(in4_tmp, pad2, 1); - tmp5 = vextq_f32(in4_tmp, pad2, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - } - output_data += hxw; - input_data += hxw; - filter_data_tmp += 9; - } + if (output_w_remain > 0) { + float32x4_t _row00 = vld1q_f32(input_ptr0); + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4); + float32x4_t _row10 = vld1q_f32(input_ptr1); + float32x4_t _row11 = vld1q_f32(input_ptr1 + 4); + + float32x4_t _ext01 = vextq_f32(_row00, _row01, 1); + float32x4_t _ext02 = vextq_f32(_row00, _row01, 2); + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1); + float32x4_t _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0); + _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0); + + _row00 = vld1q_f32(input_ptr2); + _row01 = vld1q_f32(input_ptr2 + 4); + + _ext01 = vextq_f32(_row00, _row01, 1); + _ext02 = vextq_f32(_row00, _row01, 2); + _ext03 = vextq_f32(_row01, _row01, 1); + _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0); + + switch (output_w_remain) { + case 5: + vst1q_lane_f32(output_ptr0 + 4, _result1, 0); + case 4: + vst1q_f32(output_ptr0, _result0); + break; + case 3: + vst1q_lane_f32(output_ptr0 + 2, _result0, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_result0)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _result0, 0); + break; } - */ -#endif -} - -/// w!=h not fix -void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu) { -#if __ARM_NEON - - const int batch_size = input->dims()[0]; - - const int input_height = input->dims()[2]; - - const int input_width = input->dims()[3]; - - const int output_channels = output->dims()[1]; - - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; - const int _kernel_size = 3; - const int stride_height = 2; - const int stride_width = 2; - const int padding_height = 1; - const int padding_width = 1; - const float zero = 0; - const int input_channel_stride = input_height * input_width; - const int output_channel_stride = output_height * output_width; - const int filter_channel_stride = 9; - const float *newscale_data = new_scale->data(); - const float *newbias_data = new_bias->data(); - - const float *input_data = input->data(); - const float *filter_data = filter->data(); - - float *output_data = output->mutable_data(); - - const int input_batch_stride = output_channels * input_channel_stride; - const int output_batch_stride = output_channels * output_channel_stride; - const int filter_batch_stride = output_channels * output_channel_stride; - const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr; - int hstart, wstart, hend, wend; - float result; - for (int i = 0; i < batch_size; ++i) { - for (int c = 0; c < output_channels; ++c) { - filter1 = filter_data; - filter2 = filter1 + 3; - filter3 = filter2 + 3; - - for (int ph = 0; ph < output_height; ph++) { - for (int pw = 0; pw < output_width; pw++) { - hstart = ph * stride_height - padding_height; - wstart = pw * stride_width - padding_width; - hend = std::min(hstart + _kernel_size, input_height + padding_height); - wend = std::min(wstart + _kernel_size, input_width + padding_width); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, input_height); - wend = std::min(wend, input_width); - pos1 = input_data + hstart * input_width + wstart; - pos2 = input_data + (hstart + 1) * input_width + wstart; - pos3 = input_data + (hstart + 2) * input_width + wstart; - output_ptr = output_data + ph * output_width + pw; - - if (hend - hstart != 3 || wend - wstart != 3) { - result = 0; - float fake_input[9] = {0}; - if (hstart == 0 && wstart == 0) { - // 左上角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend && k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k - - (3 - wend)]; - } - } - } - } else if (hstart == 0 && wend == input_width) { - // 右上角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend && k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k + wstart]; - } - } - } - - } else if (hend == input_height && wstart == 0) { - // 左下角 - - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - 1 - hstart && k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k - (3 - wend)]; - } - } - } - } else if (hend == input_height && wend == input_width) { - // 右下角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - hstart - 1 && - k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - } else if (hstart == 0) { - // 顶部 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k + wstart]; - } - } - } - - } else if (hend == input_height) { - // 底部 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - hstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - - } else if (wstart == 0) { - // 左侧 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + - (k - (3 - wend))]; - } - } - } - - } else if (wend == input_width) { - // 右侧 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - } - for (int l = 0; l < 9; ++l) { - result += fake_input[l] * filter1[l]; - } - output_data[ph * output_width + pw] = - newscale_data[c] * result + newbias_data[c]; - - if (if_relu) { - output_data[ph * output_width + pw] = - output_data[ph * output_width + pw] < 0 - ? 0 - : output_data[ph * output_width + pw]; - } + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + output_ptr0 += output_w_remain; + } + // pad right + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t zero = vdup_n_f32(0.f); + float32x2_t acc; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0.f; } else { - const float32x4_t data1 = vld1q_f32(pos1); - const float32x4_t data2 = vld1q_f32(pos2); - const float32x4_t data3 = vld1q_f32(pos3); - - const float32x4_t v_filter1 = vld1q_f32(filter1); - const float32x4_t v_filter2 = vld1q_f32(filter2); - const float32x4_t v_filter3 = vld1q_f32(filter3); - float32x4_t mula = vmulq_f32(data1, v_filter1); - mula = vmlaq_f32(mula, data2, v_filter2); - mula = vmlaq_f32(mula, data3, v_filter3); - float32x2_t res = vpadd_f32( - vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula)); - res = vpadd_f32(res, res); - output_data[ph * output_width + pw] = - vget_lane_f32(res, 0) * newscale_data[c] + newbias_data[c]; - - if (if_relu) { - output_data[ph * output_width + pw] = - output_data[ph * output_width + pw] < 0 - ? 0 - : output_data[ph * output_width + pw]; - } + acc = vmul_f32(row0, vget_low_f32(_ker[0])); + acc = vmla_f32(acc, row1, vget_low_f32(_ker[1])); + acc = vmla_f32(acc, row2, vget_low_f32(_ker[2])); + float32x2_t sum = vpadd_f32(acc, acc); + vst1_lane_f32(output_ptr0, sum, 0); + row0 = vext_f32(row0, zero, 1); + row1 = vext_f32(row1, zero, 1); + row2 = vext_f32(row2, zero, 1); } + output_ptr0++; } } - input_data += input_channel_stride; - output_data += output_channel_stride; - filter_data += filter_channel_stride; } - input_data += input_batch_stride; - output_data += output_batch_stride; + // pad bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); + } } -#endif } -void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu) { -#if __ARM_NEON - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->mutable_data(); - const float *bias_data; - if (if_bias) { - bias_data = bias->data(); - } - - const int in_h = static_cast(input->dims()[2]); - const int in_w = static_cast(input->dims()[3]); - const int out_h = static_cast(output->dims()[2]); - const int out_w = static_cast(output->dims()[3]); - const int out_l = out_h; - const int in_l = in_h; - const int inhxw = in_h * in_w; - const int outhxw = out_h * out_w; - /// todo : fix if_pad when w != h - const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0; - const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0; - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const float *input_row_ptr; - float *output_row_ptr; - - const int w_times = (out_w - 2) / 3; - - float32x4_t vbias = vdupq_n_f32(0.0); - - float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; - float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; - int out2in_mid; - float32x4_t zero = vdupq_n_f32(0.0); - for (int b = batch_size; b > 0; --b) { - const float *filter_data_tmp = filter_data; - for (int j = 0; j < c; ++j) { - auto output_data_tmp = output_data + j * out_h * out_w; - auto input_data_tmp = input_data + j * in_h * in_w; - auto input_const = input_data_tmp; - - if (if_bias) { - vbias = vdupq_n_f32(bias_data[j]); - } - - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - int h_mid = 0; - - for (; h_mid < out_h - 1; h_mid++) { - input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - - for (int w4 = 0; w4 < w_times + 1; w4++) { - if (h_mid == 0) { - elewise_res1 = zero; - elewise_res0 = zero; - elewise_res2 = zero; +template <> +void DepthwiseConv3x3S2(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output) { + const float *input_data = input.data(); + const float *filter_data = filter.data(); + float *out_data = output->mutable_data(); + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; + int valid_h_start = (padding_h + 1) / 2; + int valid_h_end = (input_h + padding_h - 1) / 2; + int valid_h = valid_h_end - valid_h_start; + int valid_w_start = (padding_w + 1) / 2; + int valid_w_end = (input_w + padding_w - 1) / 2; + int valid_w = valid_w_end - valid_w_start; + int input_w_start = 2 * valid_w_start - padding_w; + + #pragma omp parallel for + for (int g = 0; g < input.dims()[1]; ++g) { + const float *input_ptr = input_data + g * image_size; + const float *filter_ptr = filter_data + g * 9; + float *output_ptr = out_data + g * out_image_size; + + const float *filter_ptr0 = filter_ptr; + const float *filter_ptr1 = filter_ptr0 + 3; + const float *filter_ptr2 = filter_ptr1 + 3; + float32x4_t _ker[3]; + _ker[0] = vld1q_f32(filter_ptr0); + _ker[1] = vld1q_f32(filter_ptr1); + _ker[2] = vld1q_f32(filter_ptr2); + + // pad top + for (int h = 0; h < valid_h_start; ++h) { + DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); + } + // valid 2x4 + int output_w_tiles = valid_w / 4; + int output_w_remain = valid_w - output_w_tiles * 4; + for (int h = valid_h_start; h < valid_h_end - 1; h += 2) { + const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + const float *input_ptr4 = input_ptr3 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + float *output_ptr1 = output_ptr0 + output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - (w << 1); + if (padding >= 3) { + output_ptr0[w] = 0; + output_ptr1[w] = 0; } else { - elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - } - input_buff_mid = vld2q_f32(input_row_ptr); - input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - - elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); - elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); - elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - - elewise_res1 = - vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); - elewise_res0 = - vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); - elewise_res2 = - vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); - - res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - vaddq_f32(elewise_res0, elewise_res1)); - res3 = vaddq_f32(res3, vbias); - if (if_relu) { - res3 = vmaxq_f32(res3, zero); - } - vst1q_f32(output_row_ptr, res3); - - input_row_ptr += 6; - output_row_ptr += 3; - } - } - clock(); - - input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - - for (int w4 = 0; w4 < w_times + 1; w4++) { - elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - - input_buff_mid = vld2q_f32(input_row_ptr); - input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - - elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); - elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); - elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - - if (!if_pad_b) { - elewise_res1 = - vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); - elewise_res0 = - vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); - elewise_res2 = - vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); - } - res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - vaddq_f32(elewise_res0, elewise_res1)); - res3 = vaddq_f32(res3, vbias); - if (if_relu) { - res3 = vmaxq_f32(res3, zero); - } - - if ((w4 != w_times)) { - vst1q_f32(output_row_ptr, res3); - } else { - if (out_w - 2 - w_times * 3 == 1) { - vst1q_lane_f32(output_row_ptr, res3, 0); - } else if (out_w - 2 - w_times * 3 == 2) { - vst1q_lane_f32(output_row_ptr, res3, 0); - vst1q_lane_f32(output_row_ptr + 1, res3, 1); + float32x4_t row0 = vld1q_f32(input_ptr0 - padding); + float32x4_t row1 = vld1q_f32(input_ptr1 - padding); + float32x4_t row2 = vld1q_f32(input_ptr2 - padding); + float32x4_t row3 = vld1q_f32(input_ptr3 - padding); + float32x4_t row4 = vld1q_f32(input_ptr4 - padding); + float32x4_t acc0 = vmulq_f32(row0, _ker[0]); + float32x4_t acc1 = vmulq_f32(row2, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc1 = vmlaq_f32(acc1, row3, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + acc1 = vmlaq_f32(acc1, row4, _ker[2]); + float sum0 = vgetq_lane_f32(acc0, 2); + float sum1 = vgetq_lane_f32(acc1, 2); + if (padding == 1) { + sum0 += vgetq_lane_f32(acc0, 1); + sum1 += vgetq_lane_f32(acc1, 1); + } + output_ptr0[w] = sum0; + output_ptr1[w] = sum1; } } - input_row_ptr += 6; - output_row_ptr += 3; + input_ptr0 += input_w_start; + input_ptr1 += input_w_start; + input_ptr2 += input_w_start; + input_ptr3 += input_w_start; + input_ptr4 += input_w_start; + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; } - - // leftTop, rightTop, leftBottom, rightBottom - int lt = 0; - int rt = out_w - 1; - int lb = out_w * (out_h - 1); - int rb = out_h * out_w - 1; - - output_data_tmp[lt] = input_const[0] * w11 + input_const[1] * w12 + - input_const[in_w] * w21 + - input_const[in_w + 1] * w22; - - out2in_mid = (out_w - 1) * 2; - output_data_tmp[rt] = - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w] + - (1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] + - w22 * input_const[out2in_mid + in_w + 1]); - - out2in_mid = (out_h - 1) * 2 * in_w; - - output_data_tmp[lb] = - w01 * input_const[out2in_mid - in_w] + - w02 * input_const[out2in_mid - in_w + 1] + - w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + - (1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] + - w22 * input_const[out2in_mid + in_w + 1]); - out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2; - - output_data_tmp[rb] = - w00 * input_const[out2in_mid - in_w - 1] + - w01 * input_const[out2in_mid - in_w] + - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - (1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w]) + - (1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] + - w12 * input_const[out2in_mid + 1]) + - (1 - if_pad_r) * (1 - if_pad_b) * w22 * - input_const[out2in_mid + in_w + 1]; - if (if_bias) { - output_data_tmp[lt] += bias_data[j]; - output_data_tmp[rt] += bias_data[j]; - output_data_tmp[lb] += bias_data[j]; - output_data_tmp[rb] += bias_data[j]; + // valid + float32x4_t _result0, _result1, _ext; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4x2_t _row0 = vld2q_f32(input_ptr0); + float32x4x2_t _row1 = vld2q_f32(input_ptr1); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3); + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0); + + _row0 = vld2q_f32(input_ptr2); + _row1 = vld2q_f32(input_ptr3); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0); + _result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr3[8], _ext, 3); + _result1 = + vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0); + + _row0 = vld2q_f32(input_ptr4); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr4[8], _ext, 3); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0); + + vst1q_f32(output_ptr0, _result0); + vst1q_f32(output_ptr1, _result1); + + input_ptr0 += 8; + input_ptr1 += 8; + input_ptr2 += 8; + input_ptr3 += 8; + input_ptr4 += 8; + output_ptr0 += 4; + output_ptr1 += 4; } - if (if_relu) { - output_data_tmp[lt] = output_data_tmp[lt] < 0 ? 0 : output_data_tmp[lt]; - output_data_tmp[rt] = output_data_tmp[rt] < 0 ? 0 : output_data_tmp[rt]; - output_data_tmp[lb] = output_data_tmp[lb] < 0 ? 0 : output_data_tmp[lb]; - output_data_tmp[rb] = output_data_tmp[rb] < 0 ? 0 : output_data_tmp[rb]; - } - for (int i = 1; i < out_h - 1; i++) { - out2in_mid = i * 2 * in_w; - int left = i * out_w; - output_data_tmp[left] = w01 * input_const[out2in_mid - in_w] + - w02 * input_const[out2in_mid - in_w + 1] + - w11 * input_const[out2in_mid] + - w12 * input_const[out2in_mid + 1] + - w21 * input_const[out2in_mid + in_w] + - w22 * input_const[out2in_mid + in_w + 1]; - - out2in_mid = i * 2 * in_w + (out_w - 1) * 2; - int right = i * out_w + out_w - 1; - output_data_tmp[right] = - w00 * input_const[out2in_mid - in_w - 1] + - w01 * input_const[out2in_mid - in_w] + - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w] + - (1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] + - w12 * input_const[out2in_mid + 1] + - w22 * input_const[out2in_mid + in_w + 1]); - if (if_bias) { - output_data_tmp[left] += bias_data[j]; - output_data_tmp[right] += bias_data[j]; - } - if (if_relu) { - output_data_tmp[left] = - output_data_tmp[left] < 0 ? 0 : output_data_tmp[left]; - output_data_tmp[right] = - output_data_tmp[right] < 0 ? 0 : output_data_tmp[right]; + // remain w + if (output_w_remain > 0) { + float32x4x2_t _row0 = vld2q_f32(input_ptr0); + float32x4x2_t _row1 = vld2q_f32(input_ptr1); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3); + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0); + + _row0 = vld2q_f32(input_ptr2); + _row1 = vld2q_f32(input_ptr3); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0); + _result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr3[8], _ext, 3); + _result1 = + vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0); + + _row0 = vld2q_f32(input_ptr4); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr4[8], _ext, 3); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0); + + switch (output_w_remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _result0, 2); + vst1q_lane_f32(output_ptr1 + 2, _result1, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_result0)); + vst1_f32(output_ptr1, vget_low_f32(_result1)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _result0, 0); + vst1q_lane_f32(output_ptr1, _result1, 0); + break; } + input_ptr0 += output_w_remain * 2; + input_ptr1 += output_w_remain * 2; + input_ptr2 += output_w_remain * 2; + input_ptr3 += output_w_remain * 2; + input_ptr4 += output_w_remain * 2; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; } - filter_data_tmp += 9; - } - input_data += inhxw * c; - output_data += outhxw * c; - } -#endif -} - -void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu) { -#if __ARM_NEON - // #ifdef _OPENMP - // const float *newscale_data = new_scale->data(); - // const float *newbias_data = new_bias->data(); - // - // const int batch_size = static_cast(input->dims()[0]); - // const int input_channel = static_cast(input->dims()[1]); - // - // const int input_height = static_cast(input->dims()[2]); - // const int input_width = static_cast(input->dims()[3]); - // const int output_height = static_cast(output->dims()[2]); - // const int output_width = static_cast(output->dims()[3]); - // const int inhxw = input_height * input_width; - // const int outhxw = output_height * output_width; - // - // float32x4_t zero = vdupq_n_f32(0.0); - // for (int b = 0; b < batch_size; b++) { - // #pragma omp parallel for - // for (int c = 0; c < input_channel; c++) { - // const float *filter_data = filter->data() + c * 9; - // const float *input_data = input->data() + c * inhxw; - // float *output_data = output->data() + c * outhxw; - // float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]); - // float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]); - // - // float w00 = filter_data[0]; - // float w01 = filter_data[1]; - // float w02 = filter_data[2]; - // float w10 = filter_data[3]; - // float w11 = filter_data[4]; - // float w12 = filter_data[5]; - // float w20 = filter_data[6]; - // float w21 = filter_data[7]; - // float w22 = filter_data[8]; - // - // int m; - // for (m = 1; m < output_width - 2; m = m + 3) { - // float *output_ptr = output_data + m; - // float32x4x2_t input_buff_mid{}, input_buff_bottom{}; - // float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; - // input_buff_mid = vld2q_f32(input_data + (2 * m - 1)); - // input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m - - // 1)); - // - // in0 = input_buff_mid.val[0]; - // tmp0 = input_buff_mid.val[1]; - // tmp1 = vextq_f32(in0, zero, 1); - // - // in2 = input_buff_bottom.val[0]; - // tmp2 = input_buff_bottom.val[1]; - // tmp3 = vextq_f32(in2, zero, 1); - // - // out0 = vmulq_n_f32(in0, w10); - // out0 = vmlaq_n_f32(out0, tmp0, w11); - // out0 = vmlaq_n_f32(out0, tmp1, w12); - // out0 = vmlaq_n_f32(out0, in2, w20); - // out0 = vmlaq_n_f32(out0, tmp2, w21); - // out0 = vmlaq_n_f32(out0, tmp3, w22); - // out0 = vmlaq_f32(vnewbias, vnewscale, out0); - // if (if_relu) { - // out0 = vmaxq_f32(out0, zero); - // } - // vst1q_lane_f32(output_ptr, out0, 0); - // vst1q_lane_f32(output_ptr + 1, out0, 1); - // vst1q_lane_f32(output_ptr + 2, out0, 2); - // } - // for (m = 1; m < output_width - 2; m += 3) { - // } - // for (int j = m; j < output_width; j++) { - // output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] * - // w11 + - // input_data[2 * j + 1] * w12 + - // input_data[2 * j - 1 + input_width] * w20 + - // input_data[2 * j + input_width] * w21 + - // input_data[2 * j + 1 + input_width] * w22; - // output_data[j] = newscale_data[c] * output_data[j] + - // newbias_data[c]; if (if_relu) { - // output_data[j] = output_data[j] < 0 ? 0 : output_data[j]; - // } - // } - // - // for (int i = 1; i < output_height; i += 1) { - // for (int m = 1; m < output_width - 2; m += 3) { - // float *output_ptr = output_data + i * output_width + m; - // float32x4x2_t input_buff_top{}, input_buff_mid{}, - // input_buff_bottom{}; float32x4_t in0, in1, in2, in3, in4, in5, - // tmp0, tmp1, tmp2, tmp3, - // tmp4, tmp5, out0; - // input_buff_top = - // vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m - - // 1)); - // input_buff_mid = - // vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1)); - // input_buff_bottom = - // vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m - - // 1)); - // - // in0 = input_buff_top.val[0]; - // tmp0 = input_buff_top.val[1]; - // tmp1 = vextq_f32(in0, zero, 1); - // - // in2 = input_buff_mid.val[0]; - // tmp2 = input_buff_mid.val[1]; - // tmp3 = vextq_f32(in2, zero, 1); - // - // in4 = input_buff_bottom.val[0]; - // tmp4 = input_buff_bottom.val[1]; - // tmp5 = vextq_f32(in4, zero, 1); - // - // out0 = vmulq_n_f32(in0, w00); - // out0 = vmlaq_n_f32(out0, tmp0, w01); - // out0 = vmlaq_n_f32(out0, tmp1, w02); - // out0 = vmlaq_n_f32(out0, in2, w10); - // out0 = vmlaq_n_f32(out0, tmp2, w11); - // out0 = vmlaq_n_f32(out0, tmp3, w12); - // out0 = vmlaq_n_f32(out0, in4, w20); - // out0 = vmlaq_n_f32(out0, tmp4, w21); - // out0 = vmlaq_n_f32(out0, tmp5, w22); - // out0 = vmlaq_f32(vnewbias, vnewscale, out0); - // if (if_relu) { - // out0 = vmaxq_f32(out0, zero); - // } - // vst1q_lane_f32(output_ptr, out0, 0); - // vst1q_lane_f32(output_ptr + 1, out0, 1); - // vst1q_lane_f32(output_ptr + 2, out0, 2); - // } - // int m; - // for (m = 1; m < output_width - 2; m += 3) { - // } - // for (int j = m; j < output_width; j++) { - // output_data[i * output_width + j] = - // input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 + - // input_data[(2 * i - 1) * input_width + 2 * j] * w01 + - // input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 + - // input_data[(2 * i) * input_width + 2 * j - 1] * w10 + - // input_data[(2 * i) * input_width + 2 * j] * w11 + - // input_data[(2 * i) * input_width + 2 * j + 1] * w12 + - // input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 + - // input_data[(2 * i + 1) * input_width + 2 * j] * w21 + - // input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22; - // output_data[i * output_width + j] = - // newscale_data[c] * output_data[i * output_width + j] + - // newbias_data[c]; - // if (if_relu) { - // output_data[i * output_width + j] = - // output_data[i * output_width + j] < 0 - // ? 0 - // : output_data[i * output_width + j]; - // } - // } - // } - // output_data[0] = input_data[0] * w11 + input_data[1] * w12 + - // input_data[input_height] * w21 + - // input_data[input_height + 1] * w22; - // - // output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c]; - // if (if_relu) { - // output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; - // } - // for (int i = 1; i < output_height; i++) { - // output_data[i * output_width] = - // input_data[(2 * i - 1) * input_width] * w01 + - // input_data[(2 * i - 1) * input_width + 1] * w02 + - // input_data[(2 * i) * input_width] * w11 + - // input_data[(2 * i) * input_width + 1] * w12 + - // input_data[(2 * i + 1) * input_width] * w21 + - // input_data[(2 * i + 1) * input_width + 1] * w22; - // - // output_data[i * output_width] = - // newscale_data[c] * output_data[i * output_width] + - // newbias_data[c]; - // if (if_relu) { - // output_data[i * output_width] = output_data[i * output_width] < 0 - // ? 0 - // : output_data[i * - // output_width]; - // } - // } - // } - // } - // - // #else - - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->mutable_data(); - const float *newscale_data = new_scale->data(); - const float *newbias_data = new_bias->data(); - - const int in_h = static_cast(input->dims()[2]); - const int in_w = static_cast(input->dims()[3]); - const int out_h = static_cast(output->dims()[2]); - const int out_w = static_cast(output->dims()[3]); - // const int out_l = out_h; - // const int in_l = in_h; - const int inhxw = in_h * in_w; - const int outhxw = out_h * out_w; - /// todo : fix if_pad when w != h - const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0; - const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0; - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int w_times = (out_w - 2) / 3; - float32x4_t zero = vdupq_n_f32(0.0); - for (int b = batch_size; b > 0; --b) { -#pragma omp parallel for - for (int j = 0; j < c; j++) { - const float *input_row_ptr; - float *output_row_ptr; - float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; - float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; - int out2in_mid; - float32x4_t vnewbias = vdupq_n_f32(0.0); - float32x4_t vnewscale = vdupq_n_f32(1.0); - auto output_data_tmp = output_data + j * out_h * out_w; - auto input_data_tmp = input_data + j * in_h * in_w; - auto input_const = input_data_tmp; - const float *filter_data_tmp = filter_data + 9 * j; - vnewbias = vdupq_n_f32(newbias_data[j]); - vnewscale = vdupq_n_f32(newscale_data[j]); - - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - int h_mid = 0; - - for (; h_mid < out_h - 1; h_mid++) { - input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - - for (int w4 = 0; w4 < w_times + 1; w4++) { - if (h_mid == 0) { - elewise_res1 = zero; - elewise_res0 = zero; - elewise_res2 = zero; + // pad right + if (padding_w > 0) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t row3 = vld1q_f32(input_ptr3); + float32x4_t row4 = vld1q_f32(input_ptr4); + float32x4_t acc0, acc1; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + *output_ptr1 = 0; } else { - elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - } - input_buff_mid = vld2q_f32(input_row_ptr); - input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - - elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); - elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); - elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - - elewise_res1 = - vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); - elewise_res0 = - vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); - elewise_res2 = - vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); - - res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - vaddq_f32(elewise_res0, elewise_res1)); - res3 = vmlaq_f32(vnewbias, vnewscale, res3); - - if (if_relu) { - res3 = vmaxq_f32(res3, zero); + acc0 = vmulq_f32(row0, _ker[0]); + acc1 = vmulq_f32(row2, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc1 = vmlaq_f32(acc1, row3, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + acc1 = vmlaq_f32(acc1, row4, _ker[2]); + float sum0 = vgetq_lane_f32(acc0, 0); + float sum1 = vgetq_lane_f32(acc1, 0); + if (padding == 1) { + sum0 += vgetq_lane_f32(acc0, 1); + sum1 += vgetq_lane_f32(acc1, 1); + } + *output_ptr0 = sum0; + *output_ptr1 = sum1; } - vst1q_lane_f32(output_row_ptr, res3, 0); - vst1q_lane_f32(output_row_ptr + 1, res3, 1); - vst1q_lane_f32(output_row_ptr + 2, res3, 2); - - input_row_ptr += 6; - output_row_ptr += 3; + output_ptr0++; + output_ptr1++; } } - clock(); - - input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - - for (int w4 = 0; w4 < w_times + 1; w4++) { - elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - - input_buff_mid = vld2q_f32(input_row_ptr); - input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - - elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); - elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); - elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - - if (!if_pad_b) { - elewise_res1 = - vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); - elewise_res0 = - vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); - elewise_res2 = - vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); - } - res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - vaddq_f32(elewise_res0, elewise_res1)); - res3 = vmlaq_f32(vnewbias, vnewscale, res3); - - if (if_relu) { - res3 = vmaxq_f32(res3, zero); - } - if ((w4 != w_times)) { - vst1q_lane_f32(output_row_ptr, res3, 0); - vst1q_lane_f32(output_row_ptr + 1, res3, 1); - vst1q_lane_f32(output_row_ptr + 2, res3, 2); - } else { - if (out_w - 2 - w_times * 3 == 1) { - vst1q_lane_f32(output_row_ptr, res3, 0); - } else if (out_w - 2 - w_times * 3 == 2) { - vst1q_lane_f32(output_row_ptr, res3, 0); - vst1q_lane_f32(output_row_ptr + 1, res3, 1); + } + // remain height + int start_h = valid_h_start + (valid_h & 0xfffe); + if (start_h < valid_h_end) { + const float *input_ptr0 = input_ptr + (2 * start_h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + float *output_ptr0 = output_ptr + start_h * output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - (w << 1); + if (padding >= 3) { + output_ptr0[w] = 0; + } else { + float32x4_t row0 = vld1q_f32(input_ptr0 - padding); + float32x4_t row1 = vld1q_f32(input_ptr1 - padding); + float32x4_t row2 = vld1q_f32(input_ptr2 - padding); + float32x4_t acc0 = vmulq_f32(row0, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + float sum0 = vgetq_lane_f32(acc0, 2); + if (padding == 1) { + sum0 += vgetq_lane_f32(acc0, 1); + } + output_ptr0[w] = sum0; } } - input_row_ptr += 6; - output_row_ptr += 3; + input_ptr0 += input_w_start; + input_ptr1 += input_w_start; + input_ptr2 += input_w_start; + output_ptr0 += valid_w_start; } - - output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + - input_const[in_w] * w21 + - input_const[in_w + 1] * w22; - - out2in_mid = (out_w - 1) * 2; - output_data_tmp[out_w - 1] = - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w] + - (1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] + - w22 * input_const[out2in_mid + in_w + 1]); - - out2in_mid = (out_h - 1) * 2 * in_w; - - output_data_tmp[out_w * (out_h - 1)] = - w01 * input_const[out2in_mid - in_w] + - w02 * input_const[out2in_mid - in_w + 1] + - w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + - (1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] + - w22 * input_const[out2in_mid + in_w + 1]); - out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2; - - output_data_tmp[out_h * out_w - 1] = - w00 * input_const[out2in_mid - in_w - 1] + - w01 * input_const[out2in_mid - in_w] + - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - (1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w]) + - (1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] + - w12 * input_const[out2in_mid + 1]) + - (1 - if_pad_r) * (1 - if_pad_b) * w22 * - input_const[out2in_mid + in_w + 1]; - output_data_tmp[0] = - output_data_tmp[0] * newscale_data[j] + newbias_data[j]; - output_data_tmp[out_w - 1] = - output_data_tmp[out_w - 1] * newscale_data[j] + newbias_data[j]; - output_data_tmp[out_w * (out_h - 1)] = - output_data_tmp[out_w * (out_h - 1)] * newscale_data[j] + - newbias_data[j]; - output_data_tmp[out_h * out_w - 1] = - output_data_tmp[out_h * out_w - 1] * newscale_data[j] + - newbias_data[j]; - if (if_relu) { - output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0]; - output_data_tmp[out_w - 1] = - output_data_tmp[out_w - 1] < 0 ? 0 : output_data_tmp[out_w - 1]; - output_data_tmp[out_w * (out_h - 1)] = - output_data_tmp[out_w * (out_h - 1)] < 0 - ? 0 - : output_data_tmp[out_w * (out_h - 1)]; - output_data_tmp[out_h * out_w - 1] = - output_data_tmp[out_h * out_w - 1] < 0 - ? 0 - : output_data_tmp[out_h * out_w - 1]; + // valid + float32x4_t _result0, _ext; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4x2_t _row0 = vld2q_f32(input_ptr0); + float32x4x2_t _row1 = vld2q_f32(input_ptr1); + float32x4x2_t _row2 = vld2q_f32(input_ptr2); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3); + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0); + + _ext = vextq_f32(_row2.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0); + + vst1q_f32(output_ptr0, _result0); + + input_ptr0 += 8; + input_ptr1 += 8; + input_ptr2 += 8; + output_ptr0 += 4; } - for (int i = 1; i < out_h - 1; i++) { - out2in_mid = i * 2 * in_w; - output_data_tmp[i * out_w] = w01 * input_const[out2in_mid - in_w] + - w02 * input_const[out2in_mid - in_w + 1] + - w11 * input_const[out2in_mid] + - w12 * input_const[out2in_mid + 1] + - w21 * input_const[out2in_mid + in_w] + - w22 * input_const[out2in_mid + in_w + 1]; - - out2in_mid = i * 2 * in_w + (out_w - 1) * 2; - output_data_tmp[i * out_w + out_w - 1] = - w00 * input_const[out2in_mid - in_w - 1] + - w01 * input_const[out2in_mid - in_w] + - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w] + - (1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] + - w12 * input_const[out2in_mid + 1] + - w22 * input_const[out2in_mid + in_w + 1]); - output_data_tmp[i * out_w] = - output_data_tmp[i * out_w] * newscale_data[j] + newbias_data[j]; - output_data_tmp[i * out_w + out_w - 1] = - output_data_tmp[i * out_w + out_w - 1] * newscale_data[j] + - newbias_data[j]; - if (if_relu) { - output_data_tmp[i * out_w] = - output_data_tmp[i * out_w] < 0 ? 0 : output_data_tmp[i * out_w]; - output_data_tmp[i * out_w + out_w - 1] = - output_data_tmp[i * out_w + out_w - 1] < 0 - ? 0 - : output_data_tmp[i * out_w + out_w - 1]; + // remain w + if (output_w_remain > 0) { + float32x4x2_t _row0 = vld2q_f32(input_ptr0); + float32x4x2_t _row1 = vld2q_f32(input_ptr1); + float32x4x2_t _row2 = vld2q_f32(input_ptr2); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3); + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0); + + _ext = vextq_f32(_row2.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0); + + switch (output_w_remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _result0, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_result0)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _result0, 0); + break; } + input_ptr0 += output_w_remain * 2; + input_ptr1 += output_w_remain * 2; + input_ptr2 += output_w_remain * 2; + output_ptr0 += output_w_remain; } - } - input_data += inhxw * c; - output_data += outhxw * c; - } -// #endif -#endif -} - -void DepthwiseConv3x3s2p0(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu) { -#if __ARM_NEON - const int batch_size = static_cast(input->dims()[0]); - const int input_channel = static_cast(input->dims()[1]); - - const int input_height = static_cast(input->dims()[2]); - const int input_width = static_cast(input->dims()[3]); - const int output_height = static_cast(output->dims()[2]); - const int output_width = static_cast(output->dims()[3]); - const int inhxw = input_height * input_width; - const int outhxw = output_height * output_width; - output->mutable_data(); - - float32x4_t zero = vdupq_n_f32(0.0); - for (int b = 0; b < batch_size; b++) { -#pragma omp parallel for - for (int c = 0; c < input_channel; c++) { - const float *filter_data = filter->data() + c * 9; - const float *input_data = input->data() + c * inhxw; - const float *bias_data; - float32x4_t biasv; - if (if_bias) { - bias_data = bias->data() + c; - biasv = vld1q_dup_f32(bias_data); - } - float *output_data = output->data() + c * outhxw; - float w00 = filter_data[0]; - float w01 = filter_data[1]; - float w02 = filter_data[2]; - float w10 = filter_data[3]; - float w11 = filter_data[4]; - float w12 = filter_data[5]; - float w20 = filter_data[6]; - float w21 = filter_data[7]; - float w22 = filter_data[8]; - for (int i = 0; i < output_height; i += 1) { - for (int m = 0; m < output_width - 2; m += 3) { - float *output_ptr = output_data + i * output_width + m; - float32x4x2_t input_buff_top{}, input_buff_mid{}, input_buff_bottom{}; - float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, - tmp4, tmp5, out0; - input_buff_top = - vld2q_f32(input_data + (2 * i) * input_width + (2 * m)); - input_buff_mid = - vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m)); - input_buff_bottom = - vld2q_f32(input_data + (2 * i + 2) * input_width + (2 * m)); - - in0 = input_buff_top.val[0]; - tmp0 = input_buff_top.val[1]; - tmp1 = vextq_f32(in0, zero, 1); - - in2 = input_buff_mid.val[0]; - tmp2 = input_buff_mid.val[1]; - tmp3 = vextq_f32(in2, zero, 1); - - in4 = input_buff_bottom.val[0]; - tmp4 = input_buff_bottom.val[1]; - tmp5 = vextq_f32(in4, zero, 1); - - out0 = vmulq_n_f32(in0, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - if (if_bias) { - out0 = vaddq_f32(out0, biasv); - } - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - vst1q_lane_f32(output_ptr, out0, 0); - vst1q_lane_f32(output_ptr + 1, out0, 1); - vst1q_lane_f32(output_ptr + 2, out0, 2); - } - int m; - for (m = 0; m < output_width - 2; m += 3) { - } - for (int j = m; j < output_width; j++) { - int index = i * output_width + j; - output_data[index] = - input_data[(2 * i) * input_width + 2 * j] * w00 + - input_data[(2 * i) * input_width + 2 * j + 1] * w01 + - input_data[(2 * i) * input_width + 2 * j + 2] * w02 + - input_data[(2 * i + 1) * input_width + 2 * j] * w10 + - input_data[(2 * i + 1) * input_width + 2 * j + 1] * w11 + - input_data[(2 * i + 1) * input_width + 2 * j + 2] * w12 + - input_data[(2 * i + 2) * input_width + 2 * j] * w20 + - input_data[(2 * i + 2) * input_width + 2 * j + 1] * w21 + - input_data[(2 * i + 2) * input_width + 2 * j + 2] * w22; - if (if_bias) { - output_data[index] += *bias_data; - } - if (if_relu) { - output_data[index] = - output_data[index] < 0 ? 0 : output_data[index]; + // pad right + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t acc0; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + } else { + acc0 = vmulq_f32(row0, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + float sum0 = vgetq_lane_f32(acc0, 0); + if (padding == 1) { + sum0 += vgetq_lane_f32(acc0, 1); + } + *output_ptr0 = sum0; } + output_ptr0++; } } } + // pad bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); + } } - -#endif } } // namespace math } // namespace operators } // namespace paddle_mobile + +#endif // __ARM_NEON__ diff --git a/src/operators/math/depthwise_conv3x3.h b/src/operators/math/depthwise_conv3x3.h index fde5d878c8..9b9c5c0a6d 100644 --- a/src/operators/math/depthwise_conv3x3.h +++ b/src/operators/math/depthwise_conv3x3.h @@ -23,48 +23,6 @@ namespace paddle_mobile { namespace operators { namespace math { -void DepthwiseConv3x3(const framework::Tensor *input, - const std::vector &strides, - const std::vector &paddings, - const framework::Tensor *filter, framework::Tensor *bias, - framework::Tensor *output, bool if_bias); - -void DepthwiseConv3x3s1p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu); - -void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu); - -void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu); - -void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu); - -void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu); - -void DepthwiseConv3x3s2p0(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu); - // TODO(hjchen2) need to be implemented // template // void DepthwiseConv3x3(const framework::Tensor *input, diff --git a/src/operators/math/gemm/cblas.cc b/src/operators/math/gemm/cblas.cc index ccca4d7681..0cda7197f7 100644 --- a/src/operators/math/gemm/cblas.cc +++ b/src/operators/math/gemm/cblas.cc @@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, // return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); // } - CPUInfo *info = CPUInfo::Info(); - GemmExecutor exec(info, transA, transB, M, N, K); + GemmExecutor exec(transA, transB, M, N, K); exec(alpha, A, lda, B, ldb, beta, C, ldc); } void cblas_sgemv(const bool trans, const int M, const int N, const float alpha, const float *A, const int lda, const float *B, const float beta, float *C) { - CPUInfo *info = CPUInfo::Info(); - GemvExecutor exec(info, trans, M, N); + GemvExecutor exec(trans, M, N); exec(alpha, A, lda, B, beta, C); } diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h index b6ca66eb7e..93bf90c8b7 100644 --- a/src/operators/math/gemm/executor.h +++ b/src/operators/math/gemm/executor.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #endif #include -#include #include "common/log.h" #include "memory/t_malloc.h" #include "operators/math/gemm/cpu_info.h" @@ -29,6 +28,8 @@ namespace paddle_mobile { namespace operators { namespace math { +static CPUInfo *info = CPUInfo::Info(); + int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num, const int N, const int K) { @@ -62,15 +63,9 @@ class GemmExecutor : public Executor { typedef typename Strategy::Otype Otype; public: - GemmExecutor(const CPUInfo *info, const bool transA, const bool transB, - const int M, const int N, const int K) - : Executor(), - info_(info), - transA_(transA), - transB_(transB), - M_(M), - N_(N), - K_(K) { + GemmExecutor(const bool transA, const bool transB, const int M, const int N, + const int K) + : Executor(), transA_(transA), transB_(transB), M_(M), N_(N), K_(K) { unsigned int L1_size = 0; unsigned int L2_size = 0; if (M_ > N_) { @@ -212,8 +207,6 @@ class GemmExecutor : public Executor { virtual ~GemmExecutor() {} private: - const CPUInfo *info_; - const unsigned int M_; const unsigned int N_; const unsigned int K_; @@ -242,8 +235,8 @@ class GemvExecutor : public Executor { typedef typename Strategy::Otype Otype; public: - GemvExecutor(const CPUInfo *info, const bool transA, const int M, const int N) - : Executor(), info_(info), M_(M), N_(N) {} + GemvExecutor(const bool transA, const int M, const int N) + : Executor(), M_(M), N_(N) {} void operator()(const float alpha, const Itype *A, const int lda, const Itype *B, const float beta, Otype *C) { @@ -253,8 +246,6 @@ class GemvExecutor : public Executor { virtual ~GemvExecutor() {} private: - const CPUInfo *const info_; - const unsigned int M_; const unsigned int N_; diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 9449ad7081..fedd17ed0c 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, for (int i = start_height; i < end_height; i += stride_h) { if (stride_w == 1) { - memcpy(col_data, im_data, extract * sizeof(float)); + // memcpy(col_data, im_data, extract * sizeof(float)); + int s = 0; +#if __ARM_NEON + for (; s < extract - 3; s += 4) { + float32x4_t img = vld1q_f32(im_data + s); + vst1q_f32(col_data + s, img); + } +#endif + for (; s < extract; ++s) { + col_data[s] = im_data[s]; + } } else if (stride_w == 2) { int s = 0; #if __ARM_NEON @@ -109,325 +119,7 @@ void Im2ColFunctor::operator()( const float *im_data = im.data(); float *col_data = col->data(); #if __ARM_NEON - const int osize = col_height; - const int isize = im_height; - bool pad1 = padding[0] > 0; - bool pad2 = - (pad1 && padding[1] && - (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); - int fill = isize % 2; - if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && - dilation[0] == 1 && im_height > 2 && im_height == im_width) { - for (int c = 0; c < im_channels; ++c) { - int oosize = osize * osize; - int nk4 = osize / 4; - int mk4 = osize % 4; - - float *col0 = col_data + 0 * oosize + 2 * osize + 2; - float *col1 = col_data + 1 * oosize + 2 * osize + 1; - float *col2 = col_data + 2 * oosize + 2 * osize; - - float *col3 = col_data + 3 * oosize + osize + 2; - float *col4 = col_data + 4 * oosize + osize + 1; - float *col5 = col_data + 5 * oosize + osize; - - float *col6 = col_data + 6 * oosize + 2; - float *col7 = col_data + 7 * oosize + 1; - float *col8 = col_data + 8 * oosize; - - float32x4_t im1; - const float *im_tmp_data = im_data + osize + 1; - - int rrsize = oosize - osize - 1; - int nr4 = rrsize / 4; - int mr4 = rrsize % 4; - for (int i = 0; i < nr4; ++i) { - im1 = vld1q_f32(im_tmp_data); - vst1q_f32(col0, im1); - vst1q_f32(col1, im1); - vst1q_f32(col2, im1); - vst1q_f32(col3, im1); - vst1q_f32(col4, im1); - vst1q_f32(col5, im1); - vst1q_f32(col6, im1); - vst1q_f32(col7, im1); - vst1q_f32(col8, im1); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - col6 += 4; - col7 += 4; - col8 += 4; - - im_tmp_data += 4; - } - for (int i = 0; i < mr4; ++i) { - *col0 = *im_tmp_data; - *col1 = *im_tmp_data; - *col2 = *im_tmp_data; - *col3 = *im_tmp_data; - *col4 = *im_tmp_data; - *col5 = *im_tmp_data; - *col6 = *im_tmp_data; - *col7 = *im_tmp_data; - *col8 = *im_tmp_data; - - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - col6++; - col7++; - col8++; - - im_tmp_data++; - } - - im_tmp_data = im_data + 1; - col0 = col_data + 0 * oosize + osize + 2; - col1 = col_data + 1 * oosize + osize + 1; - col2 = col_data + 2 * oosize + osize; - - col3 = col_data + 3 * oosize + 2; - col4 = col_data + 4 * oosize + 1; - col5 = col_data + 5 * oosize; - - for (int i = 0; i < nk4; i++) { - im1 = vld1q_f32(im_tmp_data); - vst1q_f32(col0, im1); - vst1q_f32(col1, im1); - vst1q_f32(col2, im1); - vst1q_f32(col3, im1); - vst1q_f32(col4, im1); - vst1q_f32(col5, im1); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - im_tmp_data += 4; - } - - for (int i = 0; i < mk4; i++) { - *col0 = *im_tmp_data; - *col1 = *im_tmp_data; - *col2 = *im_tmp_data; - *col3 = *im_tmp_data; - *col4 = *im_tmp_data; - *col5 = *im_tmp_data; - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - - im_tmp_data++; - } - - // fill 0 1 11; - for (int i = 0; i < osize; ++i) { - col_data[0 * oosize + i * osize] = 0.0; - col_data[3 * oosize + i * osize] = 0.0; - col_data[6 * oosize + i * osize] = 0.0; - - col_data[2 * oosize + osize - 1 + i * osize] = 0.0; - col_data[5 * oosize + osize - 1 + i * osize] = 0.0; - col_data[8 * oosize + osize - 1 + i * osize] = 0.0; - } - - col_data[0 * oosize + osize + 1] = im_data[0]; - col_data[3 * oosize + 1] = im_data[0]; - col_data[6 * oosize + 1] = im_data[osize]; - - col_data[1 * oosize + osize] = im_data[0]; - col_data[4 * oosize] = im_data[0]; - col_data[7 * oosize] = im_data[osize]; - - float32x4_t zero4; - zero4 = vdupq_n_f32(0.0); - auto col_z0 = col_data; - auto col_z1 = col_data + oosize; - auto col_z2 = col_data + 2 * oosize; - auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); - auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); - auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); - - for (int i = 0; i < nk4; ++i) { - vst1q_f32(col_z0, zero4); - vst1q_f32(col_z1, zero4); - vst1q_f32(col_z2, zero4); - vst1q_f32(col_z6, zero4); - vst1q_f32(col_z7, zero4); - vst1q_f32(col_z8, zero4); - - col_z0 += 4; - col_z1 += 4; - col_z2 += 4; - col_z6 += 4; - col_z7 += 4; - col_z8 += 4; - } - - for (int i = 0; i < mk4; ++i) { - col_z0[i] = 0.0; - col_z1[i] = 0.0; - col_z2[i] = 0.0; - col_z6[i] = 0.0; - col_z7[i] = 0.0; - col_z8[i] = 0.0; - } - col_data += 9 * oosize; - im_data += isize * isize; - } - } else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 && - im_height > 2 && im_height == im_width) { - for (int c = 0; c < im_channels; ++c) { - int oosize = osize * osize; - int nk4 = osize / 4; - int mk4 = osize % 4; - - // 3 2 3 1 0 1 3 2 3 - float *col0 = col_data + 0 * oosize + osize + 1; - float *col1 = col_data + 1 * oosize + osize; - float *col2 = col_data + 2 * oosize + osize; - - float *col3 = col_data + 3 * oosize + 1; - float *col4 = col_data + 4 * oosize; - float *col5 = col_data + 5 * oosize; - - float *col6 = col_data + 6 * oosize + 1; - float *col7 = col_data + 7 * oosize; - float *col8 = col_data + 8 * oosize; - - float32x4x2_t im01; - float32x4x2_t im23; - const float *im_tmp_data0 = im_data; - const float *im_tmp_data2 = im_data + isize; - - for (int j = 0; j < osize; ++j) { - for (int i = 0; i < nk4; ++i) { - im01 = vld2q_f32(im_tmp_data0); - im23 = vld2q_f32(im_tmp_data2); - vst1q_f32(col0, im23.val[1]); - vst1q_f32(col1, im23.val[0]); - vst1q_f32(col2, im23.val[1]); - vst1q_f32(col3, im01.val[1]); - vst1q_f32(col4, im01.val[0]); - vst1q_f32(col5, im01.val[1]); - vst1q_f32(col6, im23.val[1]); - vst1q_f32(col7, im23.val[0]); - vst1q_f32(col8, im23.val[1]); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - col6 += 4; - col7 += 4; - col8 += 4; - - im_tmp_data0 += 8; - im_tmp_data2 += 8; - } - const float *im_tmp_data1 = im_tmp_data0 + 1; - const float *im_tmp_data3 = im_tmp_data2 + 1; - for (int i = 0; i < mk4; ++i) { - *col0 = *im_tmp_data3; - *col1 = *im_tmp_data2; - *col2 = *im_tmp_data3; - *col3 = *im_tmp_data1; - *col4 = *im_tmp_data0; - *col5 = *im_tmp_data1; - *col6 = *im_tmp_data3; - *col7 = *im_tmp_data2; - *col8 = *im_tmp_data3; - - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - col6++; - col7++; - col8++; - im_tmp_data0 += 2; - im_tmp_data1 += 2; - im_tmp_data2 += 2; - im_tmp_data3 += 2; - } - im_tmp_data0 += (isize - fill); - im_tmp_data2 += (isize - fill); - } - for (int i = 0; i < osize; ++i) { - col_data[0 * oosize + i * osize] = 0.0; - col_data[3 * oosize + i * osize] = 0.0; - col_data[6 * oosize + i * osize] = 0.0; - if (pad2) { - col_data[2 * oosize + osize - 1 + i * osize] = 0.0; - col_data[5 * oosize + osize - 1 + i * osize] = 0.0; - col_data[8 * oosize + osize - 1 + i * osize] = 0.0; - } - } - float32x4_t zero4; - zero4 = vdupq_n_f32(0.0); - auto col_z0 = col_data; - auto col_z1 = col_data + oosize; - auto col_z2 = col_data + 2 * oosize; - auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); - auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); - auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); - - for (int i = 0; i < nk4; ++i) { - vst1q_f32(col_z0, zero4); - vst1q_f32(col_z1, zero4); - vst1q_f32(col_z2, zero4); - if (pad2) { - vst1q_f32(col_z6, zero4); - vst1q_f32(col_z7, zero4); - vst1q_f32(col_z8, zero4); - } - col_z0 += 4; - col_z1 += 4; - col_z2 += 4; - col_z6 += 4; - col_z7 += 4; - col_z8 += 4; - } - - for (int i = 0; i < mk4; ++i) { - col_z0[i] = 0.0; - col_z1[i] = 0.0; - col_z2[i] = 0.0; - if (pad2) { - col_z6[i] = 0.0; - col_z7[i] = 0.0; - col_z8[i] = 0.0; - } - } - - col_data[1 * oosize + osize] = im_data[isize]; - for (int i = 1; i < osize; ++i) { - col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1]; - } - col_data[4 * oosize] = im_data[0]; - col_data[7 * oosize] = im_data[isize]; - - col_data += 9 * oosize; - im_data += isize * isize; - } - } else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { + if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { int im_spatial_size = im_height * im_width; int col_spatial_size = col_height * col_width; // pad 0 diff --git a/src/operators/op_param.h b/src/operators/op_param.h index a735fbee48..6bd2470cb4 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -441,10 +441,8 @@ class ConvParam : public OpParam { enum ExecMode { EXEC_INVALID = 0, EXEC_GEMM_FLOAT, - EXEC_DEPTHWISE3x3S1P1_FLOAT, - EXEC_DEPTHWISE3x3S2P0_FLOAT, - EXEC_DEPTHWISE3x3S2P1_FLOAT, - EXEC_DEPTHWISE3x3_FLOAT, + EXEC_DEPTHWISE3x3S1_FLOAT, + EXEC_DEPTHWISE3x3S2_FLOAT, EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD5X5_FLOAT, EXEC_DEPTHWISE5x5_FLOAT, -- GitLab