提交 3582b9a0 编写于 作者: H hjchen2

Add depthwise conv5x5 armv8 implementation

上级 ab199ae0
......@@ -73,13 +73,11 @@ void ConvAddBNReluKernel<CPU, float>::Compute(
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
......
......@@ -43,13 +43,11 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
......
......@@ -41,12 +41,10 @@ void ConvAddReluKernel<CPU, float>::Compute(
param.Paddings(), param.Output());
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
......
......@@ -73,13 +73,11 @@ void ConvBNAddReluKernel<CPU, float>::Compute(
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
......
......@@ -72,13 +72,11 @@ void ConvBNReluKernel<CPU, float>::Compute(
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
......
......@@ -49,11 +49,9 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT;
#ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
#endif
} else if (conv3x3 && !depth3x3 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
......
......@@ -51,11 +51,9 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
......
......@@ -164,6 +164,7 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
}
#ifndef __aarch64__
// int8 DepthwiseConv3x3
template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
......@@ -188,6 +189,7 @@ inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
}
}
}
#endif // __aarch64__
template <typename Itype, typename Otype>
inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
......@@ -210,7 +212,6 @@ inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
GemmConv<Itype, Otype>(param);
}
}
#endif // __aarch64__
template <typename ParamType>
void ConvAddReluBasic(const ParamType &param) {
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(__ARM_NEON__) && !defined(__aarch64__)
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include "operators/math/depthwise_conv5x5.h"
#include <arm_neon.h>
......@@ -243,7 +243,224 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// valid
// valid
// #if __aarch64__
#if 0
float32x4_t _q14, _q15;
for (int loop = 0; loop = output_w_tiles; ++loop) {
float32x4_t _q7 = vld1q_f32(input_ptr0);
float32x4_t _q8 = vld1q_f32(input_ptr0 + 4);
float32x4_t _q9 = vld1q_f32(input_ptr1);
float32x4_t _q10 = vld1q_f32(input_ptr1 + 4);
float32x4_t _q11 = vld1q_f32(input_ptr2);
float32x4_t _q12 = vld1q_f32(input_ptr2 + 4);
_q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0);
float32x4_t _q13 = vextq_f32(_q7, _q8, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0);
_q13 = vextq_f32(_q7, _q8, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1);
_q13 = vextq_f32(_q7, _q8, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0);
_q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1);
_q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1);
_q15 = vmulq_lane_f32(_q9, vget_low_f32(_ker[5]), 0);
_q13 = vextq_f32(_q9, _q10, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 0);
_q13 = vextq_f32(_q9, _q10, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 1);
_q13 = vextq_f32(_q9, _q10, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[0]), 0);
_q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1);
_q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[0]), 1);
_q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0);
_q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[5]), 1);
_q13 = vextq_f32(_q11, _q12, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 0);
_q13 = vextq_f32(_q11, _q12, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 1);
_q13 = vextq_f32(_q11, _q12, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[1]), 0);
_q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1);
_q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[1]), 1);
_q7 = vld1q_f32(input_ptr3);
_q8 = vld1q_f32(input_ptr3 + 4);
_q9 = vld1q_f32(input_ptr4);
_q10 = vld1q_f32(input_ptr4 + 4);
_q11 = vld1q_f32(input_ptr5);
_q12 = vld1q_f32(input_ptr5 + 4);
_q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1);
_q15 = vmlaq_lane_f32(_q15, _q7, vget_high_f32(_ker[5]), 0);
_q13 = vextq_f32(_q7, _q8, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 0);
_q13 = vextq_f32(_q7, _q8, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 1);
_q13 = vextq_f32(_q7, _q8, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[2]), 0);
_q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1);
_q15 = vmlaq_lane_f32(_q15, _q8, vget_high_f32(_ker[2]), 1);
_q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0);
_q15 = vmlaq_lane_f32(_q15, _q9, vget_high_f32(_ker[5]), 1);
_q13 = vextq_f32(_q9, _q10, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 0);
_q13 = vextq_f32(_q9, _q10, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 1);
_q13 = vextq_f32(_q9, _q10, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[3]), 0);
_q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1);
_q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[3]), 1);
_q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[6]), 0);
_q13 = vextq_f32(_q11, _q12, 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 0);
_q13 = vextq_f32(_q11, _q12, 2);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 1);
_q13 = vextq_f32(_q11, _q12, 3);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[4]), 0);
_q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[4]), 1);
vst1q_f32(output_ptr0, _q14);
vst1q_f32(output_ptr1, _q15);
input_ptr0 += 4;
input_ptr1 += 4;
input_ptr2 += 4;
input_ptr3 += 4;
input_ptr4 += 4;
input_ptr5 += 4;
output_ptr0 += 4;
output_ptr1 += 4;
}
// remain w
if (output_w_remain > 0) {
float32x4_t _q7 = vld1q_f32(input_ptr0);
float32x4_t _q8 = vld1q_f32(input_ptr0 + 4);
float32x4_t _q9 = vld1q_f32(input_ptr1);
float32x4_t _q10 = vld1q_f32(input_ptr1 + 4);
float32x4_t _q11 = vld1q_f32(input_ptr2);
float32x4_t _q12 = vld1q_f32(input_ptr2 + 4);
_q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0);
float32x4_t _q13 = vextq_f32(_q7, _q8, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0);
_q13 = vextq_f32(_q7, _q8, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1);
_q13 = vextq_f32(_q7, _q8, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0);
_q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1);
_q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1);
_q15 = vmulq_lane_f32(_q9, vget_low_f32(_ker[5]), 0);
_q13 = vextq_f32(_q9, _q10, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 0);
_q13 = vextq_f32(_q9, _q10, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 1);
_q13 = vextq_f32(_q9, _q10, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[0]), 0);
_q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1);
_q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[0]), 1);
_q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0);
_q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[5]), 1);
_q13 = vextq_f32(_q11, _q12, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 0);
_q13 = vextq_f32(_q11, _q12, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 1);
_q13 = vextq_f32(_q11, _q12, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[1]), 0);
_q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1);
_q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[1]), 1);
_q7 = vld1q_f32(input_ptr3);
_q8 = vld1q_f32(input_ptr3 + 4);
_q9 = vld1q_f32(input_ptr4);
_q10 = vld1q_f32(input_ptr4 + 4);
_q11 = vld1q_f32(input_ptr5);
_q12 = vld1q_f32(input_ptr5 + 4);
_q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1);
_q15 = vmlaq_lane_f32(_q15, _q7, vget_high_f32(_ker[5]), 0);
_q13 = vextq_f32(_q7, _q8, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 0);
_q13 = vextq_f32(_q7, _q8, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 1);
_q13 = vextq_f32(_q7, _q8, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[2]), 0);
_q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1);
_q15 = vmlaq_lane_f32(_q15, _q8, vget_high_f32(_ker[2]), 1);
_q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0);
_q15 = vmlaq_lane_f32(_q15, _q9, vget_high_f32(_ker[5]), 1);
_q13 = vextq_f32(_q9, _q10, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 0);
_q13 = vextq_f32(_q9, _q10, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 1);
_q13 = vextq_f32(_q9, _q10, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[3]), 0);
_q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1);
_q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[3]), 1);
_q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[6]), 0);
_q13 = vextq_f32(_q11, _q12, 1);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 0);
_q13 = vextq_f32(_q11, _q12, 2);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 1);
_q13 = vextq_f32(_q11, _q12, 3);
_q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[4]), 0);
_q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[4]), 1);
switch (output_w_remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _q14, 2);
vst1q_lane_f32(output_ptr1 + 2, _q15, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_q14));
vst1_f32(output_ptr1, vget_low_f32(_q15));
break;
case 1:
vst1q_lane_f32(output_ptr0, _q14, 0);
vst1q_lane_f32(output_ptr1, _q15, 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
input_ptr3 += output_w_remain;
input_ptr4 += output_w_remain;
input_ptr5 += output_w_remain;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
}
#else
int loop = output_w_tiles;
asm volatile(
"cmp %[loop], #0 \n"
......@@ -443,6 +660,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
[kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6])
: "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15", "r0");
#endif // __aarch64__
// pad right
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
......@@ -540,7 +758,155 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
}
output_ptr0 += valid_w_start;
}
// valid
// valid
// #if __aarch64__
#if 0
float32x4_t _q14;
for (int loop = 0; loop = output_w_tiles; ++loop) {
float32x4_t _q7 = vld1q_f32(input_ptr0);
float32x4_t _q8 = vld1q_f32(input_ptr0 + 4);
float32x4_t _q9 = vld1q_f32(input_ptr1);
float32x4_t _q10 = vld1q_f32(input_ptr1 + 4);
float32x4_t _q11 = vld1q_f32(input_ptr2);
float32x4_t _q12 = vld1q_f32(input_ptr2 + 4);
_q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0);
float32x4_t _q13 = vextq_f32(_q7, _q8, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0);
_q13 = vextq_f32(_q7, _q8, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1);
_q13 = vextq_f32(_q7, _q8, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0);
_q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1);
_q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1);
_q13 = vextq_f32(_q9, _q10, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0);
_q13 = vextq_f32(_q9, _q10, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1);
_q13 = vextq_f32(_q9, _q10, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0);
_q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1);
_q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0);
_q13 = vextq_f32(_q11, _q12, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0);
_q13 = vextq_f32(_q11, _q12, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1);
_q13 = vextq_f32(_q11, _q12, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0);
_q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1);
_q7 = vld1q_f32(input_ptr3);
_q8 = vld1q_f32(input_ptr3 + 4);
_q9 = vld1q_f32(input_ptr4);
_q10 = vld1q_f32(input_ptr4 + 4);
_q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1);
_q13 = vextq_f32(_q7, _q8, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0);
_q13 = vextq_f32(_q7, _q8, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1);
_q13 = vextq_f32(_q7, _q8, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0);
_q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1);
_q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0);
_q13 = vextq_f32(_q9, _q10, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0);
_q13 = vextq_f32(_q9, _q10, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1);
_q13 = vextq_f32(_q9, _q10, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0);
_q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1);
vst1q_f32(output_ptr0, _q14);
input_ptr0 += 4;
input_ptr1 += 4;
input_ptr2 += 4;
input_ptr3 += 4;
input_ptr4 += 4;
output_ptr0 += 4;
}
// remain w
if (output_w_remain > 0) {
float32x4_t _q7 = vld1q_f32(input_ptr0);
float32x4_t _q8 = vld1q_f32(input_ptr0 + 4);
float32x4_t _q9 = vld1q_f32(input_ptr1);
float32x4_t _q10 = vld1q_f32(input_ptr1 + 4);
float32x4_t _q11 = vld1q_f32(input_ptr2);
float32x4_t _q12 = vld1q_f32(input_ptr2 + 4);
_q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0);
float32x4_t _q13 = vextq_f32(_q7, _q8, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0);
_q13 = vextq_f32(_q7, _q8, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1);
_q13 = vextq_f32(_q7, _q8, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0);
_q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1);
_q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1);
_q13 = vextq_f32(_q9, _q10, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0);
_q13 = vextq_f32(_q9, _q10, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1);
_q13 = vextq_f32(_q9, _q10, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0);
_q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1);
_q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0);
_q13 = vextq_f32(_q11, _q12, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0);
_q13 = vextq_f32(_q11, _q12, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1);
_q13 = vextq_f32(_q11, _q12, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0);
_q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1);
_q7 = vld1q_f32(input_ptr3);
_q8 = vld1q_f32(input_ptr3 + 4);
_q9 = vld1q_f32(input_ptr4);
_q10 = vld1q_f32(input_ptr4 + 4);
_q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1);
_q13 = vextq_f32(_q7, _q8, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0);
_q13 = vextq_f32(_q7, _q8, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1);
_q13 = vextq_f32(_q7, _q8, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0);
_q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1);
_q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0);
_q13 = vextq_f32(_q9, _q10, 1);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0);
_q13 = vextq_f32(_q9, _q10, 2);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1);
_q13 = vextq_f32(_q9, _q10, 3);
_q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0);
_q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1);
switch (output_w_remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _q14, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_q14));
break;
case 1:
vst1q_lane_f32(output_ptr0, _q14, 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
input_ptr3 += output_w_remain;
input_ptr4 += output_w_remain;
output_ptr0 += output_w_remain;
}
#else
int loop = output_w_tiles;
asm volatile(
"cmp %[loop], #0 \n"
......@@ -676,6 +1042,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
[kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6])
: "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15", "r0");
#endif // __aarch64__
// pad right
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
......
......@@ -544,12 +544,12 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
for (int l = 0; l < 2; ++l) {
float32x4x2_t _q23, _q45, _q67, _q89;
_q23.val[0] = vld1q_f32(ptr0);
_q23.val[0] = vld1q_f32(ptr0 + 4);
_q45.val[1] = vld1q_f32(ptr0 + 8);
_q23.val[1] = vld1q_f32(ptr0 + 4);
_q45.val[0] = vld1q_f32(ptr0 + 8);
_q45.val[1] = vld1q_f32(ptr0 + 12);
_q67.val[0] = vld1q_f32(ptr1);
_q67.val[0] = vld1q_f32(ptr1 + 4);
_q89.val[1] = vld1q_f32(ptr1 + 8);
_q67.val[1] = vld1q_f32(ptr1 + 4);
_q89.val[0] = vld1q_f32(ptr1 + 8);
_q89.val[1] = vld1q_f32(ptr1 + 12);
_q23 = vtrnq_f32(_q23.val[0], _q23.val[1]);
_q45 = vtrnq_f32(_q45.val[0], _q45.val[1]);
......@@ -1167,12 +1167,12 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
float32x4_t _q0 = vld1q_f32(transform_matrix);
float32x4x2_t _q23, _q45, _q67, _q89;
_q23.val[0] = vld1q_f32(at_m_ptr0);
_q23.val[0] = vld1q_f32(at_m_ptr0 + 4);
_q45.val[1] = vld1q_f32(at_m_ptr0 + 8);
_q23.val[1] = vld1q_f32(at_m_ptr0 + 4);
_q45.val[0] = vld1q_f32(at_m_ptr0 + 8);
_q45.val[1] = vld1q_f32(at_m_ptr0 + 12);
_q67.val[0] = vld1q_f32(at_m_ptr1);
_q67.val[0] = vld1q_f32(at_m_ptr1 + 4);
_q89.val[1] = vld1q_f32(at_m_ptr1 + 8);
_q67.val[1] = vld1q_f32(at_m_ptr1 + 4);
_q89.val[0] = vld1q_f32(at_m_ptr1 + 8);
_q89.val[1] = vld1q_f32(at_m_ptr1 + 12);
_q23 = vtrnq_f32(_q23.val[0], _q23.val[1]);
_q45 = vtrnq_f32(_q45.val[0], _q45.val[1]);
......@@ -1231,6 +1231,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
_q1 = vaddq_f32(_q9, _q7);
_q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1);
_q2 = vaddq_f32(_q12, _q8);
_q2 = vaddq_f32(_q2, _q14);
_q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1);
_q23 = vtrnq_f32(_q1, _q2);
vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0]));
......@@ -1287,8 +1288,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
_d20 = vadd_f32(_d20, _d15);
_d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1);
_d18d20 = vtrn_f32(_d18, _d20);
vst1_f32(out_ptr4 + 4, _d18);
vst1_f32(out_ptr5 + 4, _d20);
vst1_f32(out_ptr4 + 4, _d18d20.val[0]);
vst1_f32(out_ptr5 + 4, _d18d20.val[1]);
#else
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
......@@ -1428,12 +1429,12 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
float32x4_t _q0 = vld1q_f32(transform_matrix);
float32x4x2_t _q23, _q45, _q67, _q89;
_q23.val[0] = vld1q_f32(at_m_ptr0);
_q23.val[0] = vld1q_f32(at_m_ptr0 + 4);
_q45.val[1] = vld1q_f32(at_m_ptr0 + 8);
_q23.val[1] = vld1q_f32(at_m_ptr0 + 4);
_q45.val[0] = vld1q_f32(at_m_ptr0 + 8);
_q45.val[1] = vld1q_f32(at_m_ptr0 + 12);
_q67.val[0] = vld1q_f32(at_m_ptr1);
_q67.val[0] = vld1q_f32(at_m_ptr1 + 4);
_q89.val[1] = vld1q_f32(at_m_ptr1 + 8);
_q67.val[1] = vld1q_f32(at_m_ptr1 + 4);
_q89.val[0] = vld1q_f32(at_m_ptr1 + 8);
_q89.val[1] = vld1q_f32(at_m_ptr1 + 12);
_q23 = vtrnq_f32(_q23.val[0], _q23.val[1]);
_q45 = vtrnq_f32(_q45.val[0], _q45.val[1]);
......@@ -1489,6 +1490,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
_q1 = vaddq_f32(_q9, _q7);
_q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1);
_q2 = vaddq_f32(_q12, _q8);
_q2 = vaddq_f32(_q2, _q14);
_q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1);
_q23 = vtrnq_f32(_q1, _q2);
vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0]));
......@@ -1545,8 +1547,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
_d20 = vadd_f32(_d20, _d15);
_d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1);
_d18d20 = vtrn_f32(_d18, _d20);
vst1_f32(out_ptr4 + 4, _d18);
vst1_f32(out_ptr5 + 4, _d20);
vst1_f32(out_ptr4 + 4, _d18d20.val[0]);
vst1_f32(out_ptr5 + 4, _d18d20.val[1]);
#else
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册