提交 532cff71 编写于 作者: H hjchen2

Refator depthwise conv3x3 and fix it's bugs for armv8

上级 cb5e15b9
...@@ -61,25 +61,15 @@ template <> ...@@ -61,25 +61,15 @@ template <>
void ConvAddBNReluKernel<CPU, float>::Compute( void ConvAddBNReluKernel<CPU, float>::Compute(
const FusionConvAddBNReluParam<CPU> &param) { const FusionConvAddBNReluParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Output(), param.NewScale(), param.Paddings(), param.Output());
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP #ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h" #include "operators/kernel/conv_add_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_add_arm_func.h" #include "operators/kernel/central-arm-func/conv_add_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -21,12 +22,44 @@ namespace operators { ...@@ -21,12 +22,44 @@ namespace operators {
template <> template <>
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) { bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
InitBaseConvKernel(param);
return true; return true;
} }
template <> template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) { void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
ConvAddCompute<float>(param); switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvAddBasic(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
} }
template class ConvAddKernel<CPU, float>; template class ConvAddKernel<CPU, float>;
......
...@@ -31,21 +31,14 @@ template <> ...@@ -31,21 +31,14 @@ template <>
void ConvAddReluKernel<CPU, float>::Compute( void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam<CPU> &param) { const FusionConvAddReluParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Bias(), true, true); param.Paddings(), param.Output());
break; math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output()); math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break; break;
#ifndef __aarch64__ #ifndef __aarch64__
......
...@@ -16,7 +16,8 @@ limitations under the License. */ ...@@ -16,7 +16,8 @@ limitations under the License. */
#include "operators/kernel/conv_bn_add_relu_kernel.h" #include "operators/kernel/conv_bn_add_relu_kernel.h"
#include <cmath> #include <cmath>
#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h" #include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -51,13 +52,46 @@ bool ConvBNAddReluKernel<CPU, float>::Init( ...@@ -51,13 +52,46 @@ bool ConvBNAddReluKernel<CPU, float>::Init(
} }
param->SetNewScale(new_scale); param->SetNewScale(new_scale);
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
InitBaseConvKernel(param);
return true; return true;
} }
template <> template <>
void ConvBNAddReluKernel<CPU, float>::Compute( void ConvBNAddReluKernel<CPU, float>::Compute(
const FusionConvBNAddReluParam<CPU> &param) { const FusionConvBNAddReluParam<CPU> &param) {
ConvBNAddReluCompute<float>(param); switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvBNAddReluParam<CPU>>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
} }
template class ConvBNAddReluKernel<CPU, float>; template class ConvBNAddReluKernel<CPU, float>;
......
...@@ -60,25 +60,15 @@ template <> ...@@ -60,25 +60,15 @@ template <>
void ConvBNReluKernel<CPU, float>::Compute( void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam<CPU> &param) { const FusionConvBNReluParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Output(), param.NewScale(), param.Paddings(), param.Output());
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
......
...@@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam<CPU> *param) { ...@@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
#endif // __aarch64__ #endif // __aarch64__
} else { } else {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] && if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 && param->Strides()[0] == 1) {
param->Paddings()[0] == param->Paddings()[1]) { param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT;
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 0 && param->Strides()[0] == 2) {
param->Paddings()[0] == param->Paddings()[1]) { param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT;
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
} else if (depth3x3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
#ifndef __aarch64__ #ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) { param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] && } else if (conv3x3 && !depth3x3 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] && param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 /* && param->Strides()[0] == 1 && param->Dilations()[0] == 1
param->Output()->dims()[1] >= 16 && #if 0
&& param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 && param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */) { param->Input()->dims()[2] <= 140 */ /* refered from ncnn */
#endif
) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight // transform weight
param->transformed_filter_ = new framework::LoDTensor; param->transformed_filter_ = new framework::LoDTensor;
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <iostream>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -41,21 +43,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { ...@@ -41,21 +43,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConv5x5<int8_t, int32_t>(param); DepthwiseConv5x5<int8_t, int32_t>(param);
break; break;
#endif // __aarch64__ #endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
nullptr, false, false); param.Paddings(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
break; break;
#ifndef __aarch64__ #ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
......
...@@ -60,25 +60,15 @@ template <> ...@@ -60,25 +60,15 @@ template <>
void DWConvBNReluKernel<CPU, float>::Compute( void DWConvBNReluKernel<CPU, float>::Compute(
const FusionDWConvBNReluParam<CPU> &param) { const FusionDWConvBNReluParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Output(), param.NewScale(), param.Paddings(), param.Output());
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
......
...@@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) { ...@@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
} }
} }
template <typename P>
void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false);
} else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, false);
}
} else {
ConvAddBasic(param);
}
}
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) { ...@@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
} }
} }
} }
template <typename P>
void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvBNAddReluBasic(param);
}
}
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) { ...@@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&out_slice); &out_slice);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
col2vol(col, dilations, strides, paddings, &out_slice); col2vol(col, dilations, strides, paddings, &out_slice);
} }
......
...@@ -12,2066 +12,1042 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,2066 +12,1042 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include "operators/math/depthwise_conv3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void DepthwiseConv3x3(const framework::Tensor *input, #ifndef __aarch64__
const std::vector<int> &strides, inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) {
const std::vector<int> &paddings, float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0));
const framework::Tensor *filter, framework::Tensor *bias, float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1));
framework::Tensor *output, bool if_bias) { return vcombine_f32(sum0, sum1);
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int _kernel_size = 3;
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const float zero = 0;
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int filter_channel_stride = 9;
const float *input_ptr = input->data<float>();
const float *filter_ptr = filter->data<float>();
if (if_bias) {
math::expand_bias(*bias, 1, output->dims());
output->ShareDataWith(*bias);
}
float *output_ptr = output->mutable_data<float>();
const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr2;
int hstart, wstart, hend, wend;
float result;
for (int i = 0; i < batch_size; ++i) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
const float *input_data =
input_ptr + (i * output_channels + c) * input_channel_stride;
float *output_data =
output_ptr + (i * output_channels + c) * output_channel_stride;
filter1 = filter_ptr + c * filter_channel_stride;
filter2 = filter1 + 3;
filter3 = filter2 + 3;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, input_height);
wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
output_ptr2 = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
result = 0;
float fake_input[9] = {0};
if (hstart == 0 && wstart == 0) {
// 左上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k -
(3 - wend)];
}
}
}
} else if (hstart == 0 && wend == input_width) {
// 右上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height && wstart == 0) {
// 左下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - 1 - hstart && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k - (3 - wend)];
}
}
}
} else if (hend == input_height && wend == input_width) {
// 右下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1 &&
k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (hstart == 0) {
// 顶部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height) {
// 底部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (wstart == 0) {
// 左侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width +
(k - (3 - wend))];
}
}
}
} else if (wend == input_width) {
// 右侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
}
for (int l = 0; l < 9; ++l) {
result += fake_input[l] * filter1[l];
}
if (if_bias) {
output_data[ph * output_width + pw] += result;
} else {
output_data[ph * output_width + pw] = result;
}
} else {
#if __ARM_NEON
#if __aarch64__
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t v_filter1 = vld1q_f32(filter1);
const float32x4_t v_filter2 = vld1q_f32(filter2);
const float32x4_t v_filter3 = vld1q_f32(filter3);
float32x4_t mula = vmulq_f32(data1, v_filter1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
if (if_bias) {
output_data[ph * output_width + pw] += vget_lane_f32(res, 0);
} else {
output_data[ph * output_width + pw] = vget_lane_f32(res, 0);
}
#else
asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t"
"vld1.32 {q4}, [%[filter1]] \n\t"
"vmov.f32 q0, #0.0 \n\t"
"vld1.32 {q2}, [%[pos2]] \n\t"
"vld1.32 {q5}, [%[filter2]] \n\t"
"vmla.f32 q0, q1, q4 \n\t"
"vld1.32 {q3}, [%[pos3]] \n\t"
"vld1.32 {q6}, [%[filter3]] \n\t"
"vmla.f32 q0, q2, q5 \n\t"
"vmla.f32 q0, q3, q6 \n\t"
"vmov.f32 d1[1], %[zero] \n\t"
"vadd.f32 d4, d0, d1 \n\t"
"vadd.f32 s10, s8, s9 \n\t"
"vst1.32 {d5[0]},[%[output_ptr]] \n\t"
:
: [input_data] "r"(input_data), [pos1] "r"(pos1),
[pos2] "r"(pos2), [pos3] "r"(pos3), [filter1] "r"(filter1),
[filter2] "r"(filter2), [filter3] "r"(filter3),
[output_ptr] "r"(output_ptr2), [zero] "r"(zero)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
#endif // __aarch64__
#else
#endif // __ARM_NEON
}
}
}
}
}
} }
#endif
void DepthwiseConv3x3s1p1(const framework::Tensor *input, template <int Stride = 1>
const framework::Tensor *filter, inline void Depth3x3NormalRowLoadInput(const float *input, float32x4_t *y) {
framework::Tensor *output, framework::Tensor *bias, y[0] = vld1q_f32(input);
bool if_bias, bool if_relu) { y[2] = vld1q_f32(input + 4);
#if __ARM_NEON y[1] = vextq_f32(y[0], y[2], 1);
const int batch_size = static_cast<int>(input->dims()[0]); y[2] = vextq_f32(y[0], y[2], 2);
const int c = static_cast<int>(input->dims()[1]); }
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int hxw = h * w;
// const int l = h;
// leftTop, rightTop, leftBottom, rightBottom
const int lt = 0;
const int rt = w - 1;
const int lb = (h - 1) * w;
const int rb = h * w - 1;
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; ++b) {
#pragma omp parallel for
for (int j = 0; j < c; ++j) {
const float *filter_data_tmp = filter->data<float>() + j * 9;
const float *input_data = input->data<float>() + j * hxw;
float *output_data = output->mutable_data<float>() + j * hxw;
float32x4_t vbias;
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[lt] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[rt] = w10 * input_data[w - 2] + w11 * input_data[w - 1] +
w20 * input_data[2 * w - 2] +
w21 * input_data[2 * w - 1];
output_data[lb] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] +
w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[rb] =
w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] +
w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1];
if (if_bias) {
output_data[lt] += bias_data[j];
output_data[rt] += bias_data[j];
output_data[lb] += bias_data[j];
output_data[rb] += bias_data[j];
}
if (if_relu) {
output_data[lt] = output_data[lt] < 0 ? 0 : output_data[lt];
output_data[rt] = output_data[rt] < 0 ? 0 : output_data[rt];
output_data[lb] = output_data[lb] < 0 ? 0 : output_data[lb];
output_data[rb] = output_data[rb] < 0 ? 0 : output_data[rb];
}
for (int i = 1; i < h - 1; ++i) {
int left = i * w;
int right = i * w + w - 1;
output_data[left] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] +
w11 * input_data[i * w] + w12 * input_data[i * w + 1] +
w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1];
output_data[right] = w00 * input_data[i * w + w - 1 - w - 1] +
w01 * input_data[i * w + w - 1 - w] +
w10 * input_data[i * w + w - 1 - 1] +
w11 * input_data[i * w + w - 1] +
w20 * input_data[i * w + w - 1 + w - 1] +
w21 * input_data[i * w + w - 1 + w];
if (if_bias) {
output_data[left] += bias_data[j];
output_data[right] += bias_data[j];
}
if (if_relu) {
output_data[left] = output_data[left] < 0 ? 0 : output_data[left];
output_data[right] = output_data[right] < 0 ? 0 : output_data[right];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w);
const float *input_tmp_end = input_tmp + (h - 2) * w;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w);
int c_mid = w_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr + (h - 1) * w, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]);
float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2);
}
}
// mid
for (int i = 0; i < h - 2; ++i) {
auto output_ptr = output_data + (i + 1) * w + 1;
input_tmp = input_data + i * w;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + w);
auto in4_tmp = vld1q_f32(input_tmp + w + w);
c_mid = w_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + w + 4);
auto in5_tmp = vld1q_f32(input_tmp + w + w + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4; template <>
input_tmp += 4; inline void Depth3x3NormalRowLoadInput<2>(const float *input, float32x4_t *y) {
in0_tmp = in1_tmp; float32x4x2_t x = vld2q_f32(input);
in2_tmp = in3_tmp; y[0] = x.val[0];
in4_tmp = in5_tmp; y[1] = x.val[1];
} y[2] = vextq_f32(y[0], y[0], 1);
y[2] = vsetq_lane_f32(input[8], y[2], 3);
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); #define DEPTHWISE_CONV3X3_NORMAL_BORDER(start, end) \
float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); for (int w = start; w < end; ++w) { \
float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 3; \
tmp0 = vextq_f32(in0_tmp, pad0, 1); const int w_start = w_in_start > 0 ? w_in_start : 0; \
tmp1 = vextq_f32(in0_tmp, pad0, 2); const int w_end = w_in_end < input_w ? w_in_end : input_w; \
tmp2 = vextq_f32(in2_tmp, pad1, 1); float value = 0; \
tmp3 = vextq_f32(in2_tmp, pad1, 2); for (int h_in = h_start; h_in < h_end; ++h_in) { \
tmp4 = vextq_f32(in4_tmp, pad2, 1); for (int w_in = w_start; w_in < w_end; ++w_in) { \
tmp5 = vextq_f32(in4_tmp, pad2, 2); value += filter[(h_in - h_in_start) * 3 + (w_in - w_in_start)] * \
input[h_in * input_w + w_in]; \
out0 = vmulq_n_f32(in0_tmp, w00); } \
out0 = vmlaq_n_f32(out0, tmp0, w01); } \
out0 = vmlaq_n_f32(out0, tmp1, w02); output_ptr[w] = value; \
out0 = vmlaq_n_f32(out0, in2_tmp, w10); }
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
for (int i = 0; i < c_mid; ++i) { template <int Stride_h, int Stride_w>
if (i == 0) { inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter,
vst1q_lane_f32(output_ptr + i, out0, 0); const int h_output, const int input_h,
} const int input_w, const int padding_h,
if (i == 1) { const int padding_w, const int output_w,
vst1q_lane_f32(output_ptr + i, out0, 1); float *output, float32x4_t *ker) {
} const int h_in_start = -padding_h + h_output * Stride_h;
if (i == 2) { const int h_in_end = h_in_start + 3;
vst1q_lane_f32(output_ptr + i, out0, 2); const int h_start = h_in_start > 0 ? h_in_start : 0;
} const int h_end = h_in_end < input_h ? h_in_end : input_h;
}
} const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1;
// const int valid_w_end = output_w - valid_w_start;
float *output_ptr = output + h_output * output_w;
// border left
DEPTHWISE_CONV3X3_NORMAL_BORDER(0, valid_w_start)
// middle
int output_tiles = (valid_w_end - valid_w_start) >> 2;
float32x4_t _sum, _x[3];
// valid w
for (int w = 0; w < output_tiles * 4; w += 4) {
_sum = vdupq_n_f32(0.f);
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride_w - padding_w;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0);
} }
vst1q_f32(output_ptr + output_offset, _sum);
} }
#endif // remain valid w
int remain = (valid_w_end - valid_w_start) & 0x3;
if (remain > 0) {
_sum = vdupq_n_f32(0.f);
int remain_start = valid_w_start + (output_tiles << 2);
int input_w_offset = remain_start * Stride_w - padding_w;
float *output_ptr0 = output_ptr + remain_start;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0);
}
switch (remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _sum, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_sum));
break;
case 1:
vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0);
break;
}
}
// border right
DEPTHWISE_CONV3X3_NORMAL_BORDER(valid_w_end, output_w)
} }
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, template <>
const framework::Tensor *filter, void DepthwiseConv3x3S1<float, float>(const framework::Tensor &input,
framework::Tensor *output, const framework::Tensor &filter,
const framework::Tensor *new_scale, const std::vector<int> &paddings,
const framework::Tensor *new_bias, framework::Tensor *output) {
bool if_relu) { const float *input_data = input.data<float>();
#if __ARM_NEON const float *filter_data = filter.data<float>();
const float *input_data = input->data<float>(); float *out_data = output->mutable_data<float>();
const float *filter_data = filter->data<float>(); int input_h = input.dims()[2];
float *output_data = output->mutable_data<float>(); int input_w = input.dims()[3];
const float *newscale_data = new_scale->data<float>(); int output_h = output->dims()[2];
const float *newbias_data = new_bias->data<float>(); int output_w = output->dims()[3];
int padding_h = paddings[0];
const int batch_size = static_cast<int>(input->dims()[0]); int padding_w = paddings[1];
const int input_channel = static_cast<int>(input->dims()[1]); int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
const int input_height = static_cast<int>(input->dims()[2]); int valid_h_start = padding_h;
const int input_width = static_cast<int>(input->dims()[3]); int valid_h_end = output_h - valid_h_start;
const int output_height = static_cast<int>(output->dims()[2]); int valid_h = valid_h_end - valid_h_start;
const int output_width = static_cast<int>(output->dims()[3]); int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start;
const int hxw = input_height * input_width; int valid_w = valid_w_end - valid_w_start;
// const int l = input_height; #pragma omp parallel for
const int h = input_height; for (int g = 0; g < input.dims()[1]; ++g) {
const int w = input_width; const float *input_ptr = input_data + g * image_size;
float32x4_t vzero = vdupq_n_f32(0); const float *filter_ptr = filter_data + g * 9;
float *output_ptr = out_data + g * out_image_size;
for (int b = 0; b < batch_size; b++) {
#pragma omp parallel for const float *filter_ptr0 = filter_ptr;
for (int c = 0; c < input_channel; c++) { const float *filter_ptr1 = filter_ptr0 + 3;
const float *filter_data = filter->data<float>() + c * 9; const float *filter_ptr2 = filter_ptr1 + 3;
const float *input_data = input->data<float>() + c * hxw; float32x4_t _ker[3];
float *output_data = output->data<float>() + c * hxw; _ker[0] = vld1q_f32(filter_ptr0);
float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]); _ker[1] = vld1q_f32(filter_ptr1);
float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]); _ker[2] = vld1q_f32(filter_ptr2);
float w00 = filter_data[0]; // pad top
float w01 = filter_data[1]; for (int h = 0; h < valid_h_start; ++h) {
float w02 = filter_data[2]; DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
float w10 = filter_data[3]; input_w, padding_h, padding_w, output_w,
float w11 = filter_data[4]; output_ptr, _ker);
float w12 = filter_data[5]; }
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
for (int i = 1; i < output_height - 1; i++) {
float *output_ptr;
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, tmp4,
tmp5, out0;
for (int m = 1; m < output_width - 4; m += 4) {
output_ptr = output_data + i * output_width + m;
in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1);
in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3);
in2 = vld1q_f32(input_data + i * input_width + m - 1);
in3 = vld1q_f32(input_data + i * input_width + m + 3);
in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1);
in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
}
int m;
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) { // output 2x6
output_data[i * output_width + j] = int output_w_tiles = valid_w / 6;
input_data[(i - 1) * input_width + j - 1] * w00 + int output_w_remain = valid_w - output_w_tiles * 6;
input_data[(i - 1) * input_width + j] * w01 + for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
input_data[(i - 1) * input_width + j + 1] * w02 + const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
input_data[(i)*input_width + j - 1] * w10 + const float *input_ptr1 = input_ptr0 + input_w;
input_data[(i)*input_width + j] * w11 + const float *input_ptr2 = input_ptr1 + input_w;
input_data[(i)*input_width + j + 1] * w12 + const float *input_ptr3 = input_ptr2 + input_w;
input_data[(i + 1) * input_width + j - 1] * w20 + float *output_ptr0 = output_ptr + h * output_w;
input_data[(i + 1) * input_width + j] * w21 + float *output_ptr1 = output_ptr0 + output_w;
input_data[(i + 1) * input_width + j + 1] * w22; // pad left
output_data[i * output_width + j] = if (padding_w) {
newscale_data[c] * output_data[i * output_width + j] + float32x4_t row0 = vld1q_f32(input_ptr0);
newbias_data[c]; float32x4_t row1 = vld1q_f32(input_ptr1);
if (if_relu) { float32x4_t row2 = vld1q_f32(input_ptr2);
output_data[i * output_width + j] = float32x4_t row3 = vld1q_f32(input_ptr3);
output_data[i * output_width + j] < 0 float32x4_t zero = vdupq_n_f32(0.f);
? 0 row0 = vextq_f32(zero, row0, 3);
: output_data[i * output_width + j]; row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
float32x4_t acc0, acc1;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
} else {
acc0 = vmulq_f32(row0, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
acc0 = vextq_f32(acc0, acc0, 1);
acc1 = vmulq_f32(row1, _ker[0]);
acc1 = vmlaq_f32(acc1, row2, _ker[1]);
acc1 = vmlaq_f32(acc1, row3, _ker[2]);
acc1 = vextq_f32(acc1, acc1, 1);
float32x2_t sum = vpadd_f32(vget_low_f32(acc0), vget_low_f32(acc1));
vst1_lane_f32(output_ptr0 + w, sum, 0);
vst1_lane_f32(output_ptr1 + w, sum, 1);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
} }
} }
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
} }
// valid
output_data[0] = w11 * input_data[0] + w12 * input_data[1] + float32x4_t _result0, _result1, _result2, _result3;
w21 * input_data[w] + w22 * input_data[w + 1]; for (int loop = 0; loop < output_w_tiles; ++loop) {
output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + float32x4_t _row00 = vld1q_f32(input_ptr0);
w20 * input_data[2 * w - 2] + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
w21 * input_data[2 * w - 1]; float32x4_t _row10 = vld1q_f32(input_ptr1);
output_data[(h - 1) * w] = float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] +
w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
output_data[h * w - 1] = float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c];
output_data[w - 1] = _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
output_data[w - 1] * newscale_data[c] + newbias_data[c]; _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
output_data[(h - 1) * w] = _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
output_data[(h - 1) * w] * newscale_data[c] + newbias_data[c]; _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
output_data[h * w - 1] = _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
output_data[h * w - 1] * newscale_data[c] + newbias_data[c]; _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
if (if_relu) { _ext01 = vextq_f32(_row10, _row11, 1);
output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; _ext02 = vextq_f32(_row10, _row11, 2);
output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - 1]; _ext03 = vextq_f32(_row11, _row11, 1);
output_data[(h - 1) * w] = _ext04 = vextq_f32(_row11, _row11, 2);
output_data[(h - 1) * w] < 0 ? 0 : output_data[(h - 1) * w];
output_data[h * w - 1] = _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
output_data[h * w - 1] < 0 ? 0 : output_data[h * w - 1]; _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
} _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
for (int i = 1; i < h - 1; ++i) { _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
output_data[i * w] = _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
w11 * input_data[i * w] + w12 * input_data[i * w + 1] +
w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; _result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1);
output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0);
w01 * input_data[i * w + w - 1 - w] + _result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0);
w10 * input_data[i * w + w - 1 - 1] + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1);
w11 * input_data[i * w + w - 1] + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0);
w20 * input_data[i * w + w - 1 + w - 1] +
w21 * input_data[i * w + w - 1 + w]; _row00 = vld1q_f32(input_ptr2);
output_data[i * w] = _row01 = vld1q_f32(input_ptr2 + 4);
output_data[i * w] * newscale_data[c] + newbias_data[c]; _row10 = vld1q_f32(input_ptr3);
output_data[i * w + w - 1] = _row11 = vld1q_f32(input_ptr3 + 4);
output_data[i * w + w - 1] * newscale_data[c] + newbias_data[c];
_ext01 = vextq_f32(_row00, _row01, 1);
if (if_relu) { _ext02 = vextq_f32(_row00, _row01, 2);
output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i * w]; _ext03 = vextq_f32(_row01, _row01, 1);
output_data[i * w + w - 1] = _ext04 = vextq_f32(_row01, _row01, 2);
output_data[i * w + w - 1] < 0 ? 0 : output_data[i * w + w - 1];
} _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
vst1_f32(output_ptr0 + 4, vget_low_f32(_result1));
vst1q_f32(output_ptr1, _result2);
vst1_f32(output_ptr1 + 4, vget_low_f32(_result3));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
input_ptr3 += 6;
output_ptr0 += 6;
output_ptr1 += 6;
} }
// remain w
int m; if (output_w_remain > 0) {
for (m = 1; m < output_width - 4; m += 4) { float32x4_t _row00 = vld1q_f32(input_ptr0);
float *output_ptr = output_data + m; float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; float32x4_t _row10 = vld1q_f32(input_ptr1);
in0 = vld1q_f32(input_data + m - 1); float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
in1 = vld1q_f32(input_data + m + 3);
in2 = vld1q_f32(input_data + input_width + m - 1); float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
in3 = vld1q_f32(input_data + input_width + m + 3); float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
tmp0 = vextq_f32(in0, in1, 1); float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
tmp1 = vextq_f32(in0, in1, 2); float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2); _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
out0 = vmulq_n_f32(in0, w10); _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
out0 = vmlaq_n_f32(out0, tmp0, w11); _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
out0 = vmlaq_n_f32(out0, tmp1, w12); _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
out0 = vmlaq_n_f32(out0, in2, w20); _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
out0 = vmlaq_n_f32(out0, tmp2, w21); _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); _ext01 = vextq_f32(_row10, _row11, 1);
if (if_relu) { _ext02 = vextq_f32(_row10, _row11, 2);
out0 = vmaxq_f32(out0, vzero); _ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
_result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0);
_result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0);
_row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
_row10 = vld1q_f32(input_ptr3);
_row11 = vld1q_f32(input_ptr3 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 5:
vst1q_lane_f32(output_ptr0 + 4, _result1, 0);
vst1q_lane_f32(output_ptr1 + 4, _result3, 0);
case 4:
vst1q_f32(output_ptr0, _result0);
vst1q_f32(output_ptr1, _result2);
break;
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
vst1q_lane_f32(output_ptr1 + 2, _result2, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
vst1_f32(output_ptr1, vget_low_f32(_result2));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
vst1q_lane_f32(output_ptr1, _result2, 0);
break;
} }
vst1q_f32(output_ptr, out0);
}
for (m = 1; (m + 3) < output_width - 1; m += 4) { input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
input_ptr3 += output_w_remain;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
} }
for (int j = m; j < output_width - 1; j++) { // pad right
output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 + if (padding_w) {
input_data[j + 1] * w12 + float32x2_t row0 = vld1_f32(input_ptr0);
input_data[input_width + j - 1] * w20 + float32x2_t row1 = vld1_f32(input_ptr1);
input_data[input_width + j] * w21 + float32x2_t row2 = vld1_f32(input_ptr2);
input_data[input_width + j + 1] * w22; float32x2_t row3 = vld1_f32(input_ptr3);
output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c]; float32x2_t zero = vdup_n_f32(0.f);
float32x2_t acc0, acc1;
if (if_relu) { for (int w = valid_w_end; w < output_w; ++w) {
output_data[j] = output_data[j] < 0 ? 0 : output_data[j]; int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
} else {
acc0 = vmul_f32(row0, vget_low_f32(_ker[0]));
acc0 = vmla_f32(acc0, row1, vget_low_f32(_ker[1]));
acc0 = vmla_f32(acc0, row2, vget_low_f32(_ker[2]));
acc1 = vmul_f32(row1, vget_low_f32(_ker[0]));
acc1 = vmla_f32(acc1, row2, vget_low_f32(_ker[1]));
acc1 = vmla_f32(acc1, row3, vget_low_f32(_ker[2]));
float32x2_t sum = vpadd_f32(acc0, acc1);
vst1_lane_f32(output_ptr0, sum, 0);
vst1_lane_f32(output_ptr1, sum, 1);
row0 = vext_f32(row0, zero, 1);
row1 = vext_f32(row1, zero, 1);
row2 = vext_f32(row2, zero, 1);
row3 = vext_f32(row3, zero, 1);
}
output_ptr0++;
output_ptr1++;
} }
} }
}
for (m = 1; m < output_width - 4; m += 4) { // remain height
float *output_ptr = int start_h = valid_h_start + (valid_h & 0xfffe);
output_data + (output_height - 1) * output_width + m; if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; const float *input_ptr1 = input_ptr0 + input_w;
in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1); const float *input_ptr2 = input_ptr1 + input_w;
in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3); float *output_ptr0 = output_ptr + start_h * output_w;
in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1); // pad left
in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3); if (padding_w) {
tmp0 = vextq_f32(in0, in1, 1); float32x4_t row0 = vld1q_f32(input_ptr0);
tmp1 = vextq_f32(in0, in1, 2); float32x4_t row1 = vld1q_f32(input_ptr1);
tmp2 = vextq_f32(in2, in3, 1); float32x4_t row2 = vld1q_f32(input_ptr2);
tmp3 = vextq_f32(in2, in3, 2); float32x4_t zero = vdupq_n_f32(0.f);
out0 = vmulq_n_f32(in0, w00); row0 = vextq_f32(zero, row0, 3);
out0 = vmlaq_n_f32(out0, tmp0, w01); row1 = vextq_f32(zero, row1, 3);
out0 = vmlaq_n_f32(out0, tmp1, w02); row2 = vextq_f32(zero, row2, 3);
out0 = vmlaq_n_f32(out0, in2, w10); float32x4_t acc;
out0 = vmlaq_n_f32(out0, tmp2, w11); for (int w = valid_w_start - 1; w >= 0; --w) {
out0 = vmlaq_n_f32(out0, tmp3, w12); int padding = padding_w - w;
out0 = vmlaq_f32(vnewbias, vnewscale, out0); if (padding >= 3) {
if (if_relu) { output_ptr0[w] = 0.f;
out0 = vmaxq_f32(out0, vzero); } else {
acc = vmulq_f32(row0, _ker[0]);
acc = vmlaq_f32(acc, row1, _ker[1]);
acc = vmlaq_f32(acc, row2, _ker[2]);
acc = vextq_f32(acc, acc, 1);
float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_low_f32(acc));
vst1_lane_f32(output_ptr0 + w, sum, 0);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
}
} }
vst1q_f32(output_ptr, out0); output_ptr0 += valid_w_start;
} }
for (m = 1; (m + 3) < output_width - 1; m = m + 4) { // valid
float32x4_t _result0, _result1;
for (int loop = 0; loop < output_w_tiles; ++loop) {
float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t _row10 = vld1q_f32(input_ptr1);
float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
_result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
_row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
vst1_f32(output_ptr0 + 4, vget_low_f32(_result1));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
output_ptr0 += 6;
} }
for (int j = m; j < output_width - 1; j++) {
output_data[(output_height - 1) * input_width + j] =
input_data[(output_height - 2) * input_width + j - 1] * w00 +
input_data[(output_height - 2) * input_width + j] * w01 +
input_data[(output_height - 2) * input_width + j + 1] * w02 +
input_data[(output_height - 1) * input_width + j - 1] * w10 +
input_data[(output_height - 1) * input_width + j] * w11 +
input_data[(output_height - 1) * input_width + j + 1] * w12;
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] *
newscale_data[c] +
newbias_data[c];
if (if_relu) {
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] < 0
? 0
: output_data[(output_height - 1) * output_width + j];
}
}
}
}
/*
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
// const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w -
1] + w20 * input_data[2 * w - 2] + w21 * input_data[2 * w - 1];
output_data[(h - 1) * w] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w +
1] + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[h * w - 1] = w00 * input_data[h*w-w-2] +
w01 * input_data[h*w-w-1] +
w10 * input_data[h * w - 2] +
w11 * input_data[h * w - 1];
output_data[0] = output_data[0] * newscale_data[j] +
newbias_data[j]; output_data[w - 1] = output_data[w - 1] *
newscale_data[j] + newbias_data[j]; output_data[(h - 1) * w] =
output_data[(h - 1) * w] * newscale_data[j] + newbias_data[j];
output_data[h * w - 1] =
output_data[h * w - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w -
1]; output_data[(h - 1) * w] = output_data[(h - 1) * w] < 0 ? 0 :
output_data[(h - 1) * w]; output_data[h * w - 1] = output_data[h * w - 1]
< 0 ? 0 : output_data[h * w - 1];
}
for (int i = 1; i < h - 1; ++i) {
output_data[i * w] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1]
+ w11 * input_data[i * w] + w12 * input_data[i * w + 1] + w21 *
input_data[i * w + w] + w22 * input_data[i * w + w + 1]; output_data[i *
w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + w01 * input_data[i
* w + w - 1 - w] + w10 * input_data[i * w + w - 1 - 1] + w11 *
input_data[i * w + w - 1] + w20 * input_data[i * w + w - 1 + w - 1] + w21
* input_data[i * w + w - 1 + w]; output_data[i * w] = output_data[i * w]
* newscale_data[j] + newbias_data[j]; output_data[i * w + w - 1] =
output_data[i * w + w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i
* w]; output_data[i * w + w - 1] = output_data[i * w + w - 1] < 0 ? 0 :
output_data[i * w + w - 1];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1,
tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); in2 =
vld1q_f32(input_tmp + w); const float *input_tmp_end = input_tmp + (h -
2) * w; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end +
w); int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid >
3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); in3 =
vld1q_f32(input_tmp + w + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (h - 1) * w, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad if (output_w_remain > 0) {
float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t _row10 = vld1q_f32(input_ptr1);
tmp0 = vextq_f32(in0, pad0, 1); float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1); float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
tmp3 = vextq_f32(in2, pad1, 2); float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
out0 = vmulq_n_f32(in0, w10); float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12); _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
out0 = vmlaq_n_f32(out0, in2, w20); _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
out0 = vmlaq_n_f32(out0, tmp2, w21); _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
out0 = vmlaq_n_f32(out0, tmp3, w22); _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
if (if_relu) { _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
out0 = vmaxq_f32(out0, vzero);
} _ext01 = vextq_f32(_row10, _row11, 1);
for (int i = 0; i < c_mid; ++i) { _ext02 = vextq_f32(_row10, _row11, 2);
if (i == 0) { _ext03 = vextq_f32(_row11, _row11, 1);
vst1q_lane_f32(output_ptr + i, out0, 0); _ext04 = vextq_f32(_row11, _row11, 2);
}
if (i == 1) { _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
vst1q_lane_f32(output_ptr + i, out0, 1); _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
} _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
if (i == 2) { _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
vst1q_lane_f32(output_ptr + i, out0, 2); _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
} _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
}
_row00 = vld1q_f32(input_ptr2);
// bottom right pad _row01 = vld1q_f32(input_ptr2 + 4);
float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]);
float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); _ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
tmp0 = vextq_f32(in4, pad2, 1); _ext03 = vextq_f32(_row01, _row01, 1);
tmp1 = vextq_f32(in4, pad2, 2); _ext04 = vextq_f32(_row01, _row01, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2); _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
out0 = vmulq_n_f32(in4, w00); _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
out0 = vmlaq_n_f32(out0, tmp0, w01); _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
out0 = vmlaq_n_f32(out0, tmp1, w02); _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
out0 = vmlaq_n_f32(out0, in6, w10); _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12); switch (output_w_remain) {
out0 = vmlaq_f32(vnewbias, vnewscale, out0); case 5:
if (if_relu) { vst1q_lane_f32(output_ptr0 + 4, _result1, 0);
out0 = vmaxq_f32(out0, vzero); case 4:
} vst1q_f32(output_ptr0, _result0);
for (int i = 0; i < c_mid; ++i) { break;
if (i == 0) { case 3:
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
} case 2:
if (i == 1) { vst1_f32(output_ptr0, vget_low_f32(_result0));
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); break;
} case 1:
if (i == 2) { vst1q_lane_f32(output_ptr0, _result0, 0);
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); break;
}
}
// mid
for (int i = 0; i < h - 2; ++i) {
auto output_ptr = output_data + (i + 1) * w + 1;
input_tmp = input_data + i * w;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + w);
auto in4_tmp = vld1q_f32(input_tmp + w + w);
c_mid = w_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + w + 4);
auto in5_tmp = vld1q_f32(input_tmp + w + w + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
}
} }
*/
#endif input_ptr0 += output_w_remain;
} input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
/// w!=h not fix output_ptr0 += output_w_remain;
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, }
const framework::Tensor *filter, // pad right
framework::Tensor *output, if (padding_w) {
const framework::Tensor *new_scale, float32x2_t row0 = vld1_f32(input_ptr0);
const framework::Tensor *new_bias, float32x2_t row1 = vld1_f32(input_ptr1);
bool if_relu) { float32x2_t row2 = vld1_f32(input_ptr2);
#if __ARM_NEON float32x2_t zero = vdup_n_f32(0.f);
float32x2_t acc;
const int batch_size = input->dims()[0]; for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
const int input_height = input->dims()[2]; if (padding >= 3) {
*output_ptr0 = 0.f;
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int _kernel_size = 3;
const int stride_height = 2;
const int stride_width = 2;
const int padding_height = 1;
const int padding_width = 1;
const float zero = 0;
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int filter_channel_stride = 9;
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride;
const int filter_batch_stride = output_channels * output_channel_stride;
const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr;
int hstart, wstart, hend, wend;
float result;
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) {
filter1 = filter_data;
filter2 = filter1 + 3;
filter3 = filter2 + 3;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, input_height);
wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
output_ptr = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
result = 0;
float fake_input[9] = {0};
if (hstart == 0 && wstart == 0) {
// 左上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k -
(3 - wend)];
}
}
}
} else if (hstart == 0 && wend == input_width) {
// 右上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height && wstart == 0) {
// 左下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - 1 - hstart && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k - (3 - wend)];
}
}
}
} else if (hend == input_height && wend == input_width) {
// 右下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1 &&
k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (hstart == 0) {
// 顶部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height) {
// 底部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (wstart == 0) {
// 左侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width +
(k - (3 - wend))];
}
}
}
} else if (wend == input_width) {
// 右侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
}
for (int l = 0; l < 9; ++l) {
result += fake_input[l] * filter1[l];
}
output_data[ph * output_width + pw] =
newscale_data[c] * result + newbias_data[c];
if (if_relu) {
output_data[ph * output_width + pw] =
output_data[ph * output_width + pw] < 0
? 0
: output_data[ph * output_width + pw];
}
} else { } else {
const float32x4_t data1 = vld1q_f32(pos1); acc = vmul_f32(row0, vget_low_f32(_ker[0]));
const float32x4_t data2 = vld1q_f32(pos2); acc = vmla_f32(acc, row1, vget_low_f32(_ker[1]));
const float32x4_t data3 = vld1q_f32(pos3); acc = vmla_f32(acc, row2, vget_low_f32(_ker[2]));
float32x2_t sum = vpadd_f32(acc, acc);
const float32x4_t v_filter1 = vld1q_f32(filter1); vst1_lane_f32(output_ptr0, sum, 0);
const float32x4_t v_filter2 = vld1q_f32(filter2); row0 = vext_f32(row0, zero, 1);
const float32x4_t v_filter3 = vld1q_f32(filter3); row1 = vext_f32(row1, zero, 1);
float32x4_t mula = vmulq_f32(data1, v_filter1); row2 = vext_f32(row2, zero, 1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
output_data[ph * output_width + pw] =
vget_lane_f32(res, 0) * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[ph * output_width + pw] =
output_data[ph * output_width + pw] < 0
? 0
: output_data[ph * output_width + pw];
}
} }
output_ptr0++;
} }
} }
input_data += input_channel_stride;
output_data += output_channel_stride;
filter_data += filter_channel_stride;
} }
input_data += input_batch_stride; // pad bottom
output_data += output_batch_stride; for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
} }
#endif
} }
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, template <>
const framework::Tensor *filter, void DepthwiseConv3x3S2<float, float>(const framework::Tensor &input,
framework::Tensor *output, framework::Tensor *bias, const framework::Tensor &filter,
bool if_bias, bool if_relu) { const std::vector<int> &paddings,
#if __ARM_NEON framework::Tensor *output) {
const float *input_data = input->data<float>(); const float *input_data = input.data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter.data<float>();
float *output_data = output->mutable_data<float>(); float *out_data = output->mutable_data<float>();
const float *bias_data; int input_h = input.dims()[2];
if (if_bias) { int input_w = input.dims()[3];
bias_data = bias->data<float>(); int output_h = output->dims()[2];
} int output_w = output->dims()[3];
int padding_h = paddings[0];
const int in_h = static_cast<int>(input->dims()[2]); int padding_w = paddings[1];
const int in_w = static_cast<int>(input->dims()[3]); int image_size = input_h * input_w;
const int out_h = static_cast<int>(output->dims()[2]); int out_image_size = output_h * output_w;
const int out_w = static_cast<int>(output->dims()[3]); int valid_h_start = (padding_h + 1) / 2;
const int out_l = out_h; int valid_h_end = (input_h + padding_h - 1) / 2;
const int in_l = in_h; int valid_h = valid_h_end - valid_h_start;
const int inhxw = in_h * in_w; int valid_w_start = (padding_w + 1) / 2;
const int outhxw = out_h * out_w; int valid_w_end = (input_w + padding_w - 1) / 2;
/// todo : fix if_pad when w != h int valid_w = valid_w_end - valid_w_start;
const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0; int input_w_start = 2 * valid_w_start - padding_w;
const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]); #pragma omp parallel for
const int c = static_cast<int>(input->dims()[1]); for (int g = 0; g < input.dims()[1]; ++g) {
const float *input_row_ptr; const float *input_ptr = input_data + g * image_size;
float *output_row_ptr; const float *filter_ptr = filter_data + g * 9;
float *output_ptr = out_data + g * out_image_size;
const int w_times = (out_w - 2) / 3;
const float *filter_ptr0 = filter_ptr;
float32x4_t vbias = vdupq_n_f32(0.0); const float *filter_ptr1 = filter_ptr0 + 3;
const float *filter_ptr2 = filter_ptr1 + 3;
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; float32x4_t _ker[3];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; _ker[0] = vld1q_f32(filter_ptr0);
int out2in_mid; _ker[1] = vld1q_f32(filter_ptr1);
float32x4_t zero = vdupq_n_f32(0.0); _ker[2] = vld1q_f32(filter_ptr2);
for (int b = batch_size; b > 0; --b) {
const float *filter_data_tmp = filter_data; // pad top
for (int j = 0; j < c; ++j) { for (int h = 0; h < valid_h_start; ++h) {
auto output_data_tmp = output_data + j * out_h * out_w; DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
auto input_data_tmp = input_data + j * in_h * in_w; input_w, padding_h, padding_w, output_w,
auto input_const = input_data_tmp; output_ptr, _ker);
}
if (if_bias) { // valid 2x4
vbias = vdupq_n_f32(bias_data[j]); int output_w_tiles = valid_w / 4;
} int output_w_remain = valid_w - output_w_tiles * 4;
for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
float w00 = filter_data_tmp[0]; const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
float w01 = filter_data_tmp[1]; const float *input_ptr1 = input_ptr0 + input_w;
float w02 = filter_data_tmp[2]; const float *input_ptr2 = input_ptr1 + input_w;
float w10 = filter_data_tmp[3]; const float *input_ptr3 = input_ptr2 + input_w;
float w11 = filter_data_tmp[4]; const float *input_ptr4 = input_ptr3 + input_w;
float w12 = filter_data_tmp[5]; float *output_ptr0 = output_ptr + h * output_w;
float w20 = filter_data_tmp[6]; float *output_ptr1 = output_ptr0 + output_w;
float w21 = filter_data_tmp[7]; // pad left
float w22 = filter_data_tmp[8]; if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int h_mid = 0; int padding = padding_w - (w << 1);
if (padding >= 3) {
for (; h_mid < out_h - 1; h_mid++) { output_ptr0[w] = 0;
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; output_ptr1[w] = 0;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
} else { } else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); float32x4_t row0 = vld1q_f32(input_ptr0 - padding);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); float32x4_t row1 = vld1q_f32(input_ptr1 - padding);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); float32x4_t row2 = vld1q_f32(input_ptr2 - padding);
} float32x4_t row3 = vld1q_f32(input_ptr3 - padding);
input_buff_mid = vld2q_f32(input_row_ptr); float32x4_t row4 = vld1q_f32(input_ptr4 - padding);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); float32x4_t acc0 = vmulq_f32(row0, _ker[0]);
float32x4_t acc1 = vmulq_f32(row2, _ker[0]);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); acc0 = vmlaq_f32(acc0, row1, _ker[1]);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); acc1 = vmlaq_f32(acc1, row3, _ker[1]);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); acc0 = vmlaq_f32(acc0, row2, _ker[2]);
acc1 = vmlaq_f32(acc1, row4, _ker[2]);
elewise_res1 = float sum0 = vgetq_lane_f32(acc0, 2);
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); float sum1 = vgetq_lane_f32(acc1, 2);
elewise_res0 = if (padding == 1) {
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); sum0 += vgetq_lane_f32(acc0, 1);
elewise_res2 = sum1 += vgetq_lane_f32(acc1, 1);
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); }
output_ptr0[w] = sum0;
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), output_ptr1[w] = sum1;
vaddq_f32(elewise_res0, elewise_res1));
res3 = vaddq_f32(res3, vbias);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
vst1q_f32(output_row_ptr, res3);
input_row_ptr += 6;
output_row_ptr += 3;
}
}
clock();
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
if (!if_pad_b) {
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vaddq_f32(res3, vbias);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
if ((w4 != w_times)) {
vst1q_f32(output_row_ptr, res3);
} else {
if (out_w - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0);
} else if (out_w - 2 - w_times * 3 == 2) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
} }
} }
input_row_ptr += 6; input_ptr0 += input_w_start;
output_row_ptr += 3; input_ptr1 += input_w_start;
input_ptr2 += input_w_start;
input_ptr3 += input_w_start;
input_ptr4 += input_w_start;
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
} }
// valid
// leftTop, rightTop, leftBottom, rightBottom float32x4_t _result0, _result1, _ext;
int lt = 0; for (int loop = 0; loop < output_w_tiles; ++loop) {
int rt = out_w - 1; float32x4x2_t _row0 = vld2q_f32(input_ptr0);
int lb = out_w * (out_h - 1); float32x4x2_t _row1 = vld2q_f32(input_ptr1);
int rb = out_h * out_w - 1;
_ext = vextq_f32(_row0.val[0], _ext, 1);
output_data_tmp[lt] = input_const[0] * w11 + input_const[1] * w12 + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
input_const[in_w] * w21 + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
input_const[in_w + 1] * w22; _result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
out2in_mid = (out_w - 1) * 2; _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
output_data_tmp[rt] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + _ext = vextq_f32(_row1.val[0], _ext, 1);
w20 * input_const[out2in_mid + in_w - 1] + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
w21 * input_const[out2in_mid + in_w] + _result0 =
(1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
w22 * input_const[out2in_mid + in_w + 1]); _result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
out2in_mid = (out_h - 1) * 2 * in_w; _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
output_data_tmp[lb] = _row0 = vld2q_f32(input_ptr2);
w01 * input_const[out2in_mid - in_w] + _row1 = vld2q_f32(input_ptr3);
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + _ext = vextq_f32(_row0.val[0], _ext, 1);
(1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
w22 * input_const[out2in_mid + in_w + 1]); _result0 =
out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2; vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
output_data_tmp[rb] = vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1);
w00 * input_const[out2in_mid - in_w - 1] + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
w01 * input_const[out2in_mid - in_w] + _result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + _result1 =
(1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1);
w21 * input_const[out2in_mid + in_w]) + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0);
(1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1]) + _ext = vextq_f32(_row1.val[0], _ext, 1);
(1 - if_pad_r) * (1 - if_pad_b) * w22 * _ext = vsetq_lane_f32(input_ptr3[8], _ext, 3);
input_const[out2in_mid + in_w + 1]; _result1 =
if (if_bias) { vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0);
output_data_tmp[lt] += bias_data[j]; _result1 =
output_data_tmp[rt] += bias_data[j]; vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1);
output_data_tmp[lb] += bias_data[j]; _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0);
output_data_tmp[rb] += bias_data[j];
_row0 = vld2q_f32(input_ptr4);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr4[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
vst1q_f32(output_ptr1, _result1);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
input_ptr4 += 8;
output_ptr0 += 4;
output_ptr1 += 4;
} }
if (if_relu) { // remain w
output_data_tmp[lt] = output_data_tmp[lt] < 0 ? 0 : output_data_tmp[lt]; if (output_w_remain > 0) {
output_data_tmp[rt] = output_data_tmp[rt] < 0 ? 0 : output_data_tmp[rt]; float32x4x2_t _row0 = vld2q_f32(input_ptr0);
output_data_tmp[lb] = output_data_tmp[lb] < 0 ? 0 : output_data_tmp[lb]; float32x4x2_t _row1 = vld2q_f32(input_ptr1);
output_data_tmp[rb] = output_data_tmp[rb] < 0 ? 0 : output_data_tmp[rb];
} _ext = vextq_f32(_row0.val[0], _ext, 1);
for (int i = 1; i < out_h - 1; i++) { _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
out2in_mid = i * 2 * in_w; _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
int left = i * out_w; _result0 =
output_data_tmp[left] = w01 * input_const[out2in_mid - in_w] + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
w02 * input_const[out2in_mid - in_w + 1] + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] + _ext = vextq_f32(_row1.val[0], _ext, 1);
w21 * input_const[out2in_mid + in_w] + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
w22 * input_const[out2in_mid + in_w + 1]; _result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
out2in_mid = i * 2 * in_w + (out_w - 1) * 2; _result0 =
int right = i * out_w + out_w - 1; vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
output_data_tmp[right] = _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] + _row0 = vld2q_f32(input_ptr2);
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + _row1 = vld2q_f32(input_ptr3);
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] + _ext = vextq_f32(_row0.val[0], _ext, 1);
(1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
w12 * input_const[out2in_mid + 1] + _result0 =
w22 * input_const[out2in_mid + in_w + 1]); vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0);
if (if_bias) { _result0 =
output_data_tmp[left] += bias_data[j]; vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1);
output_data_tmp[right] += bias_data[j]; _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
} _result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
if (if_relu) { _result1 =
output_data_tmp[left] = vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1);
output_data_tmp[left] < 0 ? 0 : output_data_tmp[left]; _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0);
output_data_tmp[right] =
output_data_tmp[right] < 0 ? 0 : output_data_tmp[right]; _ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr3[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr4);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr4[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
vst1q_lane_f32(output_ptr1 + 2, _result1, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
vst1_f32(output_ptr1, vget_low_f32(_result1));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
vst1q_lane_f32(output_ptr1, _result1, 0);
break;
} }
input_ptr0 += output_w_remain * 2;
input_ptr1 += output_w_remain * 2;
input_ptr2 += output_w_remain * 2;
input_ptr3 += output_w_remain * 2;
input_ptr4 += output_w_remain * 2;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
} }
filter_data_tmp += 9; // pad right
} if (padding_w > 0) {
input_data += inhxw * c; float32x4_t row0 = vld1q_f32(input_ptr0);
output_data += outhxw * c; float32x4_t row1 = vld1q_f32(input_ptr1);
} float32x4_t row2 = vld1q_f32(input_ptr2);
#endif float32x4_t row3 = vld1q_f32(input_ptr3);
} float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t acc0, acc1;
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, for (int w = valid_w_end; w < output_w; ++w) {
const framework::Tensor *filter, int padding = 2 * w + 3 - (padding_w + input_w);
framework::Tensor *output, if (padding >= 3) {
const framework::Tensor *new_scale, *output_ptr0 = 0;
const framework::Tensor *new_bias, *output_ptr1 = 0;
bool if_relu) {
#if __ARM_NEON
// #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>();
// const float *newbias_data = new_bias->data<float>();
//
// const int batch_size = static_cast<int>(input->dims()[0]);
// const int input_channel = static_cast<int>(input->dims()[1]);
//
// const int input_height = static_cast<int>(input->dims()[2]);
// const int input_width = static_cast<int>(input->dims()[3]);
// const int output_height = static_cast<int>(output->dims()[2]);
// const int output_width = static_cast<int>(output->dims()[3]);
// const int inhxw = input_height * input_width;
// const int outhxw = output_height * output_width;
//
// float32x4_t zero = vdupq_n_f32(0.0);
// for (int b = 0; b < batch_size; b++) {
// #pragma omp parallel for
// for (int c = 0; c < input_channel; c++) {
// const float *filter_data = filter->data<float>() + c * 9;
// const float *input_data = input->data<float>() + c * inhxw;
// float *output_data = output->data<float>() + c * outhxw;
// float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]);
// float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]);
//
// float w00 = filter_data[0];
// float w01 = filter_data[1];
// float w02 = filter_data[2];
// float w10 = filter_data[3];
// float w11 = filter_data[4];
// float w12 = filter_data[5];
// float w20 = filter_data[6];
// float w21 = filter_data[7];
// float w22 = filter_data[8];
//
// int m;
// for (m = 1; m < output_width - 2; m = m + 3) {
// float *output_ptr = output_data + m;
// float32x4x2_t input_buff_mid{}, input_buff_bottom{};
// float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
// input_buff_mid = vld2q_f32(input_data + (2 * m - 1));
// input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m -
// 1));
//
// in0 = input_buff_mid.val[0];
// tmp0 = input_buff_mid.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
//
// in2 = input_buff_bottom.val[0];
// tmp2 = input_buff_bottom.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
//
// out0 = vmulq_n_f32(in0, w10);
// out0 = vmlaq_n_f32(out0, tmp0, w11);
// out0 = vmlaq_n_f32(out0, tmp1, w12);
// out0 = vmlaq_n_f32(out0, in2, w20);
// out0 = vmlaq_n_f32(out0, tmp2, w21);
// out0 = vmlaq_n_f32(out0, tmp3, w22);
// out0 = vmlaq_f32(vnewbias, vnewscale, out0);
// if (if_relu) {
// out0 = vmaxq_f32(out0, zero);
// }
// vst1q_lane_f32(output_ptr, out0, 0);
// vst1q_lane_f32(output_ptr + 1, out0, 1);
// vst1q_lane_f32(output_ptr + 2, out0, 2);
// }
// for (m = 1; m < output_width - 2; m += 3) {
// }
// for (int j = m; j < output_width; j++) {
// output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] *
// w11 +
// input_data[2 * j + 1] * w12 +
// input_data[2 * j - 1 + input_width] * w20 +
// input_data[2 * j + input_width] * w21 +
// input_data[2 * j + 1 + input_width] * w22;
// output_data[j] = newscale_data[c] * output_data[j] +
// newbias_data[c]; if (if_relu) {
// output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
// }
// }
//
// for (int i = 1; i < output_height; i += 1) {
// for (int m = 1; m < output_width - 2; m += 3) {
// float *output_ptr = output_data + i * output_width + m;
// float32x4x2_t input_buff_top{}, input_buff_mid{},
// input_buff_bottom{}; float32x4_t in0, in1, in2, in3, in4, in5,
// tmp0, tmp1, tmp2, tmp3,
// tmp4, tmp5, out0;
// input_buff_top =
// vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m -
// 1));
// input_buff_mid =
// vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1));
// input_buff_bottom =
// vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m -
// 1));
//
// in0 = input_buff_top.val[0];
// tmp0 = input_buff_top.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
//
// in2 = input_buff_mid.val[0];
// tmp2 = input_buff_mid.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
//
// in4 = input_buff_bottom.val[0];
// tmp4 = input_buff_bottom.val[1];
// tmp5 = vextq_f32(in4, zero, 1);
//
// out0 = vmulq_n_f32(in0, w00);
// out0 = vmlaq_n_f32(out0, tmp0, w01);
// out0 = vmlaq_n_f32(out0, tmp1, w02);
// out0 = vmlaq_n_f32(out0, in2, w10);
// out0 = vmlaq_n_f32(out0, tmp2, w11);
// out0 = vmlaq_n_f32(out0, tmp3, w12);
// out0 = vmlaq_n_f32(out0, in4, w20);
// out0 = vmlaq_n_f32(out0, tmp4, w21);
// out0 = vmlaq_n_f32(out0, tmp5, w22);
// out0 = vmlaq_f32(vnewbias, vnewscale, out0);
// if (if_relu) {
// out0 = vmaxq_f32(out0, zero);
// }
// vst1q_lane_f32(output_ptr, out0, 0);
// vst1q_lane_f32(output_ptr + 1, out0, 1);
// vst1q_lane_f32(output_ptr + 2, out0, 2);
// }
// int m;
// for (m = 1; m < output_width - 2; m += 3) {
// }
// for (int j = m; j < output_width; j++) {
// output_data[i * output_width + j] =
// input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 +
// input_data[(2 * i - 1) * input_width + 2 * j] * w01 +
// input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 +
// input_data[(2 * i) * input_width + 2 * j - 1] * w10 +
// input_data[(2 * i) * input_width + 2 * j] * w11 +
// input_data[(2 * i) * input_width + 2 * j + 1] * w12 +
// input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 +
// input_data[(2 * i + 1) * input_width + 2 * j] * w21 +
// input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22;
// output_data[i * output_width + j] =
// newscale_data[c] * output_data[i * output_width + j] +
// newbias_data[c];
// if (if_relu) {
// output_data[i * output_width + j] =
// output_data[i * output_width + j] < 0
// ? 0
// : output_data[i * output_width + j];
// }
// }
// }
// output_data[0] = input_data[0] * w11 + input_data[1] * w12 +
// input_data[input_height] * w21 +
// input_data[input_height + 1] * w22;
//
// output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c];
// if (if_relu) {
// output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
// }
// for (int i = 1; i < output_height; i++) {
// output_data[i * output_width] =
// input_data[(2 * i - 1) * input_width] * w01 +
// input_data[(2 * i - 1) * input_width + 1] * w02 +
// input_data[(2 * i) * input_width] * w11 +
// input_data[(2 * i) * input_width + 1] * w12 +
// input_data[(2 * i + 1) * input_width] * w21 +
// input_data[(2 * i + 1) * input_width + 1] * w22;
//
// output_data[i * output_width] =
// newscale_data[c] * output_data[i * output_width] +
// newbias_data[c];
// if (if_relu) {
// output_data[i * output_width] = output_data[i * output_width] < 0
// ? 0
// : output_data[i *
// output_width];
// }
// }
// }
// }
//
// #else
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
// const int out_l = out_h;
// const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
/// todo : fix if_pad when w != h
const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0;
const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int w_times = (out_w - 2) / 3;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
#pragma omp parallel for
for (int j = 0; j < c; j++) {
const float *input_row_ptr;
float *output_row_ptr;
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp;
const float *filter_data_tmp = filter_data + 9 * j;
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
} else { } else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); acc0 = vmulq_f32(row0, _ker[0]);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); acc1 = vmulq_f32(row2, _ker[0]);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); acc0 = vmlaq_f32(acc0, row1, _ker[1]);
} acc1 = vmlaq_f32(acc1, row3, _ker[1]);
input_buff_mid = vld2q_f32(input_row_ptr); acc0 = vmlaq_f32(acc0, row2, _ker[2]);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); acc1 = vmlaq_f32(acc1, row4, _ker[2]);
float sum0 = vgetq_lane_f32(acc0, 0);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); float sum1 = vgetq_lane_f32(acc1, 0);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); if (padding == 1) {
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); sum0 += vgetq_lane_f32(acc0, 1);
sum1 += vgetq_lane_f32(acc1, 1);
elewise_res1 = }
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); *output_ptr0 = sum0;
elewise_res0 = *output_ptr1 = sum1;
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
} }
vst1q_lane_f32(output_row_ptr, res3, 0); output_ptr0++;
vst1q_lane_f32(output_row_ptr + 1, res3, 1); output_ptr1++;
vst1q_lane_f32(output_row_ptr + 2, res3, 2);
input_row_ptr += 6;
output_row_ptr += 3;
} }
} }
clock(); }
// remain height
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; int start_h = valid_h_start + (valid_h & 0xfffe);
output_row_ptr = output_data_tmp + 1 + h_mid * out_w; if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (2 * start_h - padding_h) * input_w;
for (int w4 = 0; w4 < w_times + 1; w4++) { const float *input_ptr1 = input_ptr0 + input_w;
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); const float *input_ptr2 = input_ptr1 + input_w;
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); float *output_ptr0 = output_ptr + start_h * output_w;
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); // pad left
if (padding_w) {
input_buff_mid = vld2q_f32(input_row_ptr); for (int w = valid_w_start - 1; w >= 0; --w) {
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); int padding = padding_w - (w << 1);
if (padding >= 3) {
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); output_ptr0[w] = 0;
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); } else {
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); float32x4_t row0 = vld1q_f32(input_ptr0 - padding);
float32x4_t row1 = vld1q_f32(input_ptr1 - padding);
if (!if_pad_b) { float32x4_t row2 = vld1q_f32(input_ptr2 - padding);
elewise_res1 = float32x4_t acc0 = vmulq_f32(row0, _ker[0]);
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); acc0 = vmlaq_f32(acc0, row1, _ker[1]);
elewise_res0 = acc0 = vmlaq_f32(acc0, row2, _ker[2]);
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); float sum0 = vgetq_lane_f32(acc0, 2);
elewise_res2 = if (padding == 1) {
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); sum0 += vgetq_lane_f32(acc0, 1);
} }
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), output_ptr0[w] = sum0;
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
if ((w4 != w_times)) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
vst1q_lane_f32(output_row_ptr + 2, res3, 2);
} else {
if (out_w - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0);
} else if (out_w - 2 - w_times * 3 == 2) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
} }
} }
input_row_ptr += 6; input_ptr0 += input_w_start;
output_row_ptr += 3; input_ptr1 += input_w_start;
input_ptr2 += input_w_start;
output_ptr0 += valid_w_start;
} }
// valid
output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + float32x4_t _result0, _ext;
input_const[in_w] * w21 + for (int loop = 0; loop < output_w_tiles; ++loop) {
input_const[in_w + 1] * w22; float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
out2in_mid = (out_w - 1) * 2; float32x4x2_t _row2 = vld2q_f32(input_ptr2);
output_data_tmp[out_w - 1] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + _ext = vextq_f32(_row0.val[0], _ext, 1);
w20 * input_const[out2in_mid + in_w - 1] + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
w21 * input_const[out2in_mid + in_w] + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
(1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] + _result0 =
w22 * input_const[out2in_mid + in_w + 1]); vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
out2in_mid = (out_h - 1) * 2 * in_w;
_ext = vextq_f32(_row1.val[0], _ext, 1);
output_data_tmp[out_w * (out_h - 1)] = _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
w01 * input_const[out2in_mid - in_w] + _result0 =
w02 * input_const[out2in_mid - in_w + 1] + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + _result0 =
(1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
w22 * input_const[out2in_mid + in_w + 1]); _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2;
_ext = vextq_f32(_row2.val[0], _ext, 1);
output_data_tmp[out_h * out_w - 1] = _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
w00 * input_const[out2in_mid - in_w - 1] + _result0 =
w01 * input_const[out2in_mid - in_w] + vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0);
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + _result0 =
(1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] + vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1);
w21 * input_const[out2in_mid + in_w]) + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
(1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1]) + vst1q_f32(output_ptr0, _result0);
(1 - if_pad_r) * (1 - if_pad_b) * w22 *
input_const[out2in_mid + in_w + 1]; input_ptr0 += 8;
output_data_tmp[0] = input_ptr1 += 8;
output_data_tmp[0] * newscale_data[j] + newbias_data[j]; input_ptr2 += 8;
output_data_tmp[out_w - 1] = output_ptr0 += 4;
output_data_tmp[out_w - 1] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_w * (out_h - 1)] =
output_data_tmp[out_w * (out_h - 1)] * newscale_data[j] +
newbias_data[j];
output_data_tmp[out_h * out_w - 1] =
output_data_tmp[out_h * out_w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0];
output_data_tmp[out_w - 1] =
output_data_tmp[out_w - 1] < 0 ? 0 : output_data_tmp[out_w - 1];
output_data_tmp[out_w * (out_h - 1)] =
output_data_tmp[out_w * (out_h - 1)] < 0
? 0
: output_data_tmp[out_w * (out_h - 1)];
output_data_tmp[out_h * out_w - 1] =
output_data_tmp[out_h * out_w - 1] < 0
? 0
: output_data_tmp[out_h * out_w - 1];
} }
for (int i = 1; i < out_h - 1; i++) { // remain w
out2in_mid = i * 2 * in_w; if (output_w_remain > 0) {
output_data_tmp[i * out_w] = w01 * input_const[out2in_mid - in_w] + float32x4x2_t _row0 = vld2q_f32(input_ptr0);
w02 * input_const[out2in_mid - in_w + 1] + float32x4x2_t _row1 = vld2q_f32(input_ptr1);
w11 * input_const[out2in_mid] + float32x4x2_t _row2 = vld2q_f32(input_ptr2);
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] + _ext = vextq_f32(_row0.val[0], _ext, 1);
w22 * input_const[out2in_mid + in_w + 1]; _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
out2in_mid = i * 2 * in_w + (out_w - 1) * 2; _result0 =
output_data_tmp[i * out_w + out_w - 1] = vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
w00 * input_const[out2in_mid - in_w - 1] + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + _ext = vextq_f32(_row1.val[0], _ext, 1);
w20 * input_const[out2in_mid + in_w - 1] + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
w21 * input_const[out2in_mid + in_w] + _result0 =
(1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
w12 * input_const[out2in_mid + 1] + _result0 =
w22 * input_const[out2in_mid + in_w + 1]); vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
output_data_tmp[i * out_w] = _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
output_data_tmp[i * out_w] * newscale_data[j] + newbias_data[j];
output_data_tmp[i * out_w + out_w - 1] = _ext = vextq_f32(_row2.val[0], _ext, 1);
output_data_tmp[i * out_w + out_w - 1] * newscale_data[j] + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
newbias_data[j]; _result0 =
if (if_relu) { vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0);
output_data_tmp[i * out_w] = _result0 =
output_data_tmp[i * out_w] < 0 ? 0 : output_data_tmp[i * out_w]; vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1);
output_data_tmp[i * out_w + out_w - 1] = _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
output_data_tmp[i * out_w + out_w - 1] < 0
? 0 switch (output_w_remain) {
: output_data_tmp[i * out_w + out_w - 1]; case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
break;
} }
input_ptr0 += output_w_remain * 2;
input_ptr1 += output_w_remain * 2;
input_ptr2 += output_w_remain * 2;
output_ptr0 += output_w_remain;
} }
} // pad right
input_data += inhxw * c; if (padding_w) {
output_data += outhxw * c; float32x4_t row0 = vld1q_f32(input_ptr0);
} float32x4_t row1 = vld1q_f32(input_ptr1);
// #endif float32x4_t row2 = vld1q_f32(input_ptr2);
#endif float32x4_t acc0;
} for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
void DepthwiseConv3x3s2p0(const framework::Tensor *input, if (padding >= 3) {
const framework::Tensor *filter, *output_ptr0 = 0;
framework::Tensor *output, framework::Tensor *bias, } else {
bool if_bias, bool if_relu) { acc0 = vmulq_f32(row0, _ker[0]);
#if __ARM_NEON acc0 = vmlaq_f32(acc0, row1, _ker[1]);
const int batch_size = static_cast<int>(input->dims()[0]); acc0 = vmlaq_f32(acc0, row2, _ker[2]);
const int input_channel = static_cast<int>(input->dims()[1]); float sum0 = vgetq_lane_f32(acc0, 0);
if (padding == 1) {
const int input_height = static_cast<int>(input->dims()[2]); sum0 += vgetq_lane_f32(acc0, 1);
const int input_width = static_cast<int>(input->dims()[3]); }
const int output_height = static_cast<int>(output->dims()[2]); *output_ptr0 = sum0;
const int output_width = static_cast<int>(output->dims()[3]);
const int inhxw = input_height * input_width;
const int outhxw = output_height * output_width;
output->mutable_data<float>();
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; b++) {
#pragma omp parallel for
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data;
float32x4_t biasv;
if (if_bias) {
bias_data = bias->data<float>() + c;
biasv = vld1q_dup_f32(bias_data);
}
float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0];
float w01 = filter_data[1];
float w02 = filter_data[2];
float w10 = filter_data[3];
float w11 = filter_data[4];
float w12 = filter_data[5];
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
for (int i = 0; i < output_height; i += 1) {
for (int m = 0; m < output_width - 2; m += 3) {
float *output_ptr = output_data + i * output_width + m;
float32x4x2_t input_buff_top{}, input_buff_mid{}, input_buff_bottom{};
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3,
tmp4, tmp5, out0;
input_buff_top =
vld2q_f32(input_data + (2 * i) * input_width + (2 * m));
input_buff_mid =
vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m));
input_buff_bottom =
vld2q_f32(input_data + (2 * i + 2) * input_width + (2 * m));
in0 = input_buff_top.val[0];
tmp0 = input_buff_top.val[1];
tmp1 = vextq_f32(in0, zero, 1);
in2 = input_buff_mid.val[0];
tmp2 = input_buff_mid.val[1];
tmp3 = vextq_f32(in2, zero, 1);
in4 = input_buff_bottom.val[0];
tmp4 = input_buff_bottom.val[1];
tmp5 = vextq_f32(in4, zero, 1);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
if (if_bias) {
out0 = vaddq_f32(out0, biasv);
}
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_lane_f32(output_ptr, out0, 0);
vst1q_lane_f32(output_ptr + 1, out0, 1);
vst1q_lane_f32(output_ptr + 2, out0, 2);
}
int m;
for (m = 0; m < output_width - 2; m += 3) {
}
for (int j = m; j < output_width; j++) {
int index = i * output_width + j;
output_data[index] =
input_data[(2 * i) * input_width + 2 * j] * w00 +
input_data[(2 * i) * input_width + 2 * j + 1] * w01 +
input_data[(2 * i) * input_width + 2 * j + 2] * w02 +
input_data[(2 * i + 1) * input_width + 2 * j] * w10 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w11 +
input_data[(2 * i + 1) * input_width + 2 * j + 2] * w12 +
input_data[(2 * i + 2) * input_width + 2 * j] * w20 +
input_data[(2 * i + 2) * input_width + 2 * j + 1] * w21 +
input_data[(2 * i + 2) * input_width + 2 * j + 2] * w22;
if (if_bias) {
output_data[index] += *bias_data;
}
if (if_relu) {
output_data[index] =
output_data[index] < 0 ? 0 : output_data[index];
} }
output_ptr0++;
} }
} }
} }
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
} }
#endif
} }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif // __ARM_NEON__
...@@ -23,48 +23,6 @@ namespace paddle_mobile { ...@@ -23,48 +23,6 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
// TODO(hjchen2) need to be implemented // TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype> // template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input, // void DepthwiseConv3x3(const framework::Tensor *input,
......
...@@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, ...@@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); // return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// } // }
CPUInfo *info = CPUInfo::Info(); GemmExecutor<SgemmStrategy> exec(transA, transB, M, N, K);
GemmExecutor<SgemmStrategy> exec(info, transA, transB, M, N, K);
exec(alpha, A, lda, B, ldb, beta, C, ldc); exec(alpha, A, lda, B, ldb, beta, C, ldc);
} }
void cblas_sgemv(const bool trans, const int M, const int N, const float alpha, void cblas_sgemv(const bool trans, const int M, const int N, const float alpha,
const float *A, const int lda, const float *B, const float *A, const int lda, const float *B,
const float beta, float *C) { const float beta, float *C) {
CPUInfo *info = CPUInfo::Info(); GemvExecutor<SgemvStrategy> exec(trans, M, N);
GemvExecutor<SgemvStrategy> exec(info, trans, M, N);
exec(alpha, A, lda, B, beta, C); exec(alpha, A, lda, B, beta, C);
} }
......
...@@ -19,7 +19,6 @@ limitations under the License. */ ...@@ -19,7 +19,6 @@ limitations under the License. */
#include <omp.h> #include <omp.h>
#endif #endif
#include <sys/time.h> #include <sys/time.h>
#include <iostream>
#include "common/log.h" #include "common/log.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#include "operators/math/gemm/cpu_info.h" #include "operators/math/gemm/cpu_info.h"
...@@ -29,6 +28,8 @@ namespace paddle_mobile { ...@@ -29,6 +28,8 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
static CPUInfo *info = CPUInfo::Info();
int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num, unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num,
const int N, const int K) { const int N, const int K) {
...@@ -62,15 +63,9 @@ class GemmExecutor : public Executor { ...@@ -62,15 +63,9 @@ class GemmExecutor : public Executor {
typedef typename Strategy::Otype Otype; typedef typename Strategy::Otype Otype;
public: public:
GemmExecutor(const CPUInfo *info, const bool transA, const bool transB, GemmExecutor(const bool transA, const bool transB, const int M, const int N,
const int M, const int N, const int K) const int K)
: Executor(), : Executor(), transA_(transA), transB_(transB), M_(M), N_(N), K_(K) {
info_(info),
transA_(transA),
transB_(transB),
M_(M),
N_(N),
K_(K) {
unsigned int L1_size = 0; unsigned int L1_size = 0;
unsigned int L2_size = 0; unsigned int L2_size = 0;
if (M_ > N_) { if (M_ > N_) {
...@@ -212,8 +207,6 @@ class GemmExecutor : public Executor { ...@@ -212,8 +207,6 @@ class GemmExecutor : public Executor {
virtual ~GemmExecutor() {} virtual ~GemmExecutor() {}
private: private:
const CPUInfo *info_;
const unsigned int M_; const unsigned int M_;
const unsigned int N_; const unsigned int N_;
const unsigned int K_; const unsigned int K_;
...@@ -242,8 +235,8 @@ class GemvExecutor : public Executor { ...@@ -242,8 +235,8 @@ class GemvExecutor : public Executor {
typedef typename Strategy::Otype Otype; typedef typename Strategy::Otype Otype;
public: public:
GemvExecutor(const CPUInfo *info, const bool transA, const int M, const int N) GemvExecutor(const bool transA, const int M, const int N)
: Executor(), info_(info), M_(M), N_(N) {} : Executor(), M_(M), N_(N) {}
void operator()(const float alpha, const Itype *A, const int lda, void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const float beta, Otype *C) { const Itype *B, const float beta, Otype *C) {
...@@ -253,8 +246,6 @@ class GemvExecutor : public Executor { ...@@ -253,8 +246,6 @@ class GemvExecutor : public Executor {
virtual ~GemvExecutor() {} virtual ~GemvExecutor() {}
private: private:
const CPUInfo *const info_;
const unsigned int M_; const unsigned int M_;
const unsigned int N_; const unsigned int N_;
......
...@@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, ...@@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
for (int i = start_height; i < end_height; i += stride_h) { for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) { if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(float)); // memcpy(col_data, im_data, extract * sizeof(float));
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4_t img = vld1q_f32(im_data + s);
vst1q_f32(col_data + s, img);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s];
}
} else if (stride_w == 2) { } else if (stride_w == 2) {
int s = 0; int s = 0;
#if __ARM_NEON #if __ARM_NEON
...@@ -109,325 +119,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( ...@@ -109,325 +119,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
const float *im_data = im.data<float>(); const float *im_data = im.data<float>();
float *col_data = col->data<float>(); float *col_data = col->data<float>();
#if __ARM_NEON #if __ARM_NEON
const int osize = col_height; if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
const int isize = im_height;
bool pad1 = padding[0] > 0;
bool pad2 =
(pad1 && padding[1] &&
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1 && im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
float *col0 = col_data + 0 * oosize + 2 * osize + 2;
float *col1 = col_data + 1 * oosize + 2 * osize + 1;
float *col2 = col_data + 2 * oosize + 2 * osize;
float *col3 = col_data + 3 * oosize + osize + 2;
float *col4 = col_data + 4 * oosize + osize + 1;
float *col5 = col_data + 5 * oosize + osize;
float *col6 = col_data + 6 * oosize + 2;
float *col7 = col_data + 7 * oosize + 1;
float *col8 = col_data + 8 * oosize;
float32x4_t im1;
const float *im_tmp_data = im_data + osize + 1;
int rrsize = oosize - osize - 1;
int nr4 = rrsize / 4;
int mr4 = rrsize % 4;
for (int i = 0; i < nr4; ++i) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
vst1q_f32(col6, im1);
vst1q_f32(col7, im1);
vst1q_f32(col8, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mr4; ++i) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
*col6 = *im_tmp_data;
*col7 = *im_tmp_data;
*col8 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data++;
}
im_tmp_data = im_data + 1;
col0 = col_data + 0 * oosize + osize + 2;
col1 = col_data + 1 * oosize + osize + 1;
col2 = col_data + 2 * oosize + osize;
col3 = col_data + 3 * oosize + 2;
col4 = col_data + 4 * oosize + 1;
col5 = col_data + 5 * oosize;
for (int i = 0; i < nk4; i++) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mk4; i++) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
im_tmp_data++;
}
// fill 0 1 11;
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
col_data[0 * oosize + osize + 1] = im_data[0];
col_data[3 * oosize + 1] = im_data[0];
col_data[6 * oosize + 1] = im_data[osize];
col_data[1 * oosize + osize] = im_data[0];
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[osize];
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 &&
im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
// 3 2 3 1 0 1 3 2 3
float *col0 = col_data + 0 * oosize + osize + 1;
float *col1 = col_data + 1 * oosize + osize;
float *col2 = col_data + 2 * oosize + osize;
float *col3 = col_data + 3 * oosize + 1;
float *col4 = col_data + 4 * oosize;
float *col5 = col_data + 5 * oosize;
float *col6 = col_data + 6 * oosize + 1;
float *col7 = col_data + 7 * oosize;
float *col8 = col_data + 8 * oosize;
float32x4x2_t im01;
float32x4x2_t im23;
const float *im_tmp_data0 = im_data;
const float *im_tmp_data2 = im_data + isize;
for (int j = 0; j < osize; ++j) {
for (int i = 0; i < nk4; ++i) {
im01 = vld2q_f32(im_tmp_data0);
im23 = vld2q_f32(im_tmp_data2);
vst1q_f32(col0, im23.val[1]);
vst1q_f32(col1, im23.val[0]);
vst1q_f32(col2, im23.val[1]);
vst1q_f32(col3, im01.val[1]);
vst1q_f32(col4, im01.val[0]);
vst1q_f32(col5, im01.val[1]);
vst1q_f32(col6, im23.val[1]);
vst1q_f32(col7, im23.val[0]);
vst1q_f32(col8, im23.val[1]);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data0 += 8;
im_tmp_data2 += 8;
}
const float *im_tmp_data1 = im_tmp_data0 + 1;
const float *im_tmp_data3 = im_tmp_data2 + 1;
for (int i = 0; i < mk4; ++i) {
*col0 = *im_tmp_data3;
*col1 = *im_tmp_data2;
*col2 = *im_tmp_data3;
*col3 = *im_tmp_data1;
*col4 = *im_tmp_data0;
*col5 = *im_tmp_data1;
*col6 = *im_tmp_data3;
*col7 = *im_tmp_data2;
*col8 = *im_tmp_data3;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data0 += 2;
im_tmp_data1 += 2;
im_tmp_data2 += 2;
im_tmp_data3 += 2;
}
im_tmp_data0 += (isize - fill);
im_tmp_data2 += (isize - fill);
}
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
if (pad2) {
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
}
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
if (pad2) {
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
}
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
if (pad2) {
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
}
col_data[1 * oosize + osize] = im_data[isize];
for (int i = 1; i < osize; ++i) {
col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1];
}
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[isize];
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width; int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width; int col_spatial_size = col_height * col_width;
// pad 0 // pad 0
......
...@@ -441,10 +441,8 @@ class ConvParam : public OpParam { ...@@ -441,10 +441,8 @@ class ConvParam : public OpParam {
enum ExecMode { enum ExecMode {
EXEC_INVALID = 0, EXEC_INVALID = 0,
EXEC_GEMM_FLOAT, EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT, EXEC_DEPTHWISE3x3S1_FLOAT,
EXEC_DEPTHWISE3x3S2P0_FLOAT, EXEC_DEPTHWISE3x3S2_FLOAT,
EXEC_DEPTHWISE3x3S2P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT, EXEC_WINOGRAD5X5_FLOAT,
EXEC_DEPTHWISE5x5_FLOAT, EXEC_DEPTHWISE5x5_FLOAT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册