提交 a8b775ec 编写于 作者: H hjchen2

Refator depthwise conv3x3 and fix it's bugs for armv8

上级 54a801d5
...@@ -61,25 +61,15 @@ template <> ...@@ -61,25 +61,15 @@ template <>
void ConvAddBNReluKernel<CPU, float>::Compute( void ConvAddBNReluKernel<CPU, float>::Compute(
const FusionConvAddBNReluParam<CPU> &param) { const FusionConvAddBNReluParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Output(), param.NewScale(), param.Paddings(), param.Output());
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP #ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h" #include "operators/kernel/conv_add_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_add_arm_func.h" #include "operators/kernel/central-arm-func/conv_add_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -21,12 +22,44 @@ namespace operators { ...@@ -21,12 +22,44 @@ namespace operators {
template <> template <>
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) { bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
InitBaseConvKernel(param);
return true; return true;
} }
template <> template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) { void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
ConvAddCompute<float>(param); switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvAddBasic(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
} }
template class ConvAddKernel<CPU, float>; template class ConvAddKernel<CPU, float>;
......
...@@ -31,21 +31,14 @@ template <> ...@@ -31,21 +31,14 @@ template <>
void ConvAddReluKernel<CPU, float>::Compute( void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam<CPU> &param) { const FusionConvAddReluParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Bias(), true, true); param.Paddings(), param.Output());
break; math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output()); math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break; break;
#ifndef __aarch64__ #ifndef __aarch64__
......
...@@ -16,7 +16,8 @@ limitations under the License. */ ...@@ -16,7 +16,8 @@ limitations under the License. */
#include "operators/kernel/conv_bn_add_relu_kernel.h" #include "operators/kernel/conv_bn_add_relu_kernel.h"
#include <cmath> #include <cmath>
#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h" #include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -51,13 +52,46 @@ bool ConvBNAddReluKernel<CPU, float>::Init( ...@@ -51,13 +52,46 @@ bool ConvBNAddReluKernel<CPU, float>::Init(
} }
param->SetNewScale(new_scale); param->SetNewScale(new_scale);
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
InitBaseConvKernel(param);
return true; return true;
} }
template <> template <>
void ConvBNAddReluKernel<CPU, float>::Compute( void ConvBNAddReluKernel<CPU, float>::Compute(
const FusionConvBNAddReluParam<CPU> &param) { const FusionConvBNAddReluParam<CPU> &param) {
ConvBNAddReluCompute<float>(param); switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvBNAddReluParam<CPU>>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
} }
template class ConvBNAddReluKernel<CPU, float>; template class ConvBNAddReluKernel<CPU, float>;
......
...@@ -60,25 +60,15 @@ template <> ...@@ -60,25 +60,15 @@ template <>
void ConvBNReluKernel<CPU, float>::Compute( void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam<CPU> &param) { const FusionConvBNReluParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Output(), param.NewScale(), param.Paddings(), param.Output());
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
......
...@@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam<CPU> *param) { ...@@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
#endif // __aarch64__ #endif // __aarch64__
} else { } else {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] && if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 && param->Strides()[0] == 1) {
param->Paddings()[0] == param->Paddings()[1]) { param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT;
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 1 && param->Strides()[0] == 2) {
param->Paddings()[0] == param->Paddings()[1]) { param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT;
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
} else if (depth3x3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
#ifndef __aarch64__ #ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) { param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] && } else if (conv3x3 && !depth3x3 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] && param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 /* && param->Strides()[0] == 1 && param->Dilations()[0] == 1
param->Output()->dims()[1] >= 16 && #if 0
&& param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 && param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */) { param->Input()->dims()[2] <= 140 */ /* refered from ncnn */
#endif
) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight // transform weight
param->transformed_filter_ = new framework::LoDTensor; param->transformed_filter_ = new framework::LoDTensor;
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <iostream>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -41,21 +43,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { ...@@ -41,21 +43,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConv5x5<int8_t, int32_t>(param); DepthwiseConv5x5<int8_t, int32_t>(param);
break; break;
#endif // __aarch64__ #endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
nullptr, false, false); param.Paddings(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
break; break;
#ifndef __aarch64__ #ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
......
...@@ -60,25 +60,15 @@ template <> ...@@ -60,25 +60,15 @@ template <>
void DWConvBNReluKernel<CPU, float>::Compute( void DWConvBNReluKernel<CPU, float>::Compute(
const FusionDWConvBNReluParam<CPU> &param) { const FusionDWConvBNReluParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Output(), param.NewScale(), param.Paddings(), param.Output());
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
break; break;
......
...@@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) { ...@@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
} }
} }
template <typename P>
void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false);
} else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, false);
}
} else {
ConvAddBasic(param);
}
}
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) { ...@@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
} }
} }
} }
template <typename P>
void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvBNAddReluBasic(param);
}
}
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) { ...@@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&out_slice); &out_slice);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
col2vol(col, dilations, strides, paddings, &out_slice); col2vol(col, dilations, strides, paddings, &out_slice);
} }
......
...@@ -12,2066 +12,1042 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,2066 +12,1042 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include "operators/math/depthwise_conv3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void DepthwiseConv3x3(const framework::Tensor *input, #ifndef __aarch64__
const std::vector<int> &strides, inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) {
const std::vector<int> &paddings, float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0));
const framework::Tensor *filter, framework::Tensor *bias, float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1));
framework::Tensor *output, bool if_bias) { return vcombine_f32(sum0, sum1);
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int _kernel_size = 3;
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const float zero = 0;
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int filter_channel_stride = 9;
const float *input_ptr = input->data<float>();
const float *filter_ptr = filter->data<float>();
if (if_bias) {
math::expand_bias(*bias, 1, output->dims());
output->ShareDataWith(*bias);
}
float *output_ptr = output->mutable_data<float>();
const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr2;
int hstart, wstart, hend, wend;
float result;
for (int i = 0; i < batch_size; ++i) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
const float *input_data =
input_ptr + (i * output_channels + c) * input_channel_stride;
float *output_data =
output_ptr + (i * output_channels + c) * output_channel_stride;
filter1 = filter_ptr + c * filter_channel_stride;
filter2 = filter1 + 3;
filter3 = filter2 + 3;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, input_height);
wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
output_ptr2 = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
result = 0;
float fake_input[9] = {0};
if (hstart == 0 && wstart == 0) {
// 左上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k -
(3 - wend)];
}
}
}
} else if (hstart == 0 && wend == input_width) {
// 右上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height && wstart == 0) {
// 左下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - 1 - hstart && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k - (3 - wend)];
}
}
}
} else if (hend == input_height && wend == input_width) {
// 右下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1 &&
k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (hstart == 0) {
// 顶部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height) {
// 底部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (wstart == 0) {
// 左侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width +
(k - (3 - wend))];
}
}
}
} else if (wend == input_width) {
// 右侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
}
for (int l = 0; l < 9; ++l) {
result += fake_input[l] * filter1[l];
}
if (if_bias) {
output_data[ph * output_width + pw] += result;
} else {
output_data[ph * output_width + pw] = result;
}
} else {
#if __ARM_NEON
#if __aarch64__
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t v_filter1 = vld1q_f32(filter1);
const float32x4_t v_filter2 = vld1q_f32(filter2);
const float32x4_t v_filter3 = vld1q_f32(filter3);
float32x4_t mula = vmulq_f32(data1, v_filter1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
if (if_bias) {
output_data[ph * output_width + pw] += vget_lane_f32(res, 0);
} else {
output_data[ph * output_width + pw] = vget_lane_f32(res, 0);
}
#else
asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t"
"vld1.32 {q4}, [%[filter1]] \n\t"
"vmov.f32 q0, #0.0 \n\t"
"vld1.32 {q2}, [%[pos2]] \n\t"
"vld1.32 {q5}, [%[filter2]] \n\t"
"vmla.f32 q0, q1, q4 \n\t"
"vld1.32 {q3}, [%[pos3]] \n\t"
"vld1.32 {q6}, [%[filter3]] \n\t"
"vmla.f32 q0, q2, q5 \n\t"
"vmla.f32 q0, q3, q6 \n\t"
"vmov.f32 d1[1], %[zero] \n\t"
"vadd.f32 d4, d0, d1 \n\t"
"vadd.f32 s10, s8, s9 \n\t"
"vst1.32 {d5[0]},[%[output_ptr]] \n\t"
:
: [input_data] "r"(input_data), [pos1] "r"(pos1),
[pos2] "r"(pos2), [pos3] "r"(pos3), [filter1] "r"(filter1),
[filter2] "r"(filter2), [filter3] "r"(filter3),
[output_ptr] "r"(output_ptr2), [zero] "r"(zero)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
#endif // __aarch64__
#else
#endif // __ARM_NEON
}
}
}
}
}
} }
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu) {
#if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int hxw = h * w;
// const int l = h;
// leftTop, rightTop, leftBottom, rightBottom
const int lt = 0;
const int rt = w - 1;
const int lb = (h - 1) * w;
const int rb = h * w - 1;
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; ++b) {
#pragma omp parallel for
for (int j = 0; j < c; ++j) {
const float *filter_data_tmp = filter->data<float>() + j * 9;
const float *input_data = input->data<float>() + j * hxw;
float *output_data = output->mutable_data<float>() + j * hxw;
float32x4_t vbias;
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[lt] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[rt] = w10 * input_data[w - 2] + w11 * input_data[w - 1] +
w20 * input_data[2 * w - 2] +
w21 * input_data[2 * w - 1];
output_data[lb] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] +
w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[rb] =
w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] +
w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1];
if (if_bias) {
output_data[lt] += bias_data[j];
output_data[rt] += bias_data[j];
output_data[lb] += bias_data[j];
output_data[rb] += bias_data[j];
}
if (if_relu) {
output_data[lt] = output_data[lt] < 0 ? 0 : output_data[lt];
output_data[rt] = output_data[rt] < 0 ? 0 : output_data[rt];
output_data[lb] = output_data[lb] < 0 ? 0 : output_data[lb];
output_data[rb] = output_data[rb] < 0 ? 0 : output_data[rb];
}
for (int i = 1; i < h - 1; ++i) {
int left = i * w;
int right = i * w + w - 1;
output_data[left] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] +
w11 * input_data[i * w] + w12 * input_data[i * w + 1] +
w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1];
output_data[right] = w00 * input_data[i * w + w - 1 - w - 1] +
w01 * input_data[i * w + w - 1 - w] +
w10 * input_data[i * w + w - 1 - 1] +
w11 * input_data[i * w + w - 1] +
w20 * input_data[i * w + w - 1 + w - 1] +
w21 * input_data[i * w + w - 1 + w];
if (if_bias) {
output_data[left] += bias_data[j];
output_data[right] += bias_data[j];
}
if (if_relu) {
output_data[left] = output_data[left] < 0 ? 0 : output_data[left];
output_data[right] = output_data[right] < 0 ? 0 : output_data[right];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w);
const float *input_tmp_end = input_tmp + (h - 2) * w;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w);
int c_mid = w_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr + (h - 1) * w, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]);
float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2);
}
}
// mid
for (int i = 0; i < h - 2; ++i) {
auto output_ptr = output_data + (i + 1) * w + 1;
input_tmp = input_data + i * w;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + w);
auto in4_tmp = vld1q_f32(input_tmp + w + w);
c_mid = w_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + w + 4);
auto in5_tmp = vld1q_f32(input_tmp + w + w + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
}
}
#endif #endif
}
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]);
const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int hxw = input_height * input_width;
// const int l = input_height;
const int h = input_height;
const int w = input_width;
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; b++) {
#pragma omp parallel for
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * hxw;
float *output_data = output->data<float>() + c * hxw;
float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]);
float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]);
float w00 = filter_data[0];
float w01 = filter_data[1];
float w02 = filter_data[2];
float w10 = filter_data[3];
float w11 = filter_data[4];
float w12 = filter_data[5];
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
for (int i = 1; i < output_height - 1; i++) {
float *output_ptr;
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, tmp4,
tmp5, out0;
for (int m = 1; m < output_width - 4; m += 4) {
output_ptr = output_data + i * output_width + m;
in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1);
in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3);
in2 = vld1q_f32(input_data + i * input_width + m - 1);
in3 = vld1q_f32(input_data + i * input_width + m + 3);
in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1);
in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
}
int m;
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[i * output_width + j] =
input_data[(i - 1) * input_width + j - 1] * w00 +
input_data[(i - 1) * input_width + j] * w01 +
input_data[(i - 1) * input_width + j + 1] * w02 +
input_data[(i)*input_width + j - 1] * w10 +
input_data[(i)*input_width + j] * w11 +
input_data[(i)*input_width + j + 1] * w12 +
input_data[(i + 1) * input_width + j - 1] * w20 +
input_data[(i + 1) * input_width + j] * w21 +
input_data[(i + 1) * input_width + j + 1] * w22;
output_data[i * output_width + j] =
newscale_data[c] * output_data[i * output_width + j] +
newbias_data[c];
if (if_relu) {
output_data[i * output_width + j] =
output_data[i * output_width + j] < 0
? 0
: output_data[i * output_width + j];
}
}
}
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] +
w20 * input_data[2 * w - 2] +
w21 * input_data[2 * w - 1];
output_data[(h - 1) * w] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] +
w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[h * w - 1] =
w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] +
w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1];
output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c];
output_data[w - 1] =
output_data[w - 1] * newscale_data[c] + newbias_data[c];
output_data[(h - 1) * w] =
output_data[(h - 1) * w] * newscale_data[c] + newbias_data[c];
output_data[h * w - 1] =
output_data[h * w - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - 1];
output_data[(h - 1) * w] =
output_data[(h - 1) * w] < 0 ? 0 : output_data[(h - 1) * w];
output_data[h * w - 1] =
output_data[h * w - 1] < 0 ? 0 : output_data[h * w - 1];
}
for (int i = 1; i < h - 1; ++i) {
output_data[i * w] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] +
w11 * input_data[i * w] + w12 * input_data[i * w + 1] +
w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1];
output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] +
w01 * input_data[i * w + w - 1 - w] +
w10 * input_data[i * w + w - 1 - 1] +
w11 * input_data[i * w + w - 1] +
w20 * input_data[i * w + w - 1 + w - 1] +
w21 * input_data[i * w + w - 1 + w];
output_data[i * w] =
output_data[i * w] * newscale_data[c] + newbias_data[c];
output_data[i * w + w - 1] =
output_data[i * w + w - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i * w];
output_data[i * w + w - 1] =
output_data[i * w + w - 1] < 0 ? 0 : output_data[i * w + w - 1];
}
}
int m;
for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr = output_data + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + m - 1);
in1 = vld1q_f32(input_data + m + 3);
in2 = vld1q_f32(input_data + input_width + m - 1);
in3 = vld1q_f32(input_data + input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
}
for (m = 1; (m + 3) < output_width - 1; m += 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 +
input_data[j + 1] * w12 +
input_data[input_width + j - 1] * w20 +
input_data[input_width + j] * w21 +
input_data[input_width + j + 1] * w22;
output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
}
}
for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr =
output_data + (output_height - 1) * output_width + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1);
in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3);
in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1);
in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
}
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[(output_height - 1) * input_width + j] =
input_data[(output_height - 2) * input_width + j - 1] * w00 +
input_data[(output_height - 2) * input_width + j] * w01 +
input_data[(output_height - 2) * input_width + j + 1] * w02 +
input_data[(output_height - 1) * input_width + j - 1] * w10 +
input_data[(output_height - 1) * input_width + j] * w11 +
input_data[(output_height - 1) * input_width + j + 1] * w12;
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] *
newscale_data[c] +
newbias_data[c];
if (if_relu) {
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] < 0
? 0
: output_data[(output_height - 1) * output_width + j];
}
}
}
}
/*
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
// const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w -
1] + w20 * input_data[2 * w - 2] + w21 * input_data[2 * w - 1];
output_data[(h - 1) * w] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w +
1] + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[h * w - 1] = w00 * input_data[h*w-w-2] +
w01 * input_data[h*w-w-1] +
w10 * input_data[h * w - 2] +
w11 * input_data[h * w - 1];
output_data[0] = output_data[0] * newscale_data[j] +
newbias_data[j]; output_data[w - 1] = output_data[w - 1] *
newscale_data[j] + newbias_data[j]; output_data[(h - 1) * w] =
output_data[(h - 1) * w] * newscale_data[j] + newbias_data[j];
output_data[h * w - 1] =
output_data[h * w - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w -
1]; output_data[(h - 1) * w] = output_data[(h - 1) * w] < 0 ? 0 :
output_data[(h - 1) * w]; output_data[h * w - 1] = output_data[h * w - 1]
< 0 ? 0 : output_data[h * w - 1];
}
for (int i = 1; i < h - 1; ++i) {
output_data[i * w] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1]
+ w11 * input_data[i * w] + w12 * input_data[i * w + 1] + w21 *
input_data[i * w + w] + w22 * input_data[i * w + w + 1]; output_data[i *
w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + w01 * input_data[i
* w + w - 1 - w] + w10 * input_data[i * w + w - 1 - 1] + w11 *
input_data[i * w + w - 1] + w20 * input_data[i * w + w - 1 + w - 1] + w21
* input_data[i * w + w - 1 + w]; output_data[i * w] = output_data[i * w]
* newscale_data[j] + newbias_data[j]; output_data[i * w + w - 1] =
output_data[i * w + w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i
* w]; output_data[i * w + w - 1] = output_data[i * w + w - 1] < 0 ? 0 :
output_data[i * w + w - 1];
}
}
// top 1 row and bottom 1 row template <int Stride = 1>
const float *input_tmp = input_data; inline void Depth3x3NormalRowLoadInput(const float *input, float32x4_t *y) {
y[0] = vld1q_f32(input);
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, y[2] = vld1q_f32(input + 4);
tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); in2 = y[1] = vextq_f32(y[0], y[2], 1);
vld1q_f32(input_tmp + w); const float *input_tmp_end = input_tmp + (h - y[2] = vextq_f32(y[0], y[2], 2);
2) * w; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end +
w); int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid >
3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); in3 =
vld1q_f32(input_tmp + w + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (h - 1) * w, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]);
float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2);
}
}
// mid
for (int i = 0; i < h - 2; ++i) {
auto output_ptr = output_data + (i + 1) * w + 1;
input_tmp = input_data + i * w;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + w);
auto in4_tmp = vld1q_f32(input_tmp + w + w);
c_mid = w_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + w + 4);
auto in5_tmp = vld1q_f32(input_tmp + w + w + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
}
}
*/
#endif
} }
/// w!=h not fix template <>
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, inline void Depth3x3NormalRowLoadInput<2>(const float *input, float32x4_t *y) {
const framework::Tensor *filter, float32x4x2_t x = vld2q_f32(input);
framework::Tensor *output, y[0] = x.val[0];
const framework::Tensor *new_scale, y[1] = x.val[1];
const framework::Tensor *new_bias, y[2] = vextq_f32(y[0], y[0], 1);
bool if_relu) { y[2] = vsetq_lane_f32(input[8], y[2], 3);
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int _kernel_size = 3;
const int stride_height = 2;
const int stride_width = 2;
const int padding_height = 1;
const int padding_width = 1;
const float zero = 0;
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int filter_channel_stride = 9;
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride;
const int filter_batch_stride = output_channels * output_channel_stride;
const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr;
int hstart, wstart, hend, wend;
float result;
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) {
filter1 = filter_data;
filter2 = filter1 + 3;
filter3 = filter2 + 3;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, input_height);
wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
output_ptr = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
result = 0;
float fake_input[9] = {0};
if (hstart == 0 && wstart == 0) {
// 左上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k -
(3 - wend)];
}
}
}
} else if (hstart == 0 && wend == input_width) {
// 右上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height && wstart == 0) {
// 左下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - 1 - hstart && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k - (3 - wend)];
}
}
}
} else if (hend == input_height && wend == input_width) {
// 右下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1 &&
k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (hstart == 0) {
// 顶部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height) {
// 底部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (wstart == 0) {
// 左侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width +
(k - (3 - wend))];
}
}
}
} else if (wend == input_width) {
// 右侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
}
for (int l = 0; l < 9; ++l) {
result += fake_input[l] * filter1[l];
}
output_data[ph * output_width + pw] =
newscale_data[c] * result + newbias_data[c];
if (if_relu) {
output_data[ph * output_width + pw] =
output_data[ph * output_width + pw] < 0
? 0
: output_data[ph * output_width + pw];
}
} else {
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t v_filter1 = vld1q_f32(filter1);
const float32x4_t v_filter2 = vld1q_f32(filter2);
const float32x4_t v_filter3 = vld1q_f32(filter3);
float32x4_t mula = vmulq_f32(data1, v_filter1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
output_data[ph * output_width + pw] =
vget_lane_f32(res, 0) * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[ph * output_width + pw] =
output_data[ph * output_width + pw] < 0
? 0
: output_data[ph * output_width + pw];
}
}
}
}
input_data += input_channel_stride;
output_data += output_channel_stride;
filter_data += filter_channel_stride;
}
input_data += input_batch_stride;
output_data += output_batch_stride;
}
#endif
} }
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, #define DEPTHWISE_CONV3X3_NORMAL_BORDER(start, end) \
const framework::Tensor *filter, for (int w = start; w < end; ++w) { \
framework::Tensor *output, framework::Tensor *bias, const int w_in_start = -padding_w + w * Stride_w; \
bool if_bias, bool if_relu) { const int w_in_end = w_in_start + 3; \
#if __ARM_NEON const int w_start = w_in_start > 0 ? w_in_start : 0; \
const float *input_data = input->data<float>(); const int w_end = w_in_end < input_w ? w_in_end : input_w; \
const float *filter_data = filter->data<float>(); float value = 0; \
float *output_data = output->mutable_data<float>(); for (int h_in = h_start; h_in < h_end; ++h_in) { \
const float *bias_data; for (int w_in = w_start; w_in < w_end; ++w_in) { \
if (if_bias) { value += filter[(h_in - h_in_start) * 3 + (w_in - w_in_start)] * \
bias_data = bias->data<float>(); input[h_in * input_w + w_in]; \
} } \
} \
const int in_h = static_cast<int>(input->dims()[2]); output_ptr[w] = value; \
const int in_w = static_cast<int>(input->dims()[3]); }
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]); template <int Stride_h, int Stride_w>
const int out_l = out_h; inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter,
const int in_l = in_h; const int h_output, const int input_h,
const int inhxw = in_h * in_w; const int input_w, const int padding_h,
const int outhxw = out_h * out_w; const int padding_w, const int output_w,
/// todo : fix if_pad when w != h float *output, float32x4_t *ker) {
const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0; const int h_in_start = -padding_h + h_output * Stride_h;
const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0; const int h_in_end = h_in_start + 3;
const int batch_size = static_cast<int>(input->dims()[0]); const int h_start = h_in_start > 0 ? h_in_start : 0;
const int c = static_cast<int>(input->dims()[1]); const int h_end = h_in_end < input_h ? h_in_end : input_h;
const float *input_row_ptr;
float *output_row_ptr; const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1;
const int w_times = (out_w - 2) / 3; // const int valid_w_end = output_w - valid_w_start;
float *output_ptr = output + h_output * output_w;
float32x4_t vbias = vdupq_n_f32(0.0); // border left
DEPTHWISE_CONV3X3_NORMAL_BORDER(0, valid_w_start)
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; // middle
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; int output_tiles = (valid_w_end - valid_w_start) >> 2;
int out2in_mid; float32x4_t _sum, _x[3];
float32x4_t zero = vdupq_n_f32(0.0); // valid w
for (int b = batch_size; b > 0; --b) { for (int w = 0; w < output_tiles * 4; w += 4) {
const float *filter_data_tmp = filter_data; _sum = vdupq_n_f32(0.f);
for (int j = 0; j < c; ++j) { int output_offset = valid_w_start + w;
auto output_data_tmp = output_data + j * out_h * out_w; int input_w_offset = output_offset * Stride_w - padding_w;
auto input_data_tmp = input_data + j * in_h * in_w; for (int h_in = h_start; h_in < h_end; ++h_in) {
auto input_const = input_data_tmp; int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>(
if (if_bias) { input + h_in * input_w + input_w_offset, _x);
vbias = vdupq_n_f32(bias_data[j]); _sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0);
} _sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0);
}
vst1q_f32(output_ptr + output_offset, _sum);
}
// remain valid w
int remain = (valid_w_end - valid_w_start) & 0x3;
if (remain > 0) {
_sum = vdupq_n_f32(0.f);
int remain_start = valid_w_start + (output_tiles << 2);
int input_w_offset = remain_start * Stride_w - padding_w;
float *output_ptr0 = output_ptr + remain_start;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0);
}
switch (remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _sum, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_sum));
break;
case 1:
vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0);
break;
}
}
// border right
DEPTHWISE_CONV3X3_NORMAL_BORDER(valid_w_end, output_w)
}
float w00 = filter_data_tmp[0]; template <>
float w01 = filter_data_tmp[1]; void DepthwiseConv3x3S1<float, float>(const framework::Tensor &input,
float w02 = filter_data_tmp[2]; const framework::Tensor &filter,
float w10 = filter_data_tmp[3]; const std::vector<int> &paddings,
float w11 = filter_data_tmp[4]; framework::Tensor *output) {
float w12 = filter_data_tmp[5]; const float *input_data = input.data<float>();
float w20 = filter_data_tmp[6]; const float *filter_data = filter.data<float>();
float w21 = filter_data_tmp[7]; float *out_data = output->mutable_data<float>();
float w22 = filter_data_tmp[8]; int input_h = input.dims()[2];
int input_w = input.dims()[3];
int h_mid = 0; int output_h = output->dims()[2];
int output_w = output->dims()[3];
for (; h_mid < out_h - 1; h_mid++) { int padding_h = paddings[0];
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; int padding_w = paddings[1];
output_row_ptr = output_data_tmp + 1 + h_mid * out_w; int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
for (int w4 = 0; w4 < w_times + 1; w4++) { int valid_h_start = padding_h;
if (h_mid == 0) { int valid_h_end = output_h - valid_h_start;
elewise_res1 = zero; int valid_h = valid_h_end - valid_h_start;
elewise_res0 = zero; int valid_w_start = padding_w;
elewise_res2 = zero; int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) {
const float *input_ptr = input_data + g * image_size;
const float *filter_ptr = filter_data + g * 9;
float *output_ptr = out_data + g * out_image_size;
const float *filter_ptr0 = filter_ptr;
const float *filter_ptr1 = filter_ptr0 + 3;
const float *filter_ptr2 = filter_ptr1 + 3;
float32x4_t _ker[3];
_ker[0] = vld1q_f32(filter_ptr0);
_ker[1] = vld1q_f32(filter_ptr1);
_ker[2] = vld1q_f32(filter_ptr2);
// pad top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
// output 2x6
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t zero = vdupq_n_f32(0.f);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
float32x4_t acc0, acc1;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
} else { } else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); acc0 = vmulq_f32(row0, _ker[0]);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); acc0 = vmlaq_f32(acc0, row1, _ker[1]);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); acc0 = vmlaq_f32(acc0, row2, _ker[2]);
} acc0 = vextq_f32(acc0, acc0, 1);
input_buff_mid = vld2q_f32(input_row_ptr); acc1 = vmulq_f32(row1, _ker[0]);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); acc1 = vmlaq_f32(acc1, row2, _ker[1]);
acc1 = vmlaq_f32(acc1, row3, _ker[2]);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); acc1 = vextq_f32(acc1, acc1, 1);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); float32x2_t sum = vpadd_f32(vget_low_f32(acc0), vget_low_f32(acc1));
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); vst1_lane_f32(output_ptr0 + w, sum, 0);
vst1_lane_f32(output_ptr1 + w, sum, 1);
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); row0 = vextq_f32(zero, row0, 3);
elewise_res0 = row1 = vextq_f32(zero, row1, 3);
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); row2 = vextq_f32(zero, row2, 3);
elewise_res2 = row3 = vextq_f32(zero, row3, 3);
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); }
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), output_ptr0 += valid_w_start;
vaddq_f32(elewise_res0, elewise_res1)); output_ptr1 += valid_w_start;
res3 = vaddq_f32(res3, vbias); }
if (if_relu) { // valid
res3 = vmaxq_f32(res3, zero); float32x4_t _result0, _result1, _result2, _result3;
} for (int loop = 0; loop < output_w_tiles; ++loop) {
vst1q_f32(output_row_ptr, res3); float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
input_row_ptr += 6; float32x4_t _row10 = vld1q_f32(input_ptr1);
output_row_ptr += 3; float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
}
} float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
clock(); float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
_result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
for (int w4 = 0; w4 < w_times + 1; w4++) { _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); _ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); _ext03 = vextq_f32(_row11, _row11, 1);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); _ext04 = vextq_f32(_row11, _row11, 2);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
if (!if_pad_b) { _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
elewise_res1 = _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
elewise_res0 = _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); _result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0);
} _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0);
vaddq_f32(elewise_res0, elewise_res1)); _result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0);
res3 = vaddq_f32(res3, vbias); _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1);
if (if_relu) { _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0);
res3 = vmaxq_f32(res3, zero);
} _row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
if ((w4 != w_times)) { _row10 = vld1q_f32(input_ptr3);
vst1q_f32(output_row_ptr, res3); _row11 = vld1q_f32(input_ptr3 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
vst1_f32(output_ptr0 + 4, vget_low_f32(_result1));
vst1q_f32(output_ptr1, _result2);
vst1_f32(output_ptr1 + 4, vget_low_f32(_result3));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
input_ptr3 += 6;
output_ptr0 += 6;
output_ptr1 += 6;
}
// remain w
if (output_w_remain > 0) {
float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t _row10 = vld1q_f32(input_ptr1);
float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
_result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
_result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0);
_result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0);
_row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
_row10 = vld1q_f32(input_ptr3);
_row11 = vld1q_f32(input_ptr3 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 5:
vst1q_lane_f32(output_ptr0 + 4, _result1, 0);
vst1q_lane_f32(output_ptr1 + 4, _result3, 0);
case 4:
vst1q_f32(output_ptr0, _result0);
vst1q_f32(output_ptr1, _result2);
break;
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
vst1q_lane_f32(output_ptr1 + 2, _result2, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
vst1_f32(output_ptr1, vget_low_f32(_result2));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
vst1q_lane_f32(output_ptr1, _result2, 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
input_ptr3 += output_w_remain;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
}
// pad right
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t zero = vdup_n_f32(0.f);
float32x2_t acc0, acc1;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
} else { } else {
if (out_w - 2 - w_times * 3 == 1) { acc0 = vmul_f32(row0, vget_low_f32(_ker[0]));
vst1q_lane_f32(output_row_ptr, res3, 0); acc0 = vmla_f32(acc0, row1, vget_low_f32(_ker[1]));
} else if (out_w - 2 - w_times * 3 == 2) { acc0 = vmla_f32(acc0, row2, vget_low_f32(_ker[2]));
vst1q_lane_f32(output_row_ptr, res3, 0); acc1 = vmul_f32(row1, vget_low_f32(_ker[0]));
vst1q_lane_f32(output_row_ptr + 1, res3, 1); acc1 = vmla_f32(acc1, row2, vget_low_f32(_ker[1]));
} acc1 = vmla_f32(acc1, row3, vget_low_f32(_ker[2]));
} float32x2_t sum = vpadd_f32(acc0, acc1);
input_row_ptr += 6; vst1_lane_f32(output_ptr0, sum, 0);
output_row_ptr += 3; vst1_lane_f32(output_ptr1, sum, 1);
} row0 = vext_f32(row0, zero, 1);
row1 = vext_f32(row1, zero, 1);
// leftTop, rightTop, leftBottom, rightBottom row2 = vext_f32(row2, zero, 1);
int lt = 0; row3 = vext_f32(row3, zero, 1);
int rt = out_w - 1; }
int lb = out_w * (out_h - 1); output_ptr0++;
int rb = out_h * out_w - 1; output_ptr1++;
}
output_data_tmp[lt] = input_const[0] * w11 + input_const[1] * w12 + }
input_const[in_w] * w21 + }
input_const[in_w + 1] * w22; // remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
out2in_mid = (out_w - 1) * 2; if (start_h < valid_h_end) {
output_data_tmp[rt] = const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + const float *input_ptr1 = input_ptr0 + input_w;
w20 * input_const[out2in_mid + in_w - 1] + const float *input_ptr2 = input_ptr1 + input_w;
w21 * input_const[out2in_mid + in_w] + float *output_ptr0 = output_ptr + start_h * output_w;
(1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] + // pad left
w22 * input_const[out2in_mid + in_w + 1]); if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
out2in_mid = (out_h - 1) * 2 * in_w; float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
output_data_tmp[lb] = float32x4_t zero = vdupq_n_f32(0.f);
w01 * input_const[out2in_mid - in_w] + row0 = vextq_f32(zero, row0, 3);
w02 * input_const[out2in_mid - in_w + 1] + row1 = vextq_f32(zero, row1, 3);
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + row2 = vextq_f32(zero, row2, 3);
(1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] + float32x4_t acc;
w22 * input_const[out2in_mid + in_w + 1]); for (int w = valid_w_start - 1; w >= 0; --w) {
out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2; int padding = padding_w - w;
if (padding >= 3) {
output_data_tmp[rb] = output_ptr0[w] = 0.f;
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w]) +
(1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1]) +
(1 - if_pad_r) * (1 - if_pad_b) * w22 *
input_const[out2in_mid + in_w + 1];
if (if_bias) {
output_data_tmp[lt] += bias_data[j];
output_data_tmp[rt] += bias_data[j];
output_data_tmp[lb] += bias_data[j];
output_data_tmp[rb] += bias_data[j];
}
if (if_relu) {
output_data_tmp[lt] = output_data_tmp[lt] < 0 ? 0 : output_data_tmp[lt];
output_data_tmp[rt] = output_data_tmp[rt] < 0 ? 0 : output_data_tmp[rt];
output_data_tmp[lb] = output_data_tmp[lb] < 0 ? 0 : output_data_tmp[lb];
output_data_tmp[rb] = output_data_tmp[rb] < 0 ? 0 : output_data_tmp[rb];
}
for (int i = 1; i < out_h - 1; i++) {
out2in_mid = i * 2 * in_w;
int left = i * out_w;
output_data_tmp[left] = w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1];
out2in_mid = i * 2 * in_w + (out_w - 1) * 2;
int right = i * out_w + out_w - 1;
output_data_tmp[right] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
if (if_bias) {
output_data_tmp[left] += bias_data[j];
output_data_tmp[right] += bias_data[j];
}
if (if_relu) {
output_data_tmp[left] =
output_data_tmp[left] < 0 ? 0 : output_data_tmp[left];
output_data_tmp[right] =
output_data_tmp[right] < 0 ? 0 : output_data_tmp[right];
}
}
filter_data_tmp += 9;
}
input_data += inhxw * c;
output_data += outhxw * c;
}
#endif
}
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON
// #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>();
// const float *newbias_data = new_bias->data<float>();
//
// const int batch_size = static_cast<int>(input->dims()[0]);
// const int input_channel = static_cast<int>(input->dims()[1]);
//
// const int input_height = static_cast<int>(input->dims()[2]);
// const int input_width = static_cast<int>(input->dims()[3]);
// const int output_height = static_cast<int>(output->dims()[2]);
// const int output_width = static_cast<int>(output->dims()[3]);
// const int inhxw = input_height * input_width;
// const int outhxw = output_height * output_width;
//
// float32x4_t zero = vdupq_n_f32(0.0);
// for (int b = 0; b < batch_size; b++) {
// #pragma omp parallel for
// for (int c = 0; c < input_channel; c++) {
// const float *filter_data = filter->data<float>() + c * 9;
// const float *input_data = input->data<float>() + c * inhxw;
// float *output_data = output->data<float>() + c * outhxw;
// float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]);
// float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]);
//
// float w00 = filter_data[0];
// float w01 = filter_data[1];
// float w02 = filter_data[2];
// float w10 = filter_data[3];
// float w11 = filter_data[4];
// float w12 = filter_data[5];
// float w20 = filter_data[6];
// float w21 = filter_data[7];
// float w22 = filter_data[8];
//
// int m;
// for (m = 1; m < output_width - 2; m = m + 3) {
// float *output_ptr = output_data + m;
// float32x4x2_t input_buff_mid{}, input_buff_bottom{};
// float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
// input_buff_mid = vld2q_f32(input_data + (2 * m - 1));
// input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m -
// 1));
//
// in0 = input_buff_mid.val[0];
// tmp0 = input_buff_mid.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
//
// in2 = input_buff_bottom.val[0];
// tmp2 = input_buff_bottom.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
//
// out0 = vmulq_n_f32(in0, w10);
// out0 = vmlaq_n_f32(out0, tmp0, w11);
// out0 = vmlaq_n_f32(out0, tmp1, w12);
// out0 = vmlaq_n_f32(out0, in2, w20);
// out0 = vmlaq_n_f32(out0, tmp2, w21);
// out0 = vmlaq_n_f32(out0, tmp3, w22);
// out0 = vmlaq_f32(vnewbias, vnewscale, out0);
// if (if_relu) {
// out0 = vmaxq_f32(out0, zero);
// }
// vst1q_lane_f32(output_ptr, out0, 0);
// vst1q_lane_f32(output_ptr + 1, out0, 1);
// vst1q_lane_f32(output_ptr + 2, out0, 2);
// }
// for (m = 1; m < output_width - 2; m += 3) {
// }
// for (int j = m; j < output_width; j++) {
// output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] *
// w11 +
// input_data[2 * j + 1] * w12 +
// input_data[2 * j - 1 + input_width] * w20 +
// input_data[2 * j + input_width] * w21 +
// input_data[2 * j + 1 + input_width] * w22;
// output_data[j] = newscale_data[c] * output_data[j] +
// newbias_data[c]; if (if_relu) {
// output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
// }
// }
//
// for (int i = 1; i < output_height; i += 1) {
// for (int m = 1; m < output_width - 2; m += 3) {
// float *output_ptr = output_data + i * output_width + m;
// float32x4x2_t input_buff_top{}, input_buff_mid{},
// input_buff_bottom{}; float32x4_t in0, in1, in2, in3, in4, in5,
// tmp0, tmp1, tmp2, tmp3,
// tmp4, tmp5, out0;
// input_buff_top =
// vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m -
// 1));
// input_buff_mid =
// vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1));
// input_buff_bottom =
// vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m -
// 1));
//
// in0 = input_buff_top.val[0];
// tmp0 = input_buff_top.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
//
// in2 = input_buff_mid.val[0];
// tmp2 = input_buff_mid.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
//
// in4 = input_buff_bottom.val[0];
// tmp4 = input_buff_bottom.val[1];
// tmp5 = vextq_f32(in4, zero, 1);
//
// out0 = vmulq_n_f32(in0, w00);
// out0 = vmlaq_n_f32(out0, tmp0, w01);
// out0 = vmlaq_n_f32(out0, tmp1, w02);
// out0 = vmlaq_n_f32(out0, in2, w10);
// out0 = vmlaq_n_f32(out0, tmp2, w11);
// out0 = vmlaq_n_f32(out0, tmp3, w12);
// out0 = vmlaq_n_f32(out0, in4, w20);
// out0 = vmlaq_n_f32(out0, tmp4, w21);
// out0 = vmlaq_n_f32(out0, tmp5, w22);
// out0 = vmlaq_f32(vnewbias, vnewscale, out0);
// if (if_relu) {
// out0 = vmaxq_f32(out0, zero);
// }
// vst1q_lane_f32(output_ptr, out0, 0);
// vst1q_lane_f32(output_ptr + 1, out0, 1);
// vst1q_lane_f32(output_ptr + 2, out0, 2);
// }
// int m;
// for (m = 1; m < output_width - 2; m += 3) {
// }
// for (int j = m; j < output_width; j++) {
// output_data[i * output_width + j] =
// input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 +
// input_data[(2 * i - 1) * input_width + 2 * j] * w01 +
// input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 +
// input_data[(2 * i) * input_width + 2 * j - 1] * w10 +
// input_data[(2 * i) * input_width + 2 * j] * w11 +
// input_data[(2 * i) * input_width + 2 * j + 1] * w12 +
// input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 +
// input_data[(2 * i + 1) * input_width + 2 * j] * w21 +
// input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22;
// output_data[i * output_width + j] =
// newscale_data[c] * output_data[i * output_width + j] +
// newbias_data[c];
// if (if_relu) {
// output_data[i * output_width + j] =
// output_data[i * output_width + j] < 0
// ? 0
// : output_data[i * output_width + j];
// }
// }
// }
// output_data[0] = input_data[0] * w11 + input_data[1] * w12 +
// input_data[input_height] * w21 +
// input_data[input_height + 1] * w22;
//
// output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c];
// if (if_relu) {
// output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
// }
// for (int i = 1; i < output_height; i++) {
// output_data[i * output_width] =
// input_data[(2 * i - 1) * input_width] * w01 +
// input_data[(2 * i - 1) * input_width + 1] * w02 +
// input_data[(2 * i) * input_width] * w11 +
// input_data[(2 * i) * input_width + 1] * w12 +
// input_data[(2 * i + 1) * input_width] * w21 +
// input_data[(2 * i + 1) * input_width + 1] * w22;
//
// output_data[i * output_width] =
// newscale_data[c] * output_data[i * output_width] +
// newbias_data[c];
// if (if_relu) {
// output_data[i * output_width] = output_data[i * output_width] < 0
// ? 0
// : output_data[i *
// output_width];
// }
// }
// }
// }
//
// #else
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
// const int out_l = out_h;
// const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
/// todo : fix if_pad when w != h
const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0;
const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int w_times = (out_w - 2) / 3;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
#pragma omp parallel for
for (int j = 0; j < c; j++) {
const float *input_row_ptr;
float *output_row_ptr;
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp;
const float *filter_data_tmp = filter_data + 9 * j;
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
} else { } else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); acc = vmulq_f32(row0, _ker[0]);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); acc = vmlaq_f32(acc, row1, _ker[1]);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); acc = vmlaq_f32(acc, row2, _ker[2]);
} acc = vextq_f32(acc, acc, 1);
input_buff_mid = vld2q_f32(input_row_ptr); float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_low_f32(acc));
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); vst1_lane_f32(output_ptr0 + w, sum, 0);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); row0 = vextq_f32(zero, row0, 3);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); row1 = vextq_f32(zero, row1, 3);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); row2 = vextq_f32(zero, row2, 3);
}
elewise_res1 = }
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); output_ptr0 += valid_w_start;
elewise_res0 = }
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); // valid
elewise_res2 = float32x4_t _result0, _result1;
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); for (int loop = 0; loop < output_w_tiles; ++loop) {
float32x4_t _row00 = vld1q_f32(input_ptr0);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
vaddq_f32(elewise_res0, elewise_res1)); float32x4_t _row10 = vld1q_f32(input_ptr1);
res3 = vmlaq_f32(vnewbias, vnewscale, res3); float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
if (if_relu) { float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
res3 = vmaxq_f32(res3, zero); float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
} float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
vst1q_lane_f32(output_row_ptr, res3, 0); float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
vst1q_lane_f32(output_row_ptr + 2, res3, 2); _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
input_row_ptr += 6; _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
output_row_ptr += 3; _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
} _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
} _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
clock();
_ext01 = vextq_f32(_row10, _row11, 1);
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; _ext02 = vextq_f32(_row10, _row11, 2);
output_row_ptr = output_data_tmp + 1 + h_mid * out_w; _ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
for (int w4 = 0; w4 < w_times + 1; w4++) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
input_buff_mid = vld2q_f32(input_row_ptr); _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); _row00 = vld1q_f32(input_ptr2);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); _row01 = vld1q_f32(input_ptr2 + 4);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
_ext01 = vextq_f32(_row00, _row01, 1);
if (!if_pad_b) { _ext02 = vextq_f32(_row00, _row01, 2);
elewise_res1 = _ext03 = vextq_f32(_row01, _row01, 1);
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); _ext04 = vextq_f32(_row01, _row01, 2);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
elewise_res2 = _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
} _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
vaddq_f32(elewise_res0, elewise_res1)); _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
vst1q_f32(output_ptr0, _result0);
if (if_relu) { vst1_f32(output_ptr0 + 4, vget_low_f32(_result1));
res3 = vmaxq_f32(res3, zero);
} input_ptr0 += 6;
if ((w4 != w_times)) { input_ptr1 += 6;
vst1q_lane_f32(output_row_ptr, res3, 0); input_ptr2 += 6;
vst1q_lane_f32(output_row_ptr + 1, res3, 1); output_ptr0 += 6;
vst1q_lane_f32(output_row_ptr + 2, res3, 2); }
if (output_w_remain > 0) {
float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t _row10 = vld1q_f32(input_ptr1);
float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
_result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
_row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 5:
vst1q_lane_f32(output_ptr0 + 4, _result1, 0);
case 4:
vst1q_f32(output_ptr0, _result0);
break;
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
output_ptr0 += output_w_remain;
}
// pad right
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t zero = vdup_n_f32(0.f);
float32x2_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
} else { } else {
if (out_w - 2 - w_times * 3 == 1) { acc = vmul_f32(row0, vget_low_f32(_ker[0]));
vst1q_lane_f32(output_row_ptr, res3, 0); acc = vmla_f32(acc, row1, vget_low_f32(_ker[1]));
} else if (out_w - 2 - w_times * 3 == 2) { acc = vmla_f32(acc, row2, vget_low_f32(_ker[2]));
vst1q_lane_f32(output_row_ptr, res3, 0); float32x2_t sum = vpadd_f32(acc, acc);
vst1q_lane_f32(output_row_ptr + 1, res3, 1); vst1_lane_f32(output_ptr0, sum, 0);
} row0 = vext_f32(row0, zero, 1);
row1 = vext_f32(row1, zero, 1);
row2 = vext_f32(row2, zero, 1);
} }
input_row_ptr += 6; output_ptr0++;
output_row_ptr += 3;
} }
output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 +
input_const[in_w] * w21 +
input_const[in_w + 1] * w22;
out2in_mid = (out_w - 1) * 2;
output_data_tmp[out_w - 1] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_h - 1) * 2 * in_w;
output_data_tmp[out_w * (out_h - 1)] =
w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] +
(1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2;
output_data_tmp[out_h * out_w - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w]) +
(1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1]) +
(1 - if_pad_r) * (1 - if_pad_b) * w22 *
input_const[out2in_mid + in_w + 1];
output_data_tmp[0] =
output_data_tmp[0] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_w - 1] =
output_data_tmp[out_w - 1] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_w * (out_h - 1)] =
output_data_tmp[out_w * (out_h - 1)] * newscale_data[j] +
newbias_data[j];
output_data_tmp[out_h * out_w - 1] =
output_data_tmp[out_h * out_w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0];
output_data_tmp[out_w - 1] =
output_data_tmp[out_w - 1] < 0 ? 0 : output_data_tmp[out_w - 1];
output_data_tmp[out_w * (out_h - 1)] =
output_data_tmp[out_w * (out_h - 1)] < 0
? 0
: output_data_tmp[out_w * (out_h - 1)];
output_data_tmp[out_h * out_w - 1] =
output_data_tmp[out_h * out_w - 1] < 0
? 0
: output_data_tmp[out_h * out_w - 1];
}
for (int i = 1; i < out_h - 1; i++) {
out2in_mid = i * 2 * in_w;
output_data_tmp[i * out_w] = w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1];
out2in_mid = i * 2 * in_w + (out_w - 1) * 2;
output_data_tmp[i * out_w + out_w - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
output_data_tmp[i * out_w] =
output_data_tmp[i * out_w] * newscale_data[j] + newbias_data[j];
output_data_tmp[i * out_w + out_w - 1] =
output_data_tmp[i * out_w + out_w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[i * out_w] =
output_data_tmp[i * out_w] < 0 ? 0 : output_data_tmp[i * out_w];
output_data_tmp[i * out_w + out_w - 1] =
output_data_tmp[i * out_w + out_w - 1] < 0
? 0
: output_data_tmp[i * out_w + out_w - 1];
} }
} }
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
} }
input_data += inhxw * c;
output_data += outhxw * c;
} }
// #endif
#endif
} }
void DepthwiseConv3x3s2p0(const framework::Tensor *input, template <>
const framework::Tensor *filter, void DepthwiseConv3x3S2<float, float>(const framework::Tensor &input,
framework::Tensor *output, framework::Tensor *bias, const framework::Tensor &filter,
bool if_bias, bool if_relu) { const std::vector<int> &paddings,
#if __ARM_NEON framework::Tensor *output) {
const int batch_size = static_cast<int>(input->dims()[0]); const float *input_data = input.data<float>();
const int input_channel = static_cast<int>(input->dims()[1]); const float *filter_data = filter.data<float>();
float *out_data = output->mutable_data<float>();
const int input_height = static_cast<int>(input->dims()[2]); int input_h = input.dims()[2];
const int input_width = static_cast<int>(input->dims()[3]); int input_w = input.dims()[3];
const int output_height = static_cast<int>(output->dims()[2]); int output_h = output->dims()[2];
const int output_width = static_cast<int>(output->dims()[3]); int output_w = output->dims()[3];
const int inhxw = input_height * input_width; int padding_h = paddings[0];
const int outhxw = output_height * output_width; int padding_w = paddings[1];
output->mutable_data<float>(); int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
float32x4_t zero = vdupq_n_f32(0.0); int valid_h_start = (padding_h + 1) / 2;
for (int b = 0; b < batch_size; b++) { int valid_h_end = (input_h + padding_h - 1) / 2;
#pragma omp parallel for int valid_h = valid_h_end - valid_h_start;
for (int c = 0; c < input_channel; c++) { int valid_w_start = (padding_w + 1) / 2;
const float *filter_data = filter->data<float>() + c * 9; int valid_w_end = (input_w + padding_w - 1) / 2;
const float *input_data = input->data<float>() + c * inhxw; int valid_w = valid_w_end - valid_w_start;
const float *bias_data; int input_w_start = 2 * valid_w_start - padding_w;
float32x4_t biasv;
if (if_bias) { #pragma omp parallel for
bias_data = bias->data<float>() + c; for (int g = 0; g < input.dims()[1]; ++g) {
biasv = vld1q_dup_f32(bias_data); const float *input_ptr = input_data + g * image_size;
} const float *filter_ptr = filter_data + g * 9;
float *output_data = output->data<float>() + c * outhxw; float *output_ptr = out_data + g * out_image_size;
float w00 = filter_data[0];
float w01 = filter_data[1]; const float *filter_ptr0 = filter_ptr;
float w02 = filter_data[2]; const float *filter_ptr1 = filter_ptr0 + 3;
float w10 = filter_data[3]; const float *filter_ptr2 = filter_ptr1 + 3;
float w11 = filter_data[4]; float32x4_t _ker[3];
float w12 = filter_data[5]; _ker[0] = vld1q_f32(filter_ptr0);
float w20 = filter_data[6]; _ker[1] = vld1q_f32(filter_ptr1);
float w21 = filter_data[7]; _ker[2] = vld1q_f32(filter_ptr2);
float w22 = filter_data[8];
for (int i = 0; i < output_height; i += 1) { // pad top
for (int m = 0; m < output_width - 2; m += 3) { for (int h = 0; h < valid_h_start; ++h) {
float *output_ptr = output_data + i * output_width + m; DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
float32x4x2_t input_buff_top{}, input_buff_mid{}, input_buff_bottom{}; input_w, padding_h, padding_w, output_w,
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, output_ptr, _ker);
tmp4, tmp5, out0; }
input_buff_top = // valid 2x4
vld2q_f32(input_data + (2 * i) * input_width + (2 * m)); int output_w_tiles = valid_w / 4;
input_buff_mid = int output_w_remain = valid_w - output_w_tiles * 4;
vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m)); for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
input_buff_bottom = const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
vld2q_f32(input_data + (2 * i + 2) * input_width + (2 * m)); const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
in0 = input_buff_top.val[0]; const float *input_ptr3 = input_ptr2 + input_w;
tmp0 = input_buff_top.val[1]; const float *input_ptr4 = input_ptr3 + input_w;
tmp1 = vextq_f32(in0, zero, 1); float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
in2 = input_buff_mid.val[0]; // pad left
tmp2 = input_buff_mid.val[1]; if (padding_w) {
tmp3 = vextq_f32(in2, zero, 1); for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
in4 = input_buff_bottom.val[0]; if (padding >= 3) {
tmp4 = input_buff_bottom.val[1]; output_ptr0[w] = 0;
tmp5 = vextq_f32(in4, zero, 1); output_ptr1[w] = 0;
} else {
out0 = vmulq_n_f32(in0, w00); float32x4_t row0 = vld1q_f32(input_ptr0 - padding);
out0 = vmlaq_n_f32(out0, tmp0, w01); float32x4_t row1 = vld1q_f32(input_ptr1 - padding);
out0 = vmlaq_n_f32(out0, tmp1, w02); float32x4_t row2 = vld1q_f32(input_ptr2 - padding);
out0 = vmlaq_n_f32(out0, in2, w10); float32x4_t row3 = vld1q_f32(input_ptr3 - padding);
out0 = vmlaq_n_f32(out0, tmp2, w11); float32x4_t row4 = vld1q_f32(input_ptr4 - padding);
out0 = vmlaq_n_f32(out0, tmp3, w12); float32x4_t acc0 = vmulq_f32(row0, _ker[0]);
out0 = vmlaq_n_f32(out0, in4, w20); float32x4_t acc1 = vmulq_f32(row2, _ker[0]);
out0 = vmlaq_n_f32(out0, tmp4, w21); acc0 = vmlaq_f32(acc0, row1, _ker[1]);
out0 = vmlaq_n_f32(out0, tmp5, w22); acc1 = vmlaq_f32(acc1, row3, _ker[1]);
if (if_bias) { acc0 = vmlaq_f32(acc0, row2, _ker[2]);
out0 = vaddq_f32(out0, biasv); acc1 = vmlaq_f32(acc1, row4, _ker[2]);
} float sum0 = vgetq_lane_f32(acc0, 2);
if (if_relu) { float sum1 = vgetq_lane_f32(acc1, 2);
out0 = vmaxq_f32(out0, zero); if (padding == 1) {
} sum0 += vgetq_lane_f32(acc0, 1);
vst1q_lane_f32(output_ptr, out0, 0); sum1 += vgetq_lane_f32(acc1, 1);
vst1q_lane_f32(output_ptr + 1, out0, 1); }
vst1q_lane_f32(output_ptr + 2, out0, 2); output_ptr0[w] = sum0;
} output_ptr1[w] = sum1;
int m; }
for (m = 0; m < output_width - 2; m += 3) { }
input_ptr0 += input_w_start;
input_ptr1 += input_w_start;
input_ptr2 += input_w_start;
input_ptr3 += input_w_start;
input_ptr4 += input_w_start;
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// valid
float32x4_t _result0, _result1, _ext;
for (int loop = 0; loop < output_w_tiles; ++loop) {
float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr2);
_row1 = vld2q_f32(input_ptr3);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
_result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr3[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr4);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr4[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
vst1q_f32(output_ptr1, _result1);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
input_ptr4 += 8;
output_ptr0 += 4;
output_ptr1 += 4;
}
// remain w
if (output_w_remain > 0) {
float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr2);
_row1 = vld2q_f32(input_ptr3);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
_result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr3[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr4);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr4[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
vst1q_lane_f32(output_ptr1 + 2, _result1, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
vst1_f32(output_ptr1, vget_low_f32(_result1));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
vst1q_lane_f32(output_ptr1, _result1, 0);
break;
}
input_ptr0 += output_w_remain * 2;
input_ptr1 += output_w_remain * 2;
input_ptr2 += output_w_remain * 2;
input_ptr3 += output_w_remain * 2;
input_ptr4 += output_w_remain * 2;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
}
// pad right
if (padding_w > 0) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t acc0, acc1;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
*output_ptr1 = 0;
} else {
acc0 = vmulq_f32(row0, _ker[0]);
acc1 = vmulq_f32(row2, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc1 = vmlaq_f32(acc1, row3, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
acc1 = vmlaq_f32(acc1, row4, _ker[2]);
float sum0 = vgetq_lane_f32(acc0, 0);
float sum1 = vgetq_lane_f32(acc1, 0);
if (padding == 1) {
sum0 += vgetq_lane_f32(acc0, 1);
sum1 += vgetq_lane_f32(acc1, 1);
}
*output_ptr0 = sum0;
*output_ptr1 = sum1;
}
output_ptr0++;
output_ptr1++;
}
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (2 * start_h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + start_h * output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0;
} else {
float32x4_t row0 = vld1q_f32(input_ptr0 - padding);
float32x4_t row1 = vld1q_f32(input_ptr1 - padding);
float32x4_t row2 = vld1q_f32(input_ptr2 - padding);
float32x4_t acc0 = vmulq_f32(row0, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
float sum0 = vgetq_lane_f32(acc0, 2);
if (padding == 1) {
sum0 += vgetq_lane_f32(acc0, 1);
}
output_ptr0[w] = sum0;
}
}
input_ptr0 += input_w_start;
input_ptr1 += input_w_start;
input_ptr2 += input_w_start;
output_ptr0 += valid_w_start;
}
// valid
float32x4_t _result0, _ext;
for (int loop = 0; loop < output_w_tiles; ++loop) {
float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
float32x4x2_t _row2 = vld2q_f32(input_ptr2);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
_ext = vextq_f32(_row2.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
output_ptr0 += 4;
}
// remain w
if (output_w_remain > 0) {
float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
float32x4x2_t _row2 = vld2q_f32(input_ptr2);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
_ext = vextq_f32(_row2.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
break;
}
input_ptr0 += output_w_remain * 2;
input_ptr1 += output_w_remain * 2;
input_ptr2 += output_w_remain * 2;
output_ptr0 += output_w_remain;
}
// pad right
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t acc0;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
} else {
acc0 = vmulq_f32(row0, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
float sum0 = vgetq_lane_f32(acc0, 0);
if (padding == 1) {
sum0 += vgetq_lane_f32(acc0, 1);
} }
for (int j = m; j < output_width; j++) { *output_ptr0 = sum0;
int index = i * output_width + j;
output_data[index] =
input_data[(2 * i) * input_width + 2 * j] * w00 +
input_data[(2 * i) * input_width + 2 * j + 1] * w01 +
input_data[(2 * i) * input_width + 2 * j + 2] * w02 +
input_data[(2 * i + 1) * input_width + 2 * j] * w10 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w11 +
input_data[(2 * i + 1) * input_width + 2 * j + 2] * w12 +
input_data[(2 * i + 2) * input_width + 2 * j] * w20 +
input_data[(2 * i + 2) * input_width + 2 * j + 1] * w21 +
input_data[(2 * i + 2) * input_width + 2 * j + 2] * w22;
if (if_bias) {
output_data[index] += *bias_data;
} }
if (if_relu) { output_ptr0++;
output_data[index] =
output_data[index] < 0 ? 0 : output_data[index];
} }
} }
} }
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
} }
} }
#endif
} }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif // __ARM_NEON__
...@@ -23,48 +23,6 @@ namespace paddle_mobile { ...@@ -23,48 +23,6 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
// TODO(hjchen2) need to be implemented // TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype> // template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input, // void DepthwiseConv3x3(const framework::Tensor *input,
......
...@@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, ...@@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); // return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// } // }
CPUInfo *info = CPUInfo::Info(); GemmExecutor<SgemmStrategy> exec(transA, transB, M, N, K);
GemmExecutor<SgemmStrategy> exec(info, transA, transB, M, N, K);
exec(alpha, A, lda, B, ldb, beta, C, ldc); exec(alpha, A, lda, B, ldb, beta, C, ldc);
} }
void cblas_sgemv(const bool trans, const int M, const int N, const float alpha, void cblas_sgemv(const bool trans, const int M, const int N, const float alpha,
const float *A, const int lda, const float *B, const float *A, const int lda, const float *B,
const float beta, float *C) { const float beta, float *C) {
CPUInfo *info = CPUInfo::Info(); GemvExecutor<SgemvStrategy> exec(trans, M, N);
GemvExecutor<SgemvStrategy> exec(info, trans, M, N);
exec(alpha, A, lda, B, beta, C); exec(alpha, A, lda, B, beta, C);
} }
......
...@@ -19,7 +19,6 @@ limitations under the License. */ ...@@ -19,7 +19,6 @@ limitations under the License. */
#include <omp.h> #include <omp.h>
#endif #endif
#include <sys/time.h> #include <sys/time.h>
#include <iostream>
#include "common/log.h" #include "common/log.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#include "operators/math/gemm/cpu_info.h" #include "operators/math/gemm/cpu_info.h"
...@@ -29,6 +28,8 @@ namespace paddle_mobile { ...@@ -29,6 +28,8 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
static CPUInfo *info = CPUInfo::Info();
int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num, unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num,
const int N, const int K) { const int N, const int K) {
...@@ -62,15 +63,9 @@ class GemmExecutor : public Executor { ...@@ -62,15 +63,9 @@ class GemmExecutor : public Executor {
typedef typename Strategy::Otype Otype; typedef typename Strategy::Otype Otype;
public: public:
GemmExecutor(const CPUInfo *info, const bool transA, const bool transB, GemmExecutor(const bool transA, const bool transB, const int M, const int N,
const int M, const int N, const int K) const int K)
: Executor(), : Executor(), transA_(transA), transB_(transB), M_(M), N_(N), K_(K) {
info_(info),
transA_(transA),
transB_(transB),
M_(M),
N_(N),
K_(K) {
unsigned int L1_size = 0; unsigned int L1_size = 0;
unsigned int L2_size = 0; unsigned int L2_size = 0;
if (M_ > N_) { if (M_ > N_) {
...@@ -212,8 +207,6 @@ class GemmExecutor : public Executor { ...@@ -212,8 +207,6 @@ class GemmExecutor : public Executor {
virtual ~GemmExecutor() {} virtual ~GemmExecutor() {}
private: private:
const CPUInfo *info_;
const unsigned int M_; const unsigned int M_;
const unsigned int N_; const unsigned int N_;
const unsigned int K_; const unsigned int K_;
...@@ -242,8 +235,8 @@ class GemvExecutor : public Executor { ...@@ -242,8 +235,8 @@ class GemvExecutor : public Executor {
typedef typename Strategy::Otype Otype; typedef typename Strategy::Otype Otype;
public: public:
GemvExecutor(const CPUInfo *info, const bool transA, const int M, const int N) GemvExecutor(const bool transA, const int M, const int N)
: Executor(), info_(info), M_(M), N_(N) {} : Executor(), M_(M), N_(N) {}
void operator()(const float alpha, const Itype *A, const int lda, void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const float beta, Otype *C) { const Itype *B, const float beta, Otype *C) {
...@@ -253,8 +246,6 @@ class GemvExecutor : public Executor { ...@@ -253,8 +246,6 @@ class GemvExecutor : public Executor {
virtual ~GemvExecutor() {} virtual ~GemvExecutor() {}
private: private:
const CPUInfo *const info_;
const unsigned int M_; const unsigned int M_;
const unsigned int N_; const unsigned int N_;
......
...@@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, ...@@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
for (int i = start_height; i < end_height; i += stride_h) { for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) { if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(float)); // memcpy(col_data, im_data, extract * sizeof(float));
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4_t img = vld1q_f32(im_data + s);
vst1q_f32(col_data + s, img);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s];
}
} else if (stride_w == 2) { } else if (stride_w == 2) {
int s = 0; int s = 0;
#if __ARM_NEON #if __ARM_NEON
...@@ -109,325 +119,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( ...@@ -109,325 +119,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
const float *im_data = im.data<float>(); const float *im_data = im.data<float>();
float *col_data = col->data<float>(); float *col_data = col->data<float>();
#if __ARM_NEON #if __ARM_NEON
const int osize = col_height; if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
const int isize = im_height;
bool pad1 = padding[0] > 0;
bool pad2 =
(pad1 && padding[1] &&
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1 && im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
float *col0 = col_data + 0 * oosize + 2 * osize + 2;
float *col1 = col_data + 1 * oosize + 2 * osize + 1;
float *col2 = col_data + 2 * oosize + 2 * osize;
float *col3 = col_data + 3 * oosize + osize + 2;
float *col4 = col_data + 4 * oosize + osize + 1;
float *col5 = col_data + 5 * oosize + osize;
float *col6 = col_data + 6 * oosize + 2;
float *col7 = col_data + 7 * oosize + 1;
float *col8 = col_data + 8 * oosize;
float32x4_t im1;
const float *im_tmp_data = im_data + osize + 1;
int rrsize = oosize - osize - 1;
int nr4 = rrsize / 4;
int mr4 = rrsize % 4;
for (int i = 0; i < nr4; ++i) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
vst1q_f32(col6, im1);
vst1q_f32(col7, im1);
vst1q_f32(col8, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mr4; ++i) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
*col6 = *im_tmp_data;
*col7 = *im_tmp_data;
*col8 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data++;
}
im_tmp_data = im_data + 1;
col0 = col_data + 0 * oosize + osize + 2;
col1 = col_data + 1 * oosize + osize + 1;
col2 = col_data + 2 * oosize + osize;
col3 = col_data + 3 * oosize + 2;
col4 = col_data + 4 * oosize + 1;
col5 = col_data + 5 * oosize;
for (int i = 0; i < nk4; i++) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mk4; i++) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
im_tmp_data++;
}
// fill 0 1 11;
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
col_data[0 * oosize + osize + 1] = im_data[0];
col_data[3 * oosize + 1] = im_data[0];
col_data[6 * oosize + 1] = im_data[osize];
col_data[1 * oosize + osize] = im_data[0];
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[osize];
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 &&
im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
// 3 2 3 1 0 1 3 2 3
float *col0 = col_data + 0 * oosize + osize + 1;
float *col1 = col_data + 1 * oosize + osize;
float *col2 = col_data + 2 * oosize + osize;
float *col3 = col_data + 3 * oosize + 1;
float *col4 = col_data + 4 * oosize;
float *col5 = col_data + 5 * oosize;
float *col6 = col_data + 6 * oosize + 1;
float *col7 = col_data + 7 * oosize;
float *col8 = col_data + 8 * oosize;
float32x4x2_t im01;
float32x4x2_t im23;
const float *im_tmp_data0 = im_data;
const float *im_tmp_data2 = im_data + isize;
for (int j = 0; j < osize; ++j) {
for (int i = 0; i < nk4; ++i) {
im01 = vld2q_f32(im_tmp_data0);
im23 = vld2q_f32(im_tmp_data2);
vst1q_f32(col0, im23.val[1]);
vst1q_f32(col1, im23.val[0]);
vst1q_f32(col2, im23.val[1]);
vst1q_f32(col3, im01.val[1]);
vst1q_f32(col4, im01.val[0]);
vst1q_f32(col5, im01.val[1]);
vst1q_f32(col6, im23.val[1]);
vst1q_f32(col7, im23.val[0]);
vst1q_f32(col8, im23.val[1]);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data0 += 8;
im_tmp_data2 += 8;
}
const float *im_tmp_data1 = im_tmp_data0 + 1;
const float *im_tmp_data3 = im_tmp_data2 + 1;
for (int i = 0; i < mk4; ++i) {
*col0 = *im_tmp_data3;
*col1 = *im_tmp_data2;
*col2 = *im_tmp_data3;
*col3 = *im_tmp_data1;
*col4 = *im_tmp_data0;
*col5 = *im_tmp_data1;
*col6 = *im_tmp_data3;
*col7 = *im_tmp_data2;
*col8 = *im_tmp_data3;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data0 += 2;
im_tmp_data1 += 2;
im_tmp_data2 += 2;
im_tmp_data3 += 2;
}
im_tmp_data0 += (isize - fill);
im_tmp_data2 += (isize - fill);
}
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
if (pad2) {
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
}
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
if (pad2) {
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
}
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
if (pad2) {
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
}
col_data[1 * oosize + osize] = im_data[isize];
for (int i = 1; i < osize; ++i) {
col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1];
}
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[isize];
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width; int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width; int col_spatial_size = col_height * col_width;
// pad 0 // pad 0
......
...@@ -441,10 +441,8 @@ class ConvParam : public OpParam { ...@@ -441,10 +441,8 @@ class ConvParam : public OpParam {
enum ExecMode { enum ExecMode {
EXEC_INVALID = 0, EXEC_INVALID = 0,
EXEC_GEMM_FLOAT, EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT, EXEC_DEPTHWISE3x3S1_FLOAT,
EXEC_DEPTHWISE3x3S2P0_FLOAT, EXEC_DEPTHWISE3x3S2_FLOAT,
EXEC_DEPTHWISE3x3S2P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT, EXEC_WINOGRAD5X5_FLOAT,
EXEC_DEPTHWISE5x5_FLOAT, EXEC_DEPTHWISE5x5_FLOAT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册