提交 68e8dc4a 编写于 作者: H hjchen2

Support winograd algo to speed up 3x3 convlution operator

上级 4af571c4
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "operators/math/math_function.h"
#include "operators/math/pad.h"
#include "operators/math/vol2col.h"
#include "operators/math/winograd/winograd.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......@@ -116,6 +117,34 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
}
}
inline void BatchConv3x3Winograd(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor *filter = param.Filter();
Tensor *output = param.Output();
output->mutable_data<float>();
int batch_size = input->dims()[0];
int groups = param.Groups();
const std::vector<int> &paddings = param.Paddings();
math::PadFunctor<CPU, float> pad;
Tensor input_pad;
for (int i = 0; i < batch_size; ++i) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
if (paddings[0] == 0 && paddings[1] == 0) {
input_pad = in_batch;
} else {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
&input_pad);
}
math::winograd_f6k3(input_pad, *filter, &out_batch);
}
}
template <typename P>
void ConvCompute(const ConvParam<CPU> &param) {
if (param.Input()->type() == typeid(int8_t)) {
......@@ -133,6 +162,12 @@ void ConvCompute(const ConvParam<CPU> &param) {
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
} else if (param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Strides()[0] == param.Strides()[1] &&
param.Dilations()[0] == param.Dilations()[1] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Dilations()[0] == 1 && param.Input()->dims()[1] > 16) {
BatchConv3x3Winograd(param);
} else {
ConvBasic<float, float>(param);
}
......
......@@ -249,7 +249,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
float *output_data = output->mutable_data<float>();
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
......
......@@ -21,10 +21,12 @@ namespace math {
template <typename T>
class PadFunctor<CPU, T> {
public:
void operator()(const framework::Tensor &input, const int pad_h,
const int pad_w, framework::Tensor *output) {
void operator()(const framework::Tensor &input, const int pad_top,
const int pad_bottom, const int pad_left, const int pad_right,
framework::Tensor *output) {
const T *in_data = input.data<T>();
T *out_data = output->mutable_data<T>();
// should check output shape is valid for such pad parameters
const framework::DDim &input_shape = input.dims();
const framework::DDim &output_shape = output->dims();
// fill output with 0
......@@ -32,13 +34,13 @@ class PadFunctor<CPU, T> {
// should make sure the shape of output is match with input
for (int i = 0; i < input_shape[0]; ++i) {
for (int c = 0; c < input_shape[1]; ++c) {
out_data += pad_h * output_shape[3];
out_data += pad_top * output_shape[3];
for (int h = 0; h < input_shape[2]; ++h) {
memcpy(out_data + pad_w, in_data, sizeof(T) * input_shape[3]);
memcpy(out_data + pad_left, in_data, sizeof(T) * input_shape[3]);
out_data += output_shape[3];
in_data += input_shape[3];
}
out_data += pad_h * output_shape[3];
out_data += pad_bottom * output_shape[3];
}
}
}
......
......@@ -22,8 +22,9 @@ namespace math {
template <typename DeviceType, typename T>
class PadFunctor {
public:
void operator()(const framework::Tensor &input, const int pad_h,
const int pad_w, framework::Tensor *output);
void operator()(const framework::Tensor &input, const int pad_top,
const int pad_bottom, const int pad_left, const int pad_right,
framework::Tensor *output);
};
} // namespace math
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#include "operators/math/winograd/winograd.h"
#include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile {
namespace operators {
namespace math {
// F(2X2, 3X3)
void winograd_f2k3(const framework::Tensor &input,
const framework::Tensor &weight, framework::Tensor *output) {
}
// F(6X6, 3X3)
void winograd_f6k3(const framework::Tensor &input,
const framework::Tensor &weight, framework::Tensor *output) {
framework::Tensor transformed_input;
framework::Tensor transformed_weight;
// transform weight
winograd_transform_weight<8, 3>(weight, &transformed_weight);
// tile input and transform
winograd_transform_input<8, 3>(input, &transformed_input);
// caculate output
winograd_transform_output<8, 3>(transformed_input, transformed_weight,
output);
}
// F(4X4, 5X5)
void winograd_f4k5(const framework::Tensor &input,
const framework::Tensor &weight, framework::Tensor *output) {
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#pragma once
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
// F(2X2, 3X3)
void winograd_f2k3(const framework::Tensor &input,
const framework::Tensor &weight, framework::Tensor *output);
// F(6X6, 3X3)
void winograd_f6k3(const framework::Tensor &input,
const framework::Tensor &weight, framework::Tensor *output);
// F(4X4, 5X5)
void winograd_f4k5(const framework::Tensor &input,
const framework::Tensor &weight, framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#pragma once
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <int tile, int kernel>
void winograd_transform_weight(const framework::Tensor &weight,
framework::Tensor *output);
template <int tile, int kernel>
void winograd_transform_input(const framework::Tensor &input,
framework::Tensor *output);
template <int tile, int kernel>
void winograd_transform_output(const framework::Tensor &input,
const framework::Tensor &weight,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// refered from https://arxiv.org/abs/1509.09308 and
// https://github.com/andravin/wincnn
#pragma once
#include <arm_neon.h>
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <>
void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
framework::Tensor *output) {
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
int out_channel = weight.dims()[0];
int in_channel = weight.dims()[1];
// reshape and alloc transformed weight
// framework::DDim transformed_shape = framework::make_ddim(
// std::vector<int>{out_channel, (in_channel + 3) / 4, 8, 32});
framework::DDim transformed_shape = framework::make_ddim(
std::vector<int>{(out_channel + 3) / 4, 64, in_channel, 4});
float *outptr = output->mutable_data<float>(transformed_shape);
memset(outptr, 0, output->numel() * sizeof(float));
/*
* w0 = g0
* w1 = ((g0 + g2) + g1) * (-2.0 / 9)
* w2 = ((g0 + g2) - g1) * (-2.0 / 9)
* w3 = ((g0 + 4 * g2) + 2 * g1) * (1.0 / 90)
* w4 = ((g0 + 4 * g2) - 2 * g1) * (1.0 / 90)
* w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180)
* w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180)
* w7 = g2
*/
/*
const float *inptr = weight.data<float>();
for (int oc = 0; oc < out_channel; ++oc) {
for (int ic = 0; ic < in_channel; ++ic) {
size_t offset = oc * in_channel + ic;
float *kout = outptr + offset * 64;
const float *k = inptr + offset * 9;
float gw[8][3];
for (int i = 0; i < 3; ++i) {
float k0 = k[i];
float k1 = k[3 + i];
float k2 = k[6 + i];
float d0 = k0 + k2;
float d1 = k0 + 4 * k2;
float d2 = 4 * k0 + k2;
float d3 = 2 * k1;
gw[0][i] = k0;
gw[1][i] = -2.f/9 * (d0 + k1); // -2.f/9 * (k0 + k1 + k2)
gw[2][i] = -2.f/9 * (d0 - k1); // -2.f/9 * (k0 - k1 + k2)
gw[3][i] = 1.f/90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2)
gw[4][i] = 1.f/90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2)
gw[5][i] = 8.f/45 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2)
gw[6][i] = 8.f/45 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2)
gw[7][i] = k2;
}
for (int i = 0; i < 8; ++i, kout += 8) {
float k0 = gw[i][0];
float k1 = gw[i][1];
float k2 = gw[i][2];
float d0 = k0 + k2;
float d1 = k0 + 4 * k2;
float d2 = 4 * k0 + k2;
float d3 = 2 * k1;
kout[0] = gw[i][0];
kout[1] = -2.f/9 * (d0 + k1); // -2.f/9 * (k0 + k1 + k2)
kout[2] = -2.f/9 * (d0 - k1); // -2.f/9 * (k0 - k1 + k2)
kout[3] = 1.f/90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2)
kout[4] = 1.f/90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2)
kout[5] = 8.f/45 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2)
kout[6] = 8.f/45 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2)
kout[7] = gw[i][2];
}
}
}
*/
}
template <>
void winograd_transform_input<8, 3>(const framework::Tensor &input,
framework::Tensor *output) {
// pack input to [8 * roundup(h/6), 8 * roundup(w/6), channel] tiles
int channel = input.dims()[1];
int height = input.dims()[2];
int width = input.dims()[3];
int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6
int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6
int tiles = (h_tiles * w_tiles + 7) / 8;
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{tiles, 64, channel, 8});
float *outptr = output->mutable_data<float>(transformed_shape);
memset(outptr, 0, output->numel() * sizeof(float));
const float *inptr = input.data<float>();
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
framework::Tensor input_pad;
if (remain_h > 2 || remain_w > 2) {
inter_h += (remain_h > 2);
inter_w += (remain_w > 2);
height = (inter_h - 1) * 6 + 8;
width = (inter_w - 1) * 6 + 8;
framework::DDim input_shape =
framework::make_ddim(std::vector<int>{1, channel, height, width});
PadFunctor<CPU, float> pad;
inptr = input_pad.mutable_data<float>(input_shape);
pad(input, 0, height - input.dims()[2], 0, width - input.dims()[3],
&input_pad);
}
size_t image_size = height * width;
const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f,
2.f, -1.25f, 0.5f, 0.25f};
#pragma omp parallel for
for (int c = 0; c < channel; c += 4) {
const float *in = inptr + c * image_size;
float d_bt[64 * 4]; // d * B_t
for (int h = 0; h < h_tiles; ++h) {
for (int w = 0; w < w_tiles; ++w) {
const float *in0 = in + (h * width + w) * 6;
const float *in1 = in0 + image_size;
const float *in2 = in1 + image_size;
const float *in3 = in2 + image_size;
float *d_bt_ptr = d_bt;
asm volatile(
"mov r0, #8 \n"
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
// row loop
"loop_r_%=: \n"
"vld1.32 {d4-d7}, [%[in0]], %[width] \n"
"vld1.32 {d8-d11}, [%[in1]], %[width] \n"
"vld1.32 {d12-d15}, [%[in2]], %[width] \n"
"vld1.32 {d16-d19}, [%[in3]], %[width] \n"
"vtrn.32 q2, q4 \n"
"vtrn.32 q3, q5 \n"
"vtrn.32 q6, q8 \n"
"vtrn.32 q7, q9 \n"
"vswp.32 d5, d12 \n"
"vswp.32 d9, d16 \n"
"vswp.32 d7, d14 \n"
"vswp.32 d11, d18 \n" // q2: d0, q3: d1,
// ..., q9: d7
"vsub.f32 q10, q2, q8 \n"
"vsub.f32 q11, q6, q4 \n"
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"subs r0, #1 \n"
"bne loop_r_%= \n"
: [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1),
[in2] "+r"(in2), [in3] "+r"(in3)
: [tm_ptr] "r"((float *)transform_matrix), [width] "r"(width)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
float *ptr0 = d_bt;
float *ptr1 = ptr0 + 32;
float *ptr2 = ptr1 + 32;
float *ptr3 = ptr2 + 32;
float *ptr4 = ptr3 + 32;
float *ptr5 = ptr4 + 32;
float *ptr6 = ptr5 + 32;
float *ptr7 = ptr6 + 32;
int tile_id = h * w_tiles + w;
int block_id = tile_id >> 3;
int pack_id = tile_id & 0x7;
// (tiles / 8, 64, channel, 8)
float *out0 = outptr + (block_id * 64 * channel + c) * 8 + pack_id;
int steps = channel * 8;
asm volatile(
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
"mov r0, 4 \n"
"loop_col_%=: \n"
// col 0:
"vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0
"vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1
"vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2
"vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3
"vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4
"vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5
"vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6
"vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7
"vsub.f32 q10, q2, q8 \n" // d0 - d6
"vsub.f32 q11, q6, q4 \n" // d4 - d2
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20-d21}, [%[out0]], %[steps] \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
// col 1:
"vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0
"vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1
"vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2
"vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3
"vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4
"vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5
"vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6
"vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7
"vsub.f32 q10, q2, q8 \n" // d0 - d6
"vsub.f32 q11, q6, q4 \n" // d4 - d2
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20-d21}, [%[out0]], %[steps] \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d24-d25}, [%[out0]], %[steps] \n"
"subs r0, #1 \n"
"bne loop_col_%= \n"
: [out0] "+r"(out0), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3), [ptr4] "+r"(ptr4),
[ptr5] "+r"(ptr5), [ptr6] "+r"(ptr6), [ptr7] "+r"(ptr7)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
}
}
}
// remainer channels
int remain_c_start = ((channel >> 2) << 2);
for (int c = remain_c_start; c < channel; ++c) {
const float *in = inptr + c * image_size;
float d_bt[64]; // d * B_t
for (int h = 0; h < h_tiles; ++h) {
for (int w = 0; w < w_tiles; ++w) {
const float *in0 = in + (h * width + w) * 6;
const float *in1 = in0 + width;
const float *in2 = in1 + width;
const float *in3 = in2 + width;
float *d_bt_ptr = d_bt;
int steps = 4 * width;
asm volatile(
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
"mov r0, #2 \n"
// row loop
"loop_r_%=: \n"
"vld1.32 {d4-d7}, [%[in0]], %[steps] \n"
"vld1.32 {d8-d11}, [%[in1]], %[steps] \n"
"vld1.32 {d12-d15}, [%[in2]], %[steps] \n"
"vld1.32 {d16-d19}, [%[in3]], %[steps] \n"
"vtrn.32 q2, q4 \n"
"vtrn.32 q3, q5 \n"
"vtrn.32 q6, q8 \n"
"vtrn.32 q7, q9 \n"
"vswp.32 d5, d12 \n"
"vswp.32 d9, d16 \n"
"vswp.32 d7, d14 \n"
"vswp.32 d11, d18 \n" // q2: d0, q3: d1,
// ..., q9: d7
"vsub.f32 q10, q2, q8 \n"
"vsub.f32 q11, q6, q4 \n"
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25"
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"subs r0, #1 \n"
"bne loop_r_%= \n"
: [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1),
[in2] "+r"(in2), [in3] "+r"(in3)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
float *ptr0 = d_bt;
float *ptr1 = ptr0 + 32;
int tile_id = h * w_tiles + w;
int block_id = tile_id >> 3;
int pack_id = tile_id & 0x7;
// (tiles / 8, 64, channel, 8)
float *out0 = outptr + (block_id * 64 * channel + c) * 8 + pack_id;
steps = channel * 8;
asm volatile(
"mov r0, #2 \n"
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
// row loop
"loop_r_%=: \n"
"vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0
"vld1.32 {d6-d7}, [%[ptr0]]! \n" // q3: d1
"vld1.32 {d8-d9}, [%[ptr0]]! \n" // q4: d2
"vld1.32 {d10-d11}, [%[ptr0]]! \n" // q5: d3
"vld1.32 {d12-d13}, [%[ptr1]]! \n" // q6: d4
"vld1.32 {d14-d15}, [%[ptr1]]! \n" // q7: d5
"vld1.32 {d16-d17}, [%[ptr1]]! \n" // q8: d6
"vld1.32 {d18-d19}, [%[ptr1]]! \n" // q9: d7
"vtrn.32 q2, q3 \n"
"vtrn.32 q4, q5 \n"
"vtrn.32 q6, q7 \n"
"vtrn.32 q8, q9 \n"
"vswp.32 d5, d8 \n"
"vswp.32 d7, d10 \n"
"vswp.32 d13, d16 \n"
"vswp.32 d15, d18 \n"
"vsub.f32 q10, q2, q8 \n" // d0 - d6
"vsub.f32 q11, q6, q4 \n" // d4 - d2
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d20[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d21[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d20[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d21[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"subs r0, #1 \n"
"bne loop_r_%= \n"
: [out0] "+r"(out0), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
}
}
}
}
template <>
void winograd_transform_output<8, 3>(const framework::Tensor &input,
const framework::Tensor &weight,
framework::Tensor *output) {
// weight shape is [out_channel/4, 64, in_channel, 4],
// input shape is [hw/8, 64, in_channel, 8]
int in_channel = input.dims()[2];
int tiles = input.dims()[0];
int out_channel = weight.dims()[0];
// compute U*V first
framework::Tensor uv_trans;
framework::DDim shape =
framework::make_ddim(std::vector<int>{4 * out_channel, 8 * tiles, 64});
float *uv_trans_ptr = uv_trans.mutable_data<float>(shape);
memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float));
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
for (int i = 0; i < out_channel; ++i) {
float *uv_ptr = uv_trans_ptr + (i * tiles * 64 * 32);
for (int k = 0; k < 64; ++k) {
const float *w_ptr = weight_ptr + (i * 64 + k) * in_channel * 4;
for (int j = 0; j < tiles; ++j) {
const float *in_ptr = input_ptr + (k * tiles + j) * in_channel * 8;
float *out0 = uv_ptr + (8 * j) * 64 + k; // out channel 0
float *out1 = out0 + 8 * tiles * 64; // out channel 1
float *out2 = out1 + 8 * tiles * 64; // out channel 2
float *out3 = out2 + 8 * tiles * 64; // out channel 3
int inter_channel = in_channel >> 1;
int remain_channel = in_channel & 0x1;
int steps = 64;
asm volatile(
"veor q8, q8, q8 \n"
"veor q9, q9, q9 \n"
"veor q10, q10, q10 \n"
"veor q11, q11, q11 \n"
"veor q12, q12, q12 \n"
"veor q13, q13, q13 \n"
"veor q14, q14, q14 \n"
"veor q15, q15, q15 \n"
// loop 2 channels
"cmp %[inter_channel], #0 \n"
"ble cmp_remain_%= \n"
"loop_4c_%=: \n"
"vld1.32 {d0-d3}, [%[w_ptr]]! \n"
"vld1.32 {d4-d7}, [%[in_ptr]]! \n"
"vmla.f32 q8, q2, d0[0] \n"
"vmla.f32 q9, q3, d0[0] \n"
"vmla.f32 q10, q2, d0[1] \n"
"vmla.f32 q11, q3, d0[1] \n"
"vmla.f32 q12, q2, d1[0] \n"
"vmla.f32 q13, q3, d1[0] \n"
"vmla.f32 q14, q2, d1[1] \n"
"vmla.f32 q15, q3, d1[1] \n"
"vld1.32 {d8-d11}, [%[in_ptr]]! \n"
"vmla.f32 q8, q4, d2[0] \n"
"vmla.f32 q9, q5, d2[0] \n"
"vmla.f32 q10, q4, d2[1] \n"
"vmla.f32 q11, q5, d2[1] \n"
"vmla.f32 q12, q4, d3[0] \n"
"vmla.f32 q13, q5, d3[0] \n"
"vmla.f32 q14, q4, d3[1] \n"
"vmla.f32 q15, q5, d3[1] \n"
"subs %[inter_channel], #1 \n"
"bne loop_4c_%= \n"
// cmp remain channel > 0
"cmp_remain_%=: \n"
"cmp %[remain_channel], #0 \n"
"ble store_res_%= \n"
// loop 1 channel
"loop_c_%=: \n"
"vld1.32 {d0-d1}, [%[w_ptr]]! \n"
"vld1.32 {d4-d7}, [%[in_ptr]]! \n"
"vmla.f32 q8, q2, d0[0] \n"
"vmla.f32 q9, q3, d0[0] \n"
"vmla.f32 q10, q2, d0[1] \n"
"vmla.f32 q11, q3, d0[1] \n"
"vmla.f32 q12, q2, d1[0] \n"
"vmla.f32 q13, q3, d1[0] \n"
"vmla.f32 q14, q2, d1[1] \n"
"vmla.f32 q15, q3, d1[1] \n"
"subs %[remain_channel], #1 \n"
"bne loop_c_%= \n"
"store_res_%=: \n"
"vst1.32 {d16[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d16[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d17[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d17[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d18[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d18[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d19[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d19[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d20[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d20[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d21[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d21[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d22[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d22[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d23[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d23[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d24[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d26[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d26[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d27[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d27[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d28[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d28[1]}, [%[out3]], %[steps] \n"
"vst1.32 {d29[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d29[1]}, [%[out3]], %[steps] \n"
"vst1.32 {d30[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d30[1]}, [%[out3]], %[steps] \n"
"vst1.32 {d31[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d31[1]}, [%[out3]], %[steps] \n"
: [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [out0] "+r"(out0),
[out1] "+r"(out1), [out2] "+r"(out2), [out3] "+r"(out3),
[remain_channel] "+r"(remain_channel),
[inter_channel] "+r"(inter_channel)
: [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
}
}
/*
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6)
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6)
* s2 = (m1 + m2) + 4 * (m3 + m4) + 8 * (m5 + m6)
* s3 = (m1 - m2) + 8 * (m3 - m4) + 4 * (m5 - m6)
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6)
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) + m7
*/
int out_h = output->dims()[2];
int out_w = output->dims()[3];
int h_tiles = (out_h + 5) / 6;
int w_tiles = (out_w + 5) / 6;
int remain_h = out_h - out_h / 6 * 6;
int remain_w = out_w - out_w / 6 * 6;
float *output_ptr = output->mutable_data<float>();
out_channel = output->dims()[1];
float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f};
for (int oc = 0; oc < out_channel; ++oc) {
float at_m[48];
float output_tmp[36];
const float *uv_ptr = uv_trans_ptr + oc * h_tiles * w_tiles * 64;
for (int tile_h = 0; tile_h < h_tiles; ++tile_h) {
for (int tile_w = 0; tile_w < w_tiles; ++tile_w) {
float *at_m_ptr = at_m;
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
"mov r0, #2 \n"
"loop_%=: \n"
"vld1.32 {d2-d5}, [%[uv_ptr]]! \n"
"vld1.32 {d6-d9}, [%[uv_ptr]]! \n"
"vld1.32 {d10-d13}, [%[uv_ptr]]! \n"
"vld1.32 {d14-d17}, [%[uv_ptr]]! \n"
"vtrn.32 q1, q3 \n"
"vtrn.32 q2, q4 \n"
"vtrn.32 q5, q7 \n"
"vtrn.32 q6, q8 \n"
"vswp.32 d3, d10 \n" // q1: m0, q5: m2
"vswp.32 d7, d14 \n" // q3: m1, q7: m3
"vswp.32 d5, d12 \n" // q2: m4, q6: m6
"vswp.32 d9, d16 \n" // q4: m5, q8: m7
"vadd.f32 q9, q3, q5 \n" // m1 + m2
"vadd.f32 q10, q7, q2 \n" // m3 + m4
"vadd.f32 q11, q4, q6 \n" // m5 + m6
"vsub.f32 q12, q3, q5 \n" // m1 - m2
"vsub.f32 q13, q7, q2 \n" // m3 - m4
"vsub.f32 q14, q4, q6 \n" // m5 - m6
"vmul.f32 q2, q13, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 q3, q11, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 q15, q1, q9 \n"
"vadd.f32 q15, q15, q10 \n"
"vmla.f32 q15, q3, d1[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vadd.f32 q15, q12, q2 \n"
"vmla.f32 q15, q14, d1[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vmov.32 q15, q9 \n"
"vmla.f32 q15, q10, d0[1] \n"
"vmla.f32 q15, q11, d1[0] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vmov.32 q15, q12 \n"
"vmla.f32 q15, q13, d1[0] \n"
"vmla.f32 q15, q14, d0[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vadd.f32 q15, q9, q3 \n"
"vmla.f32 q15, q10, d1[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vadd.f32 q15, q12, q8 \n"
"vmla.f32 q15, q2, d1[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"subs r0, #1 \n"
"bne loop_%= \n"
: [uv_ptr] "+r"(uv_ptr), [at_m_ptr] "+r"(at_m_ptr)
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
float *at_m_ptr0 = at_m;
float *at_m_ptr1 = at_m + 24;
if ((remain_w > 0 && tile_w == w_tiles - 1) ||
(remain_h > 0 && tile_h == h_tiles - 1)) {
float *out_ptr0 = output_tmp;
float *out_ptr1 = output_tmp + 6;
float *out_ptr2 = output_tmp + 12;
float *out_ptr3 = output_tmp + 18;
float *out_ptr4 = output_tmp + 24;
float *out_ptr5 = output_tmp + 30;
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
// process 4 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // q1: m0, q2: m1
"vld1.32 {d6-d9}, [%[at_m_ptr0]]! \n" // q3: m2, q4: m3
"vld1.32 {d10-d13}, [%[at_m_ptr1]]! \n" // q5: m4, q6: m5
"vld1.32 {d14-d17}, [%[at_m_ptr1]]! \n" // q7: m6, q8: m7
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vtrn.32 q5, q6 \n"
"vtrn.32 q7, q8 \n"
"vswp.32 d3, d6 \n"
"vswp.32 d5, d8 \n"
"vswp.32 d11, d14 \n"
"vswp.32 d13, d16 \n"
"vadd.f32 q9, q2, q3 \n" // m1 + m2
"vadd.f32 q10, q4, q5 \n" // m3 + m4
"vadd.f32 q11, q6, q7 \n" // m5 + m6
"vsub.f32 q12, q2, q3 \n" // m1 - m2
"vsub.f32 q13, q4, q5 \n" // m3 - m4
"vsub.f32 q14, q6, q7 \n" // m5 - m6
"vmul.f32 q6, q13, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 q7, q11, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 q1, q1, q9 \n"
"vadd.f32 q1, q1, q10 \n"
"vmla.f32 q1, q3, d1[1] \n"
"vadd.f32 q2, q12, q6 \n"
"vmla.f32 q2, q14, d1[1] \n"
"vmov.32 q3, q9 \n"
"vmla.f32 q3, q10, d0[1] \n"
"vmla.f32 q3, q11, d1[0] \n"
"vmov.32 q4, q12 \n"
"vmla.f32 q4, q13, d1[0] \n"
"vmla.f32 q4, q14, d0[1] \n"
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vswp.32 d3, d6 \n"
"vswp.32 d5, d8 \n"
"vst1.32 {d2-d3}, [%[out_ptr0]]! \n"
"vst1.32 {d4-d5}, [%[out_ptr1]]! \n"
"vst1.32 {d6-d7}, [%[out_ptr2]]! \n"
"vst1.32 {d8-d9}, [%[out_ptr3]]! \n"
"vadd.f32 q1, q9, q7 \n"
"vmla.f32 q1, q10, d1[1] \n"
"vadd.f32 q2, q12, q8 \n"
"vmla.f32 q2, q6, d1[1] \n"
"vtrn.32 q1, q2 \n"
"vst1.32 {d2}, [%[out_ptr0]]! \n"
"vst1.32 {d4}, [%[out_ptr1]]! \n"
"vst1.32 {d3}, [%[out_ptr2]]! \n"
"vst1.32 {d5}, [%[out_ptr3]]! \n"
// remain 2 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // d2: m0, d3: m2,
// d4: m1, d5: m3
"vld1.32 {d6-d9}, [%[at_m_ptr1]]! \n" // d6: m4, d7: m6,
// d8: m5, d9: m7
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vadd.f32 d10, d4, d3 \n" // m1 + m2
"vadd.f32 d11, d5, d6 \n" // m3 + m4
"vadd.f32 d12, d8, d7 \n" // m5 + m6
"vsub.f32 d13, d4, d3 \n" // m1 - m2
"vsub.f32 d14, d5, d6 \n" // m3 - m4
"vsub.f32 d15, d8, d7 \n" // m5 - m6
"vmul.f32 d16, d14, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 d17, d12, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 d18, d2, d10 \n"
"vadd.f32 d18, d18, d11 \n"
"vmla.f32 d18, d17, d1[1] \n"
"vadd.f32 d20, d13, d16 \n"
"vmla.f32 d20, d15, d1[1] \n"
"vmov.32 d19, d10 \n"
"vmla.f32 d19, d11, d0[1] \n"
"vmla.f32 d19, d20, d1[0] \n"
"vmov.32 d21, d13 \n"
"vmla.f32 d21, d14, d1[0] \n"
"vmla.f32 d21, d15, d0[1] \n"
"vtrn.32 d18, d19 \n"
"vtrn.32 d20, d21 \n"
"vst1.32 {d18-d19}, [%[out_ptr4]]! \n"
"vst1.32 {d20-d21}, [%[out_ptr5]]! \n"
"vadd.f32 d18, d10, d17 \n"
"vmla.f32 d18, d11, d1[1] \n"
"vadd.f32 d19, d13, d9 \n"
"vmla.f32 d19, d16, d1[1] \n"
"vtrn.32 d18, d19 \n"
"vst1.32 {d18}, [%[out_ptr4]]! \n"
"vst1.32 {d19}, [%[out_ptr5]]! \n"
: [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1),
[out_ptr2] "+r"(out_ptr2), [out_ptr3] "+r"(out_ptr3),
[out_ptr4] "+r"(out_ptr4), [out_ptr5] "+r"(out_ptr5),
[at_m_ptr0] "+r"(at_m_ptr0), [at_m_ptr1] "+r"(at_m_ptr1)
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w;
float *out_ptr = output_ptr + offset;
int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h;
int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w;
for (int i = 0; i < remain_row; ++i, out_ptr += out_w) {
memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float));
}
} else {
size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w;
float *out_ptr0 = output_ptr + offset;
float *out_ptr1 = out_ptr0 + out_w;
float *out_ptr2 = out_ptr1 + out_w;
float *out_ptr3 = out_ptr2 + out_w;
float *out_ptr4 = out_ptr3 + out_w;
float *out_ptr5 = out_ptr4 + out_w;
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
// process 4 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // q1: m0, q2: m1
"vld1.32 {d6-d9}, [%[at_m_ptr0]]! \n" // q3: m2, q4: m3
"vld1.32 {d10-d13}, [%[at_m_ptr1]]! \n" // q5: m4, q6: m5
"vld1.32 {d14-d17}, [%[at_m_ptr1]]! \n" // q7: m6, q8: m7
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vtrn.32 q5, q6 \n"
"vtrn.32 q7, q8 \n"
"vswp.32 d3, d6 \n"
"vswp.32 d5, d8 \n"
"vswp.32 d11, d14 \n"
"vswp.32 d13, d16 \n"
"vadd.f32 q9, q2, q3 \n" // m1 + m2
"vadd.f32 q10, q4, q5 \n" // m3 + m4
"vadd.f32 q11, q6, q7 \n" // m5 + m6
"vsub.f32 q12, q2, q3 \n" // m1 - m2
"vsub.f32 q13, q4, q5 \n" // m3 - m4
"vsub.f32 q14, q6, q7 \n" // m5 - m6
"vmul.f32 q6, q13, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 q7, q11, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 q1, q1, q9 \n"
"vadd.f32 q1, q1, q10 \n"
"vmla.f32 q1, q3, d1[1] \n"
"vadd.f32 q2, q12, q6 \n"
"vmla.f32 q2, q14, d1[1] \n"
"vmov.32 q3, q9 \n"
"vmla.f32 q3, q10, d0[1] \n"
"vmla.f32 q3, q11, d1[0] \n"
"vmov.32 q4, q12 \n"
"vmla.f32 q4, q13, d1[0] \n"
"vmla.f32 q4, q14, d0[1] \n"
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vswp.32 d3, d6 \n"
"vswp.32 d5, d8 \n"
"vst1.32 {d2-d3}, [%[out_ptr0]]! \n"
"vst1.32 {d4-d5}, [%[out_ptr1]]! \n"
"vst1.32 {d6-d7}, [%[out_ptr2]]! \n"
"vst1.32 {d8-d9}, [%[out_ptr3]]! \n"
"vadd.f32 q1, q9, q7 \n"
"vmla.f32 q1, q10, d1[1] \n"
"vadd.f32 q2, q12, q8 \n"
"vmla.f32 q2, q6, d1[1] \n"
"vtrn.32 q1, q2 \n"
"vst1.32 {d2}, [%[out_ptr0]]! \n"
"vst1.32 {d4}, [%[out_ptr1]]! \n"
"vst1.32 {d3}, [%[out_ptr2]]! \n"
"vst1.32 {d5}, [%[out_ptr3]]! \n"
// remain 2 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // d2: m0, d3: m2,
// d4: m1, d5: m3
"vld1.32 {d6-d9}, [%[at_m_ptr1]]! \n" // d6: m4, d7: m6,
// d8: m5, d9: m7
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vadd.f32 d10, d4, d3 \n" // m1 + m2
"vadd.f32 d11, d5, d6 \n" // m3 + m4
"vadd.f32 d12, d8, d7 \n" // m5 + m6
"vsub.f32 d13, d4, d3 \n" // m1 - m2
"vsub.f32 d14, d5, d6 \n" // m3 - m4
"vsub.f32 d15, d8, d7 \n" // m5 - m6
"vmul.f32 d16, d14, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 d17, d12, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 d18, d2, d10 \n"
"vadd.f32 d18, d18, d11 \n"
"vmla.f32 d18, d17, d1[1] \n"
"vadd.f32 d20, d13, d16 \n"
"vmla.f32 d20, d15, d1[1] \n"
"vmov.32 d19, d10 \n"
"vmla.f32 d19, d11, d0[1] \n"
"vmla.f32 d19, d20, d1[0] \n"
"vmov.32 d21, d13 \n"
"vmla.f32 d21, d14, d1[0] \n"
"vmla.f32 d21, d15, d0[1] \n"
"vtrn.32 d18, d19 \n"
"vtrn.32 d20, d21 \n"
"vst1.32 {d18-d19}, [%[out_ptr4]]! \n"
"vst1.32 {d20-d21}, [%[out_ptr5]]! \n"
"vadd.f32 d18, d10, d17 \n"
"vmla.f32 d18, d11, d1[1] \n"
"vadd.f32 d19, d13, d9 \n"
"vmla.f32 d19, d16, d1[1] \n"
"vtrn.32 d18, d19 \n"
"vst1.32 {d18}, [%[out_ptr4]]! \n"
"vst1.32 {d19}, [%[out_ptr5]]! \n"
: [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1),
[out_ptr2] "+r"(out_ptr2), [out_ptr3] "+r"(out_ptr3),
[out_ptr4] "+r"(out_ptr4), [out_ptr5] "+r"(out_ptr5),
[at_m_ptr0] "+r"(at_m_ptr0), [at_m_ptr1] "+r"(at_m_ptr1)
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册