提交 ab199ae0 编写于 作者: H hjchen2

Add winograd implementation for arm64

上级 a8b775ec
......@@ -79,12 +79,12 @@ void ConvAddBNReluKernel<CPU, float>::Compute(
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvAddBNReluParam<CPU>>(param);
break;
......
......@@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_add_arm_func.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -47,12 +49,12 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvAddBasic(param);
break;
......
......@@ -46,11 +46,11 @@ void ConvAddReluKernel<CPU, float>::Compute(
DepthwiseConv5x5<float, float>(param);
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvAddReluBasic<FusionConvAddReluParam<CPU>>(param);
break;
......
......@@ -79,12 +79,12 @@ void ConvBNAddReluKernel<CPU, float>::Compute(
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvBNAddReluParam<CPU>>(param);
break;
......
......@@ -78,12 +78,12 @@ void ConvBNReluKernel<CPU, float>::Compute(
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvBNReluParam<CPU>>(param);
break;
......
......@@ -53,6 +53,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
#endif
} else if (conv3x3 && !depth3x3 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
......@@ -68,7 +69,6 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
param->transformed_filter_ = new framework::LoDTensor;
operators::math::winograd_transform_weight<8, 3>(
*param->Filter(), param->transformed_filter_);
#endif
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
}
......
......@@ -55,10 +55,10 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
......
......@@ -15,10 +15,10 @@ limitations under the License. */
// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn
// project.
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#ifdef CONV_OP
#ifndef __aarch64__
#include <arm_neon.h>
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h"
......@@ -51,6 +51,10 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180};
const float *inptr = weight.data<float>();
#if __aarch64__
int remain_start = 0;
#else
int remain_start = out_channel & 0xFFFC;
#pragma omp parallel for
......@@ -256,6 +260,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
"q13", "r0");
}
}
#endif // __aarch64__
// remain output channel
#pragma omp parallel for
......@@ -358,6 +363,90 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
const float *in2 = in1 + width;
const float *in3 = in2 + width;
float *d_bt_ptr = d_bt;
#if __aarch64__
int steps = 4 * width;
float32x4_t _q0 = vld1q_f32(transform_matrix);
float32x4_t _q1 = vld1q_f32(transform_matrix + 4);
for (int l = 0; l < 2; ++l) {
float32x4x2_t _q23, _q45, _q67, _q89;
_q23.val[0] = vld1q_f32(in0);
_q45.val[0] = vld1q_f32(in0 + 4);
_q23.val[1] = vld1q_f32(in1);
_q45.val[1] = vld1q_f32(in1 + 4);
_q67.val[0] = vld1q_f32(in2);
_q89.val[0] = vld1q_f32(in2 + 4);
_q67.val[1] = vld1q_f32(in3);
_q89.val[1] = vld1q_f32(in3 + 4);
_q23 = vtrnq_f32(_q23.val[0], _q23.val[1]);
_q45 = vtrnq_f32(_q45.val[0], _q45.val[1]);
_q67 = vtrnq_f32(_q67.val[0], _q67.val[1]);
_q89 = vtrnq_f32(_q89.val[0], _q89.val[1]);
float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[0]),
vget_low_f32(_q67.val[0]));
float32x4_t _q4 = vcombine_f32(vget_low_f32(_q23.val[1]),
vget_low_f32(_q67.val[1]));
float32x4_t _q3 = vcombine_f32(vget_low_f32(_q45.val[0]),
vget_low_f32(_q89.val[0]));
float32x4_t _q5 = vcombine_f32(vget_low_f32(_q45.val[1]),
vget_low_f32(_q89.val[1]));
float32x4_t _q6 = vcombine_f32(vget_high_f32(_q23.val[0]),
vget_high_f32(_q67.val[0]));
float32x4_t _q8 = vcombine_f32(vget_high_f32(_q23.val[1]),
vget_high_f32(_q67.val[1]));
float32x4_t _q7 = vcombine_f32(vget_high_f32(_q45.val[0]),
vget_high_f32(_q89.val[0]));
float32x4_t _q9 = vcombine_f32(vget_high_f32(_q45.val[1]),
vget_high_f32(_q89.val[1]));
float32x4_t _q10 = vsubq_f32(_q2, _q7);
float32x4_t _q11 = vsubq_f32(_q3, _q6);
_q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0);
vst1q_f32(d_bt_ptr, _q10);
_q10 = vaddq_f32(_q6, _q7);
_q11 = vaddq_f32(_q4, _q5);
_q10 = vmlaq_lane_f32(_q10, _q3, vget_high_f32(_q0), 0);
_q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 0);
float32x4_t _q12 = vaddq_f32(_q10, _q11);
float32x4_t _q13 = vsubq_f32(_q10, _q11);
vst1q_f32(d_bt_ptr + 4, _q12);
vst1q_f32(d_bt_ptr + 8, _q13);
_q10 = vmulq_lane_f32(_q6, vget_high_f32(_q1), 1);
_q11 = vmulq_lane_f32(_q4, vget_high_f32(_q1), 0);
_q10 = vaddq_f32(_q10, _q7);
_q11 = vmlaq_lane_f32(_q11, _q5, vget_low_f32(_q1), 0);
_q10 = vmlaq_lane_f32(_q10, _q3, vget_low_f32(_q1), 1);
_q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 1);
_q12 = vaddq_f32(_q10, _q11);
_q13 = vsubq_f32(_q10, _q11);
vst1q_f32(d_bt_ptr + 12, _q12);
vst1q_f32(d_bt_ptr + 16, _q13);
_q10 = vmulq_lane_f32(_q6, vget_low_f32(_q1), 0);
_q11 = vmulq_lane_f32(_q4, vget_low_f32(_q1), 0);
_q10 = vmlaq_lane_f32(_q10, _q3, vget_high_f32(_q0), 1);
_q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 1);
_q10 = vmlaq_lane_f32(_q10, _q7, vget_high_f32(_q1), 0);
_q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q1), 0);
_q10 = vmulq_lane_f32(_q10, vget_low_f32(_q1), 0);
_q12 = vaddq_f32(_q10, _q11);
_q13 = vsubq_f32(_q10, _q11);
vst1q_f32(d_bt_ptr + 20, _q12);
vst1q_f32(d_bt_ptr + 24, _q13);
_q10 = vsubq_f32(_q9, _q4);
_q11 = vsubq_f32(_q8, _q5);
_q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0);
vst1q_f32(d_bt_ptr + 28, _q10);
in0 += steps;
in1 += steps;
in2 += steps;
in3 += steps;
d_bt_ptr += 32;
}
#else
int steps = 4 * width * sizeof(float);
asm volatile(
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
......@@ -434,7 +523,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
#endif // __aarch64__
float *ptr0 = d_bt;
float *ptr1 = ptr0 + 32;
int tile_indics = h * w_tiles + w;
......@@ -450,6 +539,120 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
float *out5 = out4 + channel * 8;
float *out6 = out5 + channel * 8;
float *out7 = out6 + channel * 8;
#if __aarch64__
steps = 8 * channel * 8;
for (int l = 0; l < 2; ++l) {
float32x4x2_t _q23, _q45, _q67, _q89;
_q23.val[0] = vld1q_f32(ptr0);
_q23.val[0] = vld1q_f32(ptr0 + 4);
_q45.val[1] = vld1q_f32(ptr0 + 8);
_q45.val[1] = vld1q_f32(ptr0 + 12);
_q67.val[0] = vld1q_f32(ptr1);
_q67.val[0] = vld1q_f32(ptr1 + 4);
_q89.val[1] = vld1q_f32(ptr1 + 8);
_q89.val[1] = vld1q_f32(ptr1 + 12);
_q23 = vtrnq_f32(_q23.val[0], _q23.val[1]);
_q45 = vtrnq_f32(_q45.val[0], _q45.val[1]);
_q67 = vtrnq_f32(_q67.val[0], _q67.val[1]);
_q89 = vtrnq_f32(_q89.val[0], _q89.val[1]);
float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[0]),
vget_low_f32(_q45.val[0]));
float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[0]),
vget_high_f32(_q45.val[0]));
float32x4_t _q3 = vcombine_f32(vget_low_f32(_q23.val[1]),
vget_low_f32(_q45.val[1]));
float32x4_t _q5 = vcombine_f32(vget_high_f32(_q23.val[1]),
vget_high_f32(_q45.val[1]));
float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[0]),
vget_low_f32(_q89.val[0]));
float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[0]),
vget_high_f32(_q89.val[0]));
float32x4_t _q7 = vcombine_f32(vget_low_f32(_q67.val[1]),
vget_low_f32(_q89.val[1]));
float32x4_t _q9 = vcombine_f32(vget_high_f32(_q67.val[1]),
vget_high_f32(_q89.val[1]));
float32x4_t _q10 = vsubq_f32(_q2, _q8);
float32x4_t _q11 = vsubq_f32(_q6, _q4);
_q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0);
vst1q_lane_f32(out0, _q10, 0);
vst1q_lane_f32(out0 + steps, _q10, 1);
vst1q_lane_f32(out0 + 2 * steps, _q10, 2);
vst1q_lane_f32(out0 + 3 * steps, _q10, 3);
_q10 = vaddq_f32(_q4, _q8);
_q11 = vaddq_f32(_q3, _q7);
_q10 = vmlaq_lane_f32(_q10, _q6, vget_high_f32(_q0), 0);
_q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 0);
float32x4_t _q12 = vaddq_f32(_q10, _q11);
vst1q_lane_f32(out1, _q12, 0);
vst1q_lane_f32(out1 + steps, _q12, 1);
vst1q_lane_f32(out1 + 2 * steps, _q12, 2);
vst1q_lane_f32(out1 + 3 * steps, _q12, 3);
_q12 = vsubq_f32(_q10, _q11);
vst1q_lane_f32(out2, _q12, 0);
vst1q_lane_f32(out2 + steps, _q12, 1);
vst1q_lane_f32(out2 + 2 * steps, _q12, 2);
vst1q_lane_f32(out2 + 3 * steps, _q12, 3);
_q10 = vmulq_lane_f32(_q4, vget_high_f32(_q1), 1);
_q11 = vmulq_lane_f32(_q3, vget_high_f32(_q1), 0);
_q10 = vaddq_f32(_q10, _q8);
_q11 = vmlaq_lane_f32(_q11, _q7, vget_low_f32(_q1), 0);
_q10 = vmlaq_lane_f32(_q10, _q6, vget_low_f32(_q1), 1);
_q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 1);
_q12 = vaddq_f32(_q10, _q11);
vst1q_lane_f32(out3, _q12, 0);
vst1q_lane_f32(out3 + steps, _q12, 1);
vst1q_lane_f32(out3 + 2 * steps, _q12, 2);
vst1q_lane_f32(out3 + 3 * steps, _q12, 3);
_q12 = vsubq_f32(_q10, _q11);
vst1q_lane_f32(out4, _q12, 0);
vst1q_lane_f32(out4 + steps, _q12, 1);
vst1q_lane_f32(out4 + 2 * steps, _q12, 2);
vst1q_lane_f32(out4 + 3 * steps, _q12, 3);
_q10 = vmulq_lane_f32(_q4, vget_low_f32(_q1), 0);
_q11 = vmulq_lane_f32(_q3, vget_low_f32(_q1), 0);
_q10 = vmlaq_lane_f32(_q10, _q6, vget_high_f32(_q0), 1);
_q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 1);
_q10 = vmlaq_lane_f32(_q10, _q8, vget_high_f32(_q1), 0);
_q11 = vmlaq_lane_f32(_q11, _q7, vget_high_f32(_q1), 0);
_q10 = vmulq_lane_f32(_q10, vget_low_f32(_q1), 0);
_q12 = vaddq_f32(_q10, _q11);
vst1q_lane_f32(out5, _q12, 0);
vst1q_lane_f32(out5 + steps, _q12, 1);
vst1q_lane_f32(out5 + 2 * steps, _q12, 2);
vst1q_lane_f32(out5 + 3 * steps, _q12, 3);
_q12 = vsubq_f32(_q10, _q11);
vst1q_lane_f32(out6, _q12, 0);
vst1q_lane_f32(out6 + steps, _q12, 1);
vst1q_lane_f32(out6 + 2 * steps, _q12, 2);
vst1q_lane_f32(out6 + 3 * steps, _q12, 3);
_q10 = vsubq_f32(_q9, _q3);
_q11 = vsubq_f32(_q5, _q7);
_q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0);
vst1q_lane_f32(out7, _q10, 0);
vst1q_lane_f32(out7 + steps, _q10, 1);
vst1q_lane_f32(out7 + 2 * steps, _q10, 2);
vst1q_lane_f32(out7 + 3 * steps, _q10, 3);
ptr0 += 16;
ptr1 += 16;
out0 += 4 * steps;
out1 += 4 * steps;
out2 += 4 * steps;
out3 += 4 * steps;
out4 += 4 * steps;
out5 += 4 * steps;
out6 += 4 * steps;
out7 += 4 * steps;
}
#else
steps = 8 * channel * 8 * sizeof(float);
asm volatile(
"mov r0, #2 \n"
......@@ -555,6 +758,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
#endif // __aarch64__
}
}
}
......@@ -587,6 +791,71 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
const float *in_ptr = input_ptr + (j * 64 + k) * in_channel * 8;
int inter_channel = in_channel >> 1;
int remain_channel = in_channel & 0x1;
#if __aarch64__
asm volatile(
"dup v8.4s, wzr \n"
"dup v9.4s, wzr \n"
"dup v10.4s, wzr \n"
"dup v11.4s, wzr \n"
"dup v12.4s, wzr \n"
"dup v13.4s, wzr \n"
"dup v14.4s, wzr \n"
"dup v15.4s, wzr \n"
"cmp %[inter], #0 \n"
"ble loop_1c_%= \n"
// loop 2 channels
"loop_2c_%=: \n"
"ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n"
"ld1 {v4.4s, v5.4s}, [%[in_ptr]], #32 \n"
"fmla v8.4s, v2.4s, v0.s[0] \n"
"fmla v9.4s, v3.4s, v0.s[0] \n"
"fmla v10.4s, v2.4s, v0.s[1] \n"
"fmla v11.4s, v3.4s, v0.s[1] \n"
"fmla v12.4s, v2.4s, v0.s[2] \n"
"fmla v13.4s, v3.4s, v0.s[2] \n"
"fmla v14.4s, v2.4s, v0.s[3] \n"
"fmla v15.4s, v3.4s, v0.s[3] \n"
"fmla v8.4s, v4.4s, v1.s[0] \n"
"fmla v9.4s, v5.4s, v1.s[0] \n"
"fmla v10.4s, v4.4s, v1.s[1] \n"
"fmla v11.4s, v5.4s, v1.s[1] \n"
"fmla v12.4s, v4.4s, v1.s[2] \n"
"fmla v13.4s, v5.4s, v1.s[2] \n"
"fmla v14.4s, v4.4s, v1.s[3] \n"
"fmla v15.4s, v5.4s, v1.s[3] \n"
"subs %[inter], %[inter], #1 \n"
"bne loop_2c_%= \n"
// loop 1 channel
"loop_1c_%=: \n"
"cmp %[remain], #0 \n"
"ble store_res_%= \n"
"ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n"
"fmla v8.4s, v2.4s, v0.s[0] \n"
"fmla v9.4s, v3.4s, v0.s[0] \n"
"fmla v10.4s, v2.4s, v0.s[1] \n"
"fmla v11.4s, v3.4s, v0.s[1] \n"
"fmla v12.4s, v2.4s, v0.s[2] \n"
"fmla v13.4s, v3.4s, v0.s[2] \n"
"fmla v14.4s, v2.4s, v0.s[3] \n"
"fmla v15.4s, v3.4s, v0.s[3] \n"
"store_res_%=: \n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[uv_ptr]], #64 \n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[uv_ptr]], #64 \n"
: [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr),
[inter] "+r"(inter_channel)
: [remain] "r"(remain_channel)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15");
#else
asm volatile(
"veor q8, q8, q8 \n"
"veor q9, q9, q9 \n"
......@@ -651,6 +920,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
: [remain_channel] "r"(remain_channel)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif // __aarch64__
}
}
}
......@@ -686,6 +956,116 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
int tile_block = tile_indics >> 3;
int block_indics = tile_indics & 0x7;
const float *uv_ptr0 = uv_ptr + tile_block * 64 * 32 + block_indics;
#if __aarch64__
float32x4_t _q0 = vld1q_f32(transform_matrix);
for (int l = 0; l < 2; ++l) {
float32x4_t _q1, _q2, _q3, _q4, _q5, _q6, _q7, _q8;
_q1 = vsetq_lane_f32(*uv_ptr0, _q1, 0);
uv_ptr0 += 32;
_q3 = vsetq_lane_f32(*uv_ptr0, _q3, 0);
uv_ptr0 += 32;
_q5 = vsetq_lane_f32(*uv_ptr0, _q5, 0);
uv_ptr0 += 32;
_q7 = vsetq_lane_f32(*uv_ptr0, _q7, 0);
uv_ptr0 += 32;
_q2 = vsetq_lane_f32(*uv_ptr0, _q2, 0);
uv_ptr0 += 32;
_q4 = vsetq_lane_f32(*uv_ptr0, _q4, 0);
uv_ptr0 += 32;
_q6 = vsetq_lane_f32(*uv_ptr0, _q6, 0);
uv_ptr0 += 32;
_q8 = vsetq_lane_f32(*uv_ptr0, _q8, 0);
uv_ptr0 += 32;
_q1 = vsetq_lane_f32(*uv_ptr0, _q1, 1);
uv_ptr0 += 32;
_q3 = vsetq_lane_f32(*uv_ptr0, _q3, 1);
uv_ptr0 += 32;
_q5 = vsetq_lane_f32(*uv_ptr0, _q5, 1);
uv_ptr0 += 32;
_q7 = vsetq_lane_f32(*uv_ptr0, _q7, 1);
uv_ptr0 += 32;
_q2 = vsetq_lane_f32(*uv_ptr0, _q2, 1);
uv_ptr0 += 32;
_q4 = vsetq_lane_f32(*uv_ptr0, _q4, 1);
uv_ptr0 += 32;
_q6 = vsetq_lane_f32(*uv_ptr0, _q6, 1);
uv_ptr0 += 32;
_q8 = vsetq_lane_f32(*uv_ptr0, _q8, 1);
uv_ptr0 += 32;
_q1 = vsetq_lane_f32(*uv_ptr0, _q1, 2);
uv_ptr0 += 32;
_q3 = vsetq_lane_f32(*uv_ptr0, _q3, 2);
uv_ptr0 += 32;
_q5 = vsetq_lane_f32(*uv_ptr0, _q5, 2);
uv_ptr0 += 32;
_q7 = vsetq_lane_f32(*uv_ptr0, _q7, 2);
uv_ptr0 += 32;
_q2 = vsetq_lane_f32(*uv_ptr0, _q2, 2);
uv_ptr0 += 32;
_q4 = vsetq_lane_f32(*uv_ptr0, _q4, 2);
uv_ptr0 += 32;
_q6 = vsetq_lane_f32(*uv_ptr0, _q6, 2);
uv_ptr0 += 32;
_q8 = vsetq_lane_f32(*uv_ptr0, _q8, 2);
uv_ptr0 += 32;
_q1 = vsetq_lane_f32(*uv_ptr0, _q1, 3);
uv_ptr0 += 32;
_q3 = vsetq_lane_f32(*uv_ptr0, _q3, 3);
uv_ptr0 += 32;
_q5 = vsetq_lane_f32(*uv_ptr0, _q5, 3);
uv_ptr0 += 32;
_q7 = vsetq_lane_f32(*uv_ptr0, _q7, 3);
uv_ptr0 += 32;
_q2 = vsetq_lane_f32(*uv_ptr0, _q2, 3);
uv_ptr0 += 32;
_q4 = vsetq_lane_f32(*uv_ptr0, _q4, 3);
uv_ptr0 += 32;
_q6 = vsetq_lane_f32(*uv_ptr0, _q6, 3);
uv_ptr0 += 32;
_q8 = vsetq_lane_f32(*uv_ptr0, _q8, 3);
uv_ptr0 += 32;
float32x4_t _q9 = vaddq_f32(_q3, _q5);
float32x4_t _q10 = vaddq_f32(_q7, _q2);
float32x4_t _q11 = vaddq_f32(_q4, _q6);
float32x4_t _q12 = vsubq_f32(_q3, _q5);
float32x4_t _q13 = vsubq_f32(_q7, _q2);
float32x4_t _q14 = vsubq_f32(_q4, _q6);
_q2 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0);
_q3 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0);
float32x4_t _q15 = vaddq_f32(_q1, _q9);
_q15 = vaddq_f32(_q15, _q10);
_q15 = vmlaq_lane_f32(_q15, _q3, vget_high_f32(_q0), 1);
vst1q_f32(at_m_ptr, _q15);
_q15 = vaddq_f32(_q12, _q2);
_q15 = vmlaq_lane_f32(_q15, _q14, vget_high_f32(_q0), 1);
vst1q_f32(at_m_ptr + 4, _q15);
_q15 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1);
_q15 = vmlaq_lane_f32(_q15, _q11, vget_high_f32(_q0), 0);
vst1q_f32(at_m_ptr + 8, _q15);
_q15 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0);
_q15 = vmlaq_lane_f32(_q15, _q14, vget_low_f32(_q0), 1);
vst1q_f32(at_m_ptr + 12, _q15);
_q15 = vaddq_f32(_q9, _q3);
_q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_q0), 1);
vst1q_f32(at_m_ptr + 16, _q15);
_q15 = vaddq_f32(_q12, _q8);
_q15 = vaddq_f32(_q15, _q14);
_q15 = vmlaq_lane_f32(_q15, _q2, vget_high_f32(_q0), 1);
vst1q_f32(at_m_ptr + 20, _q15);
at_m_ptr += 24;
}
#else
int steps = 32 * sizeof(float);
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
......@@ -771,6 +1151,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
#endif // __aarch64__
float *at_m_ptr0 = at_m;
float *at_m_ptr1 = at_m + 24;
......@@ -782,6 +1163,133 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
float *out_ptr3 = output_tmp + 18;
float *out_ptr4 = output_tmp + 24;
float *out_ptr5 = output_tmp + 30;
#if __aarch64__
float32x4_t _q0 = vld1q_f32(transform_matrix);
float32x4x2_t _q23, _q45, _q67, _q89;
_q23.val[0] = vld1q_f32(at_m_ptr0);
_q23.val[0] = vld1q_f32(at_m_ptr0 + 4);
_q45.val[1] = vld1q_f32(at_m_ptr0 + 8);
_q45.val[1] = vld1q_f32(at_m_ptr0 + 12);
_q67.val[0] = vld1q_f32(at_m_ptr1);
_q67.val[0] = vld1q_f32(at_m_ptr1 + 4);
_q89.val[1] = vld1q_f32(at_m_ptr1 + 8);
_q89.val[1] = vld1q_f32(at_m_ptr1 + 12);
_q23 = vtrnq_f32(_q23.val[0], _q23.val[1]);
_q45 = vtrnq_f32(_q45.val[0], _q45.val[1]);
_q67 = vtrnq_f32(_q67.val[0], _q67.val[1]);
_q89 = vtrnq_f32(_q89.val[0], _q89.val[1]);
float32x4_t _q1 = vcombine_f32(vget_low_f32(_q23.val[0]),
vget_low_f32(_q45.val[0]));
float32x4_t _q3 = vcombine_f32(vget_high_f32(_q23.val[0]),
vget_high_f32(_q45.val[0]));
float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[1]),
vget_low_f32(_q45.val[1]));
float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[1]),
vget_high_f32(_q45.val[1]));
float32x4_t _q5 = vcombine_f32(vget_low_f32(_q67.val[0]),
vget_low_f32(_q89.val[0]));
float32x4_t _q7 = vcombine_f32(vget_high_f32(_q67.val[0]),
vget_high_f32(_q89.val[0]));
float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[1]),
vget_low_f32(_q89.val[1]));
float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[1]),
vget_high_f32(_q89.val[1]));
float32x4_t _q9 = vaddq_f32(_q2, _q3);
float32x4_t _q10 = vaddq_f32(_q4, _q5);
float32x4_t _q11 = vaddq_f32(_q6, _q7);
float32x4_t _q12 = vsubq_f32(_q2, _q3);
float32x4_t _q13 = vsubq_f32(_q4, _q5);
float32x4_t _q14 = vsubq_f32(_q6, _q7);
_q6 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0);
_q7 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0);
_q1 = vaddq_f32(_q1, _q9);
_q1 = vaddq_f32(_q1, _q10);
_q1 = vmlaq_lane_f32(_q1, _q7, vget_high_f32(_q0), 1);
_q2 = vaddq_f32(_q12, _q6);
_q2 = vmlaq_lane_f32(_q2, _q14, vget_high_f32(_q0), 1);
_q3 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1);
_q3 = vmlaq_lane_f32(_q3, _q11, vget_high_f32(_q0), 0);
_q4 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0);
_q4 = vmlaq_lane_f32(_q4, _q14, vget_low_f32(_q0), 1);
_q23 = vtrnq_f32(_q1, _q2);
_q45 = vtrnq_f32(_q3, _q4);
vst1_f32(out_ptr0, vget_low_f32(_q23.val[0]));
vst1_f32(out_ptr0 + 2, vget_low_f32(_q45.val[0]));
vst1_f32(out_ptr1, vget_low_f32(_q23.val[1]));
vst1_f32(out_ptr1 + 2, vget_low_f32(_q45.val[1]));
vst1_f32(out_ptr2, vget_high_f32(_q23.val[0]));
vst1_f32(out_ptr2 + 2, vget_high_f32(_q45.val[0]));
vst1_f32(out_ptr3, vget_high_f32(_q23.val[1]));
vst1_f32(out_ptr3 + 2, vget_high_f32(_q45.val[1]));
_q1 = vaddq_f32(_q9, _q7);
_q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1);
_q2 = vaddq_f32(_q12, _q8);
_q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1);
_q23 = vtrnq_f32(_q1, _q2);
vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0]));
vst1_f32(out_ptr1 + 4, vget_low_f32(_q23.val[1]));
vst1_f32(out_ptr2 + 4, vget_high_f32(_q23.val[0]));
vst1_f32(out_ptr3 + 4, vget_high_f32(_q23.val[1]));
// remain 2 rows
_q1 = vld1q_f32(at_m_ptr0 + 16);
_q2 = vld1q_f32(at_m_ptr0 + 20);
_q3 = vld1q_f32(at_m_ptr1 + 16);
_q4 = vld1q_f32(at_m_ptr1 + 20);
_q23 = vtrnq_f32(_q1, _q2);
_q45 = vtrnq_f32(_q3, _q4);
float32x2_t _d2 = vget_low_f32(_q23.val[0]);
float32x2_t _d3 = vget_high_f32(_q23.val[0]);
float32x2_t _d4 = vget_low_f32(_q23.val[1]);
float32x2_t _d5 = vget_high_f32(_q23.val[1]);
float32x2_t _d6 = vget_low_f32(_q45.val[0]);
float32x2_t _d7 = vget_high_f32(_q45.val[0]);
float32x2_t _d8 = vget_low_f32(_q45.val[1]);
float32x2_t _d9 = vget_high_f32(_q45.val[1]);
float32x2_t _d10 = vadd_f32(_d4, _d3);
float32x2_t _d11 = vadd_f32(_d5, _d6);
float32x2_t _d12 = vadd_f32(_d8, _d7);
float32x2_t _d13 = vsub_f32(_d4, _d3);
float32x2_t _d14 = vsub_f32(_d5, _d6);
float32x2_t _d15 = vsub_f32(_d8, _d7);
float32x2_t _d16 = vmul_lane_f32(_d14, vget_low_f32(_q0), 0);
float32x2_t _d17 = vmul_lane_f32(_d12, vget_low_f32(_q0), 0);
float32x2_t _d18 = vadd_f32(_d2, _d10);
float32x2_t _d20 = vadd_f32(_d13, _d16);
float32x2_t _d19 = vmla_lane_f32(_d10, _d11, vget_low_f32(_q0), 1);
float32x2_t _d21 = vmla_lane_f32(_d13, _d14, vget_high_f32(_q0), 0);
_d18 = vadd_f32(_d18, _d11);
_d18 = vmla_lane_f32(_d18, _d17, vget_high_f32(_q0), 1);
_d20 = vmla_lane_f32(_d20, _d15, vget_high_f32(_q0), 1);
_d19 = vmla_lane_f32(_d19, _d12, vget_high_f32(_q0), 0);
_d21 = vmla_lane_f32(_d21, _d15, vget_low_f32(_q0), 1);
float32x2x2_t _d18d20 = vtrn_f32(_d18, _d20);
float32x2x2_t _d19d21 = vtrn_f32(_d19, _d21);
vst1_f32(out_ptr4, _d18d20.val[0]);
vst1_f32(out_ptr4 + 2, _d19d21.val[0]);
vst1_f32(out_ptr5, _d18d20.val[1]);
vst1_f32(out_ptr5 + 2, _d19d21.val[1]);
_d18 = vadd_f32(_d10, _d17);
_d18 = vmla_lane_f32(_d18, _d11, vget_high_f32(_q0), 1);
_d20 = vadd_f32(_d13, _d9);
_d20 = vadd_f32(_d20, _d15);
_d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1);
_d18d20 = vtrn_f32(_d18, _d20);
vst1_f32(out_ptr4 + 4, _d18);
vst1_f32(out_ptr5 + 4, _d20);
#else
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
// process 4 rows
......@@ -898,6 +1406,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif // __aarch64__
size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w;
float *out_ptr = output_ptr + offset;
int remain_row = out_h - 6 * tile_h;
......@@ -915,6 +1424,130 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
float *out_ptr3 = out_ptr2 + out_w;
float *out_ptr4 = out_ptr3 + out_w;
float *out_ptr5 = out_ptr4 + out_w;
#if __aarch64__
float32x4_t _q0 = vld1q_f32(transform_matrix);
float32x4x2_t _q23, _q45, _q67, _q89;
_q23.val[0] = vld1q_f32(at_m_ptr0);
_q23.val[0] = vld1q_f32(at_m_ptr0 + 4);
_q45.val[1] = vld1q_f32(at_m_ptr0 + 8);
_q45.val[1] = vld1q_f32(at_m_ptr0 + 12);
_q67.val[0] = vld1q_f32(at_m_ptr1);
_q67.val[0] = vld1q_f32(at_m_ptr1 + 4);
_q89.val[1] = vld1q_f32(at_m_ptr1 + 8);
_q89.val[1] = vld1q_f32(at_m_ptr1 + 12);
_q23 = vtrnq_f32(_q23.val[0], _q23.val[1]);
_q45 = vtrnq_f32(_q45.val[0], _q45.val[1]);
_q67 = vtrnq_f32(_q67.val[0], _q67.val[1]);
_q89 = vtrnq_f32(_q89.val[0], _q89.val[1]);
float32x4_t _q1 = vcombine_f32(vget_low_f32(_q23.val[0]),
vget_low_f32(_q45.val[0]));
float32x4_t _q3 = vcombine_f32(vget_high_f32(_q23.val[0]),
vget_high_f32(_q45.val[0]));
float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[1]),
vget_low_f32(_q45.val[1]));
float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[1]),
vget_high_f32(_q45.val[1]));
float32x4_t _q5 = vcombine_f32(vget_low_f32(_q67.val[0]),
vget_low_f32(_q89.val[0]));
float32x4_t _q7 = vcombine_f32(vget_high_f32(_q67.val[0]),
vget_high_f32(_q89.val[0]));
float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[1]),
vget_low_f32(_q89.val[1]));
float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[1]),
vget_high_f32(_q89.val[1]));
float32x4_t _q9 = vaddq_f32(_q2, _q3);
float32x4_t _q10 = vaddq_f32(_q4, _q5);
float32x4_t _q11 = vaddq_f32(_q6, _q7);
float32x4_t _q12 = vsubq_f32(_q2, _q3);
float32x4_t _q13 = vsubq_f32(_q4, _q5);
float32x4_t _q14 = vsubq_f32(_q6, _q7);
_q6 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0);
_q7 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0);
_q1 = vaddq_f32(_q1, _q9);
_q1 = vaddq_f32(_q1, _q10);
_q1 = vmlaq_lane_f32(_q1, _q7, vget_high_f32(_q0), 1);
_q2 = vaddq_f32(_q12, _q6);
_q2 = vmlaq_lane_f32(_q2, _q14, vget_high_f32(_q0), 1);
_q3 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1);
_q3 = vmlaq_lane_f32(_q3, _q11, vget_high_f32(_q0), 0);
_q4 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0);
_q4 = vmlaq_lane_f32(_q4, _q14, vget_low_f32(_q0), 1);
_q23 = vtrnq_f32(_q1, _q2);
_q45 = vtrnq_f32(_q3, _q4);
vst1_f32(out_ptr0, vget_low_f32(_q23.val[0]));
vst1_f32(out_ptr0 + 2, vget_low_f32(_q45.val[0]));
vst1_f32(out_ptr1, vget_low_f32(_q23.val[1]));
vst1_f32(out_ptr1 + 2, vget_low_f32(_q45.val[1]));
vst1_f32(out_ptr2, vget_high_f32(_q23.val[0]));
vst1_f32(out_ptr2 + 2, vget_high_f32(_q45.val[0]));
vst1_f32(out_ptr3, vget_high_f32(_q23.val[1]));
vst1_f32(out_ptr3 + 2, vget_high_f32(_q45.val[1]));
_q1 = vaddq_f32(_q9, _q7);
_q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1);
_q2 = vaddq_f32(_q12, _q8);
_q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1);
_q23 = vtrnq_f32(_q1, _q2);
vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0]));
vst1_f32(out_ptr1 + 4, vget_low_f32(_q23.val[1]));
vst1_f32(out_ptr2 + 4, vget_high_f32(_q23.val[0]));
vst1_f32(out_ptr3 + 4, vget_high_f32(_q23.val[1]));
// remain 2 rows
_q1 = vld1q_f32(at_m_ptr0 + 16);
_q2 = vld1q_f32(at_m_ptr0 + 20);
_q3 = vld1q_f32(at_m_ptr1 + 16);
_q4 = vld1q_f32(at_m_ptr1 + 20);
_q23 = vtrnq_f32(_q1, _q2);
_q45 = vtrnq_f32(_q3, _q4);
float32x2_t _d2 = vget_low_f32(_q23.val[0]);
float32x2_t _d3 = vget_high_f32(_q23.val[0]);
float32x2_t _d4 = vget_low_f32(_q23.val[1]);
float32x2_t _d5 = vget_high_f32(_q23.val[1]);
float32x2_t _d6 = vget_low_f32(_q45.val[0]);
float32x2_t _d7 = vget_high_f32(_q45.val[0]);
float32x2_t _d8 = vget_low_f32(_q45.val[1]);
float32x2_t _d9 = vget_high_f32(_q45.val[1]);
float32x2_t _d10 = vadd_f32(_d4, _d3);
float32x2_t _d11 = vadd_f32(_d5, _d6);
float32x2_t _d12 = vadd_f32(_d8, _d7);
float32x2_t _d13 = vsub_f32(_d4, _d3);
float32x2_t _d14 = vsub_f32(_d5, _d6);
float32x2_t _d15 = vsub_f32(_d8, _d7);
float32x2_t _d16 = vmul_lane_f32(_d14, vget_low_f32(_q0), 0);
float32x2_t _d17 = vmul_lane_f32(_d12, vget_low_f32(_q0), 0);
float32x2_t _d18 = vadd_f32(_d2, _d10);
float32x2_t _d20 = vadd_f32(_d13, _d16);
float32x2_t _d19 = vmla_lane_f32(_d10, _d11, vget_low_f32(_q0), 1);
float32x2_t _d21 = vmla_lane_f32(_d13, _d14, vget_high_f32(_q0), 0);
_d18 = vadd_f32(_d18, _d11);
_d18 = vmla_lane_f32(_d18, _d17, vget_high_f32(_q0), 1);
_d20 = vmla_lane_f32(_d20, _d15, vget_high_f32(_q0), 1);
_d19 = vmla_lane_f32(_d19, _d12, vget_high_f32(_q0), 0);
_d21 = vmla_lane_f32(_d21, _d15, vget_low_f32(_q0), 1);
float32x2x2_t _d18d20 = vtrn_f32(_d18, _d20);
float32x2x2_t _d19d21 = vtrn_f32(_d19, _d21);
vst1_f32(out_ptr4, _d18d20.val[0]);
vst1_f32(out_ptr4 + 2, _d19d21.val[0]);
vst1_f32(out_ptr5, _d18d20.val[1]);
vst1_f32(out_ptr5 + 2, _d19d21.val[1]);
_d18 = vadd_f32(_d10, _d17);
_d18 = vmla_lane_f32(_d18, _d11, vget_high_f32(_q0), 1);
_d20 = vadd_f32(_d13, _d9);
_d20 = vadd_f32(_d20, _d15);
_d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1);
_d18d20 = vtrn_f32(_d18, _d20);
vst1_f32(out_ptr4 + 4, _d18);
vst1_f32(out_ptr5 + 4, _d20);
#else
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
// process 4 rows
......@@ -1031,6 +1664,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif // __aarch64__
}
}
}
......@@ -1041,5 +1675,5 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
} // namespace operators
} // namespace paddle_mobile
#endif // __aarch64__
#endif // CONV_OP
#endif // __ARM_NEON__
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// We refer https://github.com/andravin/wincnn to access the winograd transform
// matrixs
#ifdef CONV_OP
#ifdef __aarch64__
#include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <>
void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
framework::Tensor *output) {
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
int out_channel = weight.dims()[0];
int in_channel = weight.dims()[1];
// reshape and alloc transformed weight
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{out_channel, in_channel, 64});
float *outptr = output->mutable_data<float>(transformed_shape);
const float *inptr = weight.data<float>();
for (int oc = 0; oc < out_channel; ++oc) {
for (int ic = 0; ic < in_channel; ++ic) {
size_t offset = oc * in_channel + ic;
float *kout = outptr + offset * 64;
const float *k = inptr + offset * 9;
float gw[3][8];
for (int i = 0; i < 3; ++i, k += 3) {
float g0 = k[0];
float g1 = k[1];
float g2 = k[2];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
gw[i][0] = g0;
gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2)
gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2)
gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2)
gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2)
gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2)
gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2)
gw[i][7] = g2;
}
for (int i = 0; i < 8; ++i, kout += 8) {
float g0 = gw[0][i];
float g1 = gw[1][i];
float g2 = gw[2][i];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
kout[0] = g0;
kout[1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (k0 + k1 + k2)
kout[2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (k0 - k1 + k2)
kout[3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2)
kout[4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2)
kout[5] = 1.f / 180 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2)
kout[6] = 1.f / 180 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2)
kout[7] = g2;
}
}
}
}
template <>
void winograd_transform_input<8, 3>(const framework::Tensor &input,
framework::Tensor *output) {
// tile input to [c, roundup(h/6), roundup(w/6), 64] and do transformation
int channel = input.dims()[1];
int height = input.dims()[2];
int width = input.dims()[3];
int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6
int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{channel, h_tiles, w_tiles, 64});
float *outptr = output->mutable_data<float>(transformed_shape);
memset(outptr, 0, channel * h_tiles * w_tiles * 64 * sizeof(float));
const float *inptr = input.data<float>();
// pack input to tiles
for (int c = 0; c < channel; ++c) {
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
const float *in0 = inptr + c * height * width;
const float *in1 = in0 + width;
const float *in2 = in1 + width;
const float *in3 = in2 + width;
const float *in4 = in3 + width;
const float *in5 = in4 + width;
const float *in6 = in5 + width;
const float *in7 = in6 + width;
float *out = outptr + c * h_tiles * w_tiles * 64;
for (int h = 0; h < inter_h; ++h) {
for (int w = 0; w < inter_w; ++w) {
memcpy(out, in0, 8 * sizeof(float));
memcpy(out + 8, in1, 8 * sizeof(float));
memcpy(out + 16, in2, 8 * sizeof(float));
memcpy(out + 24, in3, 8 * sizeof(float));
memcpy(out + 32, in4, 8 * sizeof(float));
memcpy(out + 40, in5, 8 * sizeof(float));
memcpy(out + 48, in6, 8 * sizeof(float));
memcpy(out + 56, in7, 8 * sizeof(float));
in0 += 6;
in1 += 6;
in2 += 6;
in3 += 6;
in4 += 6;
in5 += 6;
in6 += 6;
in7 += 6;
out += 64;
}
// remain width
if (remain_w > 2) {
memcpy(out, in0, remain_w * sizeof(float));
memcpy(out + 8, in1, remain_w * sizeof(float));
memcpy(out + 16, in2, remain_w * sizeof(float));
memcpy(out + 24, in3, remain_w * sizeof(float));
memcpy(out + 32, in4, remain_w * sizeof(float));
memcpy(out + 40, in5, remain_w * sizeof(float));
memcpy(out + 48, in6, remain_w * sizeof(float));
memcpy(out + 56, in7, remain_w * sizeof(float));
out += 64;
}
in0 += 5 * width + remain_w;
in1 += 5 * width + remain_w;
in2 += 5 * width + remain_w;
in3 += 5 * width + remain_w;
in4 += 5 * width + remain_w;
in5 += 5 * width + remain_w;
in6 += 5 * width + remain_w;
in7 += 5 * width + remain_w;
}
// remain height
if (remain_h > 2) {
for (int w = 0; w < inter_w; ++w) {
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out + rh * 8, in0 + rh * width, 8 * sizeof(float));
}
out += 64;
in0 += 6;
}
// remain width
if (remain_w > 2) {
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out + rh * 8, in0 + rh * width, remain_w * sizeof(float));
}
}
}
}
// transform tiles, compute B_T * d(c, b) * B
for (int c = 0; c < channel; ++c) {
for (int tile = 0; tile < h_tiles * w_tiles; ++tile) {
float *out = outptr + (c * h_tiles * w_tiles + tile) * 64;
// compute B_T * d(c, b)
float bd[8][8];
for (int i = 0; i < 8; ++i) {
float d0 = out[8 * i + 0];
float d1 = out[8 * i + 1];
float d2 = out[8 * i + 2];
float d3 = out[8 * i + 3];
float d4 = out[8 * i + 4];
float d5 = out[8 * i + 5];
float d6 = out[8 * i + 6];
float d7 = out[8 * i + 7];
bd[i][0] = d0 - d6 + (d4 - d2) * 5.25;
float v1 = d2 - 4.25 * d4 + d6;
float v2 = d1 - 4.25 * d3 + d5;
// d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6
bd[i][1] = v1 + v2;
// -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6
bd[i][2] = v1 - v2;
v1 = 0.25 * d2 - 1.25 * d4 + d6;
v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5;
// 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6
bd[i][3] = v1 + v2;
// -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6
bd[i][4] = v1 - v2;
v1 = 4 * d2 - 5 * d4 + d6;
v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5;
// 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6
bd[i][5] = v1 + v2;
// -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6
bd[i][6] = v1 - v2;
bd[i][7] = d7 - d1 + (d3 - d5) * 5.25;
}
// compute B_T * d(c, b) * B
for (int i = 0; i < 8; ++i, out += 8) {
float d0 = bd[0][i];
float d1 = bd[1][i];
float d2 = bd[2][i];
float d3 = bd[3][i];
float d4 = bd[4][i];
float d5 = bd[5][i];
float d6 = bd[6][i];
float d7 = bd[7][i];
out[0] = d0 - d6 + (d4 - d2) * 5.25;
float v1 = d2 - 4.25 * d4 + d6;
float v2 = d1 - 4.25 * d3 + d5;
// d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6
out[1] = v1 + v2;
// -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6
out[2] = v1 - v2;
v1 = 0.25 * d2 - 1.25 * d4 + d6;
v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5;
// 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6
out[3] = v1 + v2;
// -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6
out[4] = v1 - v2;
v1 = 4 * d2 - 5 * d4 + d6;
v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5;
// 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6
out[5] = v1 + v2;
// -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6
out[6] = v1 - v2;
out[7] = d7 - d1 + (d3 - d5) * 5.25;
}
}
}
}
template <>
void winograd_transform_output<8, 3>(const framework::Tensor &input,
const framework::Tensor &weight,
framework::Tensor *output) {
// input shape is [in_channel, h_tiles, w_tiles, 64]
// weight shape is [out_channel, in_channel, 64]
int in_channel = input.dims()[0];
int h_tiles = input.dims()[1];
int w_tiles = input.dims()[2];
int tiles = h_tiles * w_tiles;
int out_channel = weight.dims()[0];
// compute U*V first
framework::Tensor output_m;
framework::DDim shape =
framework::make_ddim(std::vector<int>{out_channel, tiles, 64});
float *output_m_ptr = output_m.mutable_data<float>(shape);
memset(output_m_ptr, 0, output_m.numel() * sizeof(float));
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
for (int i = 0; i < out_channel; ++i) {
for (int j = 0; j < tiles; ++j) {
const float *w_ptr = weight_ptr + i * in_channel * 64;
const float *in_ptr = input_ptr + j * 64;
float *m_ptr = output_m_ptr + (i * tiles + j) * 64;
for (int c = 0; c < in_channel; ++c) {
for (int k = 0; k < 64; ++k) {
m_ptr[k] += w_ptr[k] * in_ptr[k];
}
w_ptr += 64;
in_ptr += tiles * 64;
}
}
}
for (int oc = 0; oc < out_channel; ++oc) {
for (int tile = 0; tile < tiles; ++tile) {
float *m = output_m_ptr + (oc * tiles + tile) * 64;
// compute A_T * m
float am[6][8];
for (int i = 0; i < 8; ++i) {
float d0 = m[i * 8 + 0];
float d1 = m[i * 8 + 1];
float d2 = m[i * 8 + 2];
float d3 = m[i * 8 + 3];
float d4 = m[i * 8 + 4];
float d5 = m[i * 8 + 5];
float d6 = m[i * 8 + 6];
float d7 = m[i * 8 + 7];
float v0 = d1 + d2;
float v1 = d1 - d2;
float v2 = d3 + d4;
float v3 = d3 - d4;
float v4 = d5 + d6;
float v5 = d5 - d6;
am[0][i] = d0 + v0 + v2 + 32 * v4;
am[1][i] = v1 + 2 * v3 + 16 * v5;
am[2][i] = v0 + 4 * v2 + 8 * v4;
am[3][i] = v1 + 8 * v3 + 4 * v5;
am[4][i] = v0 + 16 * v2 + 2 * v4;
am[5][i] = v1 + 32 * v3 + v5 + d7;
}
// compute A_T * m * A
for (int i = 0; i < 6; ++i, m += 8) {
float d0 = am[i][0];
float d1 = am[i][1];
float d2 = am[i][2];
float d3 = am[i][3];
float d4 = am[i][4];
float d5 = am[i][5];
float d6 = am[i][6];
float d7 = am[i][7];
float v0 = d1 + d2;
float v1 = d1 - d2;
float v2 = d3 + d4;
float v3 = d3 - d4;
float v4 = d5 + d6;
float v5 = d5 - d6;
m[0] = d0 + v0 + v2 + 32 * v4;
m[1] = v1 + 2 * v3 + 16 * v5;
m[2] = v0 + 4 * v2 + 8 * v4;
m[3] = v1 + 8 * v3 + 4 * v5;
m[4] = v0 + 16 * v2 + 2 * v4;
m[5] = v1 + 32 * v3 + v5 + d7;
}
}
}
int out_h = output->dims()[2];
int out_w = output->dims()[3];
float *output_ptr = output->mutable_data<float>();
// copy valid region to final output
for (int oc = 0; oc < out_channel; ++oc) {
int inter_h = out_h / 6;
int inter_w = out_w / 6;
int remain_h = out_h - inter_h * 6;
int remain_w = out_w - inter_w * 6;
float *out_ptr0 = output_ptr + oc * out_h * out_w;
float *out_ptr1 = out_ptr0 + out_w;
float *out_ptr2 = out_ptr1 + out_w;
float *out_ptr3 = out_ptr2 + out_w;
float *out_ptr4 = out_ptr3 + out_w;
float *out_ptr5 = out_ptr4 + out_w;
const float *m_ptr = output_m_ptr + oc * tiles * 64;
for (int tile_h = 0; tile_h < inter_h; ++tile_h) {
for (int tile_w = 0; tile_w < inter_w; ++tile_w) {
const float *m = m_ptr + (tile_h * w_tiles + tile_w) * 64;
memcpy(out_ptr0, m, 6 * sizeof(float));
memcpy(out_ptr1, m + 8, 6 * sizeof(float));
memcpy(out_ptr2, m + 16, 6 * sizeof(float));
memcpy(out_ptr3, m + 24, 6 * sizeof(float));
memcpy(out_ptr4, m + 32, 6 * sizeof(float));
memcpy(out_ptr5, m + 40, 6 * sizeof(float));
out_ptr0 += 6;
out_ptr1 += 6;
out_ptr2 += 6;
out_ptr3 += 6;
out_ptr4 += 6;
out_ptr5 += 6;
}
// remain w
if (remain_w > 0) {
const float *m = m_ptr + (tile_h * w_tiles + inter_w) * 64;
memcpy(out_ptr0, m, remain_w * sizeof(float));
memcpy(out_ptr1, m + 8, remain_w * sizeof(float));
memcpy(out_ptr2, m + 16, remain_w * sizeof(float));
memcpy(out_ptr3, m + 24, remain_w * sizeof(float));
memcpy(out_ptr4, m + 32, remain_w * sizeof(float));
memcpy(out_ptr5, m + 40, remain_w * sizeof(float));
out_ptr0 += remain_w;
out_ptr1 += remain_w;
out_ptr2 += remain_w;
out_ptr3 += remain_w;
out_ptr4 += remain_w;
out_ptr5 += remain_w;
}
out_ptr0 += 5 * out_w;
out_ptr1 += 5 * out_w;
out_ptr2 += 5 * out_w;
out_ptr3 += 5 * out_w;
out_ptr4 += 5 * out_w;
out_ptr5 += 5 * out_w;
}
// remain h
if (remain_h > 0) {
for (int tile_w = 0; tile_w < inter_w; ++tile_w) {
const float *m = m_ptr + (inter_h * w_tiles + tile_w) * 64;
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out_ptr0 + rh * out_w, m + rh * 8, 6 * sizeof(float));
}
out_ptr0 += 6;
}
if (remain_w > 0) {
const float *m = m_ptr + (inter_h * w_tiles + inter_w) * 64;
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out_ptr0 + rh * out_w, m + rh * 8, remain_w * sizeof(float));
}
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __aarch64__
#endif // CONV_OP
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册