提交 02f7529f 编写于 作者: H hjchen2

complete kernel transform function

上级 68e8dc4a
...@@ -166,7 +166,7 @@ void ConvCompute(const ConvParam<CPU> &param) { ...@@ -166,7 +166,7 @@ void ConvCompute(const ConvParam<CPU> &param) {
param.Strides()[0] == param.Strides()[1] && param.Strides()[0] == param.Strides()[1] &&
param.Dilations()[0] == param.Dilations()[1] && param.Dilations()[0] == param.Dilations()[1] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Dilations()[0] == 1 && param.Input()->dims()[1] > 16) { param.Dilations()[0] == 1 && param.Input()->dims()[1] >= 16) {
BatchConv3x3Winograd(param); BatchConv3x3Winograd(param);
} else { } else {
ConvBasic<float, float>(param); ConvBasic<float, float>(param);
......
...@@ -12,11 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// refered from https://arxiv.org/abs/1509.09308 and // Inspired by https://arxiv.org/abs/1509.09308 and
// https://github.com/andravin/wincnn // https://github.com/andravin/wincnn
#pragma once // Refered from nnpack and ncnn project
#include <arm_neon.h>
#include "operators/math/pad.h" #include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h" #include "operators/math/winograd/winograd_transform.h"
...@@ -27,16 +26,6 @@ namespace math { ...@@ -27,16 +26,6 @@ namespace math {
template <> template <>
void winograd_transform_weight<8, 3>(const framework::Tensor &weight, void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
framework::Tensor *output) { framework::Tensor *output) {
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
int out_channel = weight.dims()[0];
int in_channel = weight.dims()[1];
// reshape and alloc transformed weight
// framework::DDim transformed_shape = framework::make_ddim(
// std::vector<int>{out_channel, (in_channel + 3) / 4, 8, 32});
framework::DDim transformed_shape = framework::make_ddim(
std::vector<int>{(out_channel + 3) / 4, 64, in_channel, 4});
float *outptr = output->mutable_data<float>(transformed_shape);
memset(outptr, 0, output->numel() * sizeof(float));
/* /*
* w0 = g0 * w0 = g0
* w1 = ((g0 + g2) + g1) * (-2.0 / 9) * w1 = ((g0 + g2) + g1) * (-2.0 / 9)
...@@ -47,51 +36,276 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, ...@@ -47,51 +36,276 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
* w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180) * w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180)
* w7 = g2 * w7 = g2
*/ */
/* // weight shape is [out_channel, in_channel, kernel_h, kernel_w]
const float *inptr = weight.data<float>(); int out_channel = weight.dims()[0];
for (int oc = 0; oc < out_channel; ++oc) { int in_channel = weight.dims()[1];
for (int ic = 0; ic < in_channel; ++ic) { // reshape and alloc transformed weight
size_t offset = oc * in_channel + ic; framework::DDim transformed_shape = framework::make_ddim(
float *kout = outptr + offset * 64; std::vector<int>{(out_channel + 3) / 4, 64, in_channel, 4});
const float *k = inptr + offset * 9; float *trans_outptr = output->mutable_data<float>(transformed_shape);
float gw[8][3]; memset(trans_outptr, 0, output->numel() * sizeof(float));
for (int i = 0; i < 3; ++i) {
float k0 = k[i]; const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180};
float k1 = k[3 + i]; const float *inptr = weight.data<float>();
float k2 = k[6 + i]; int remain_start = out_channel & 0xFFFC;
float d0 = k0 + k2; #ifdef __aarch64__
float d1 = k0 + 4 * k2; remain_start = 0;
float d2 = 4 * k0 + k2; #else
float d3 = 2 * k1; for (int oc = 0; oc < out_channel; oc += 4) {
gw[0][i] = k0; float gw[96]; // gw[3][8][4]
gw[1][i] = -2.f/9 * (d0 + k1); // -2.f/9 * (k0 + k1 + k2) const float *inptr0 = inptr + oc * in_channel * 9; //
gw[2][i] = -2.f/9 * (d0 - k1); // -2.f/9 * (k0 - k1 + k2) const float *inptr1 = inptr + (oc + 1) * in_channel * 9;
gw[3][i] = 1.f/90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2) const float *inptr2 = inptr + (oc + 2) * in_channel * 9;
gw[4][i] = 1.f/90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2) const float *inptr3 = inptr + (oc + 3) * in_channel * 9;
gw[5][i] = 8.f/45 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2) // oc * 64 * in_channel
gw[6][i] = 8.f/45 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2) float *outptr = trans_outptr + ((oc * in_channel) << 6);
gw[7][i] = k2; for (int ic = 0; ic < in_channel; ++ic) {
} float *gw_ptr = gw;
for (int i = 0; i < 8; ++i, kout += 8) { asm volatile(
float k0 = gw[i][0]; "vld1.32 {d0-d1}, [%[tm_ptr]] \n"
float k1 = gw[i][1];
float k2 = gw[i][2]; "mov r0, #6 \n"
float d0 = k0 + k2; "vld1.32 {d2-d5}, [%[inptr0]], r0 \n"
float d1 = k0 + 4 * k2; "vld1.32 {d6-d9}, [%[inptr1]], r0 \n"
float d2 = 4 * k0 + k2; "vld1.32 {d10-d13}, [%[inptr2]], r0 \n"
float d3 = 2 * k1; "vld1.32 {d14-d17}, [%[inptr3]], r0 \n"
kout[0] = gw[i][0]; "vtrn.32 q1, q3 \n"
kout[1] = -2.f/9 * (d0 + k1); // -2.f/9 * (k0 + k1 + k2) "vtrn.32 q2, q4 \n"
kout[2] = -2.f/9 * (d0 - k1); // -2.f/9 * (k0 - k1 + k2) "vtrn.32 q5, q7 \n"
kout[3] = 1.f/90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2) "vtrn.32 q6, q8 \n"
kout[4] = 1.f/90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2) "vswp.32 d3, d10 \n"
kout[5] = 8.f/45 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2) "vswp.32 d7, d14 \n"
kout[6] = 8.f/45 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2) "vswp.32 d5, d12 \n"
kout[7] = gw[i][2]; "vswp.32 d9, d16 \n"
}
// q1: g0, q3: g1, q5: g2
"vst1.32 {d2-d3}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q1, q5 \n"
"vadd.f32 q10, q9, q3 \n"
"vsub.f32 q11, q9, q3 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q3, d0[0] \n" // 2 * g1
"vmul.f32 q11, q5, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q5, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d10-d11}, [%[gw_ptr]]! \n"
// q7: g0, q2: g1, q4: g2
"vst1.32 {d14-d15}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q7, q4 \n"
"vadd.f32 q10, q9, q2 \n"
"vsub.f32 q11, q9, q2 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q7, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q2, d0[0] \n" // 2 * g1
"vmul.f32 q11, q4, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q7, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q4, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d8-d9}, [%[gw_ptr]]! \n"
"mov r0, #3 \n"
"vld1.32 {d2-d5}, [%[inptr0]], r0 \n"
"vld1.32 {d6-d9}, [%[inptr1]], r0 \n"
"vld1.32 {d10-d13}, [%[inptr2]], r0 \n"
"vld1.32 {d14-d17}, [%[inptr3]], r0 \n"
"vtrn.32 q1, q3 \n"
"vtrn.32 q2, q4 \n"
"vtrn.32 q5, q7 \n"
"vtrn.32 q6, q8 \n"
"vswp.32 d3, d10 \n"
"vswp.32 d7, d14 \n"
"vswp.32 d5, d12 \n"
"vswp.32 d9, d16 \n"
// q1: g0, q3: g1, q5: g2
"vst1.32 {d2-d3}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q1, q5 \n"
"vadd.f32 q10, q9, q3 \n"
"vsub.f32 q11, q9, q3 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q3, d0[0] \n" // 2 * g1
"vmul.f32 q11, q5, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q5, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d10-d11}, [%[gw_ptr]]! \n"
: [gw_ptr] "+r"(gw_ptr), [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr0), [inptr3] "+r"(inptr3)
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
float *gw_ptr0 = gw;
float *gw_ptr1 = gw + 32;
float *gw_ptr2 = gw + 64;
float *outptr0 = outptr + (ic << 2); // ic * 4
int steps = (in_channel << 2); // in_channel * 4
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
"mov r0, #8 \n"
"loop_8_%=: \n"
"vld1.32 {d2-d3}, [%[gw_ptr0]]! \n"
"vld1.32 {d4-d5}, [%[gw_ptr1]]! \n"
"vld1.32 {d6-d7}, [%[gw_ptr2]]! \n"
// q1: g0, q2: g1, q3: g2
"vst1.32 {d2-d3}, [%[outptr0]], %[steps] \n"
"vadd.f32 q9, q1, q3 \n"
"vadd.f32 q10, q9, q2 \n"
"vsub.f32 q11, q9, q2 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[outptr0]], %[steps] \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[outptr0]], %[steps] \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q2, d0[0] \n" // 2 * g1
"vmul.f32 q11, q3, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vadd.f32 q12, q2, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vst1.32 {d6-d7}, [%[outptr0]], %[steps] \n"
"subs r0, #1 \n"
"bne loop_8_%= \n"
: [outptr0] "+r"(outptr0), [gw_ptr0] "+r"(gw_ptr0),
[gw_ptr1] "+r"(gw_ptr1), [gw_ptr2] "+r"(gw_ptr2)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q9", "q10", "q11", "q12",
"q13", "r0");
}
}
#endif // __aarch64__
// remain output channel
for (int oc = remain_start; oc < out_channel; ++oc) {
float gw[8][3]; // gw[3][8]
const float *inptr0 = inptr + oc * in_channel * 9; //
// oc * 64 * in_channel + oc % 4
int offset = ((oc * in_channel) << 6) + oc & 0x3;
int steps = (in_channel << 2); // in_channel * 4
float *outptr = trans_outptr + offset;
for (int ic = 0; ic < in_channel; ++ic) {
for (int i = 0; i < 3; ++i) {
float g0 = inptr0[i];
float g1 = inptr0[3 + i];
float g2 = inptr0[6 + i];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
gw[0][i] = g0;
gw[1][i] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2)
gw[2][i] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2)
gw[3][i] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2)
gw[4][i] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2)
gw[5][i] = 8.f / 45 * (d2 + d3); // 8.f/45 * (4 * g0 + 2 * g1 + g2)
gw[6][i] = 8.f / 45 * (d2 - d3); // 8.f/45 * (4 * g0 - 2 * g1 + g2)
gw[7][i] = g2;
}
inptr0 += 9;
outptr += ic * 4;
for (int i = 0; i < 8; ++i) {
float g0 = gw[i][0];
float g1 = gw[i][1];
float g2 = gw[i][2];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
int offset = i * 8 * steps;
outptr[offset] = gw[i][0];
outptr[offset + 1 * steps] = -2.f / 9 * (d0 + g1);
outptr[offset + 2 * steps] = -2.f / 9 * (d0 - g1);
outptr[offset + 3 * steps] = 1.f / 90 * (d1 + d3);
outptr[offset + 4 * steps] = 1.f / 90 * (d1 - d3);
outptr[offset + 5 * steps] = 8.f / 45 * (d2 + d3);
outptr[offset + 6 * steps] = 8.f / 45 * (d2 - d3);
outptr[offset + 7 * steps] = gw[i][2];
} }
} }
*/ }
} }
template <> template <>
...@@ -721,7 +935,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -721,7 +935,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
} }
} }
} }
/* /*
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6) * s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6)
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6) * s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6)
...@@ -741,8 +954,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -741,8 +954,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f}; float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f};
for (int oc = 0; oc < out_channel; ++oc) { for (int oc = 0; oc < out_channel; ++oc) {
float at_m[48]; float at_m[48]; // [6][8]
float output_tmp[36]; float output_tmp[36]; // [6][6], temporarily restore results
const float *uv_ptr = uv_trans_ptr + oc * h_tiles * w_tiles * 64; const float *uv_ptr = uv_trans_ptr + oc * h_tiles * w_tiles * 64;
for (int tile_h = 0; tile_h < h_tiles; ++tile_h) { for (int tile_h = 0; tile_h < h_tiles; ++tile_h) {
for (int tile_w = 0; tile_w < w_tiles; ++tile_w) { for (int tile_w = 0; tile_w < w_tiles; ++tile_w) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册