提交 cbc8c575 编写于 作者: H hjchen2

complete kernel transform function

上级 b6a50b91
......@@ -166,7 +166,7 @@ void ConvCompute(const ConvParam<CPU> &param) {
param.Strides()[0] == param.Strides()[1] &&
param.Dilations()[0] == param.Dilations()[1] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Dilations()[0] == 1 && param.Input()->dims()[1] > 16) {
param.Dilations()[0] == 1 && param.Input()->dims()[1] >= 16) {
BatchConv3x3Winograd(param);
} else {
ConvBasic<float, float>(param);
......
......@@ -12,11 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// refered from https://arxiv.org/abs/1509.09308 and
// Inspired by https://arxiv.org/abs/1509.09308 and
// https://github.com/andravin/wincnn
#pragma once
// Refered from nnpack and ncnn project
#include <arm_neon.h>
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h"
......@@ -27,16 +26,6 @@ namespace math {
template <>
void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
framework::Tensor *output) {
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
int out_channel = weight.dims()[0];
int in_channel = weight.dims()[1];
// reshape and alloc transformed weight
// framework::DDim transformed_shape = framework::make_ddim(
// std::vector<int>{out_channel, (in_channel + 3) / 4, 8, 32});
framework::DDim transformed_shape = framework::make_ddim(
std::vector<int>{(out_channel + 3) / 4, 64, in_channel, 4});
float *outptr = output->mutable_data<float>(transformed_shape);
memset(outptr, 0, output->numel() * sizeof(float));
/*
* w0 = g0
* w1 = ((g0 + g2) + g1) * (-2.0 / 9)
......@@ -47,51 +36,276 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
* w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180)
* w7 = g2
*/
/*
const float *inptr = weight.data<float>();
for (int oc = 0; oc < out_channel; ++oc) {
for (int ic = 0; ic < in_channel; ++ic) {
size_t offset = oc * in_channel + ic;
float *kout = outptr + offset * 64;
const float *k = inptr + offset * 9;
float gw[8][3];
for (int i = 0; i < 3; ++i) {
float k0 = k[i];
float k1 = k[3 + i];
float k2 = k[6 + i];
float d0 = k0 + k2;
float d1 = k0 + 4 * k2;
float d2 = 4 * k0 + k2;
float d3 = 2 * k1;
gw[0][i] = k0;
gw[1][i] = -2.f/9 * (d0 + k1); // -2.f/9 * (k0 + k1 + k2)
gw[2][i] = -2.f/9 * (d0 - k1); // -2.f/9 * (k0 - k1 + k2)
gw[3][i] = 1.f/90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2)
gw[4][i] = 1.f/90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2)
gw[5][i] = 8.f/45 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2)
gw[6][i] = 8.f/45 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2)
gw[7][i] = k2;
}
for (int i = 0; i < 8; ++i, kout += 8) {
float k0 = gw[i][0];
float k1 = gw[i][1];
float k2 = gw[i][2];
float d0 = k0 + k2;
float d1 = k0 + 4 * k2;
float d2 = 4 * k0 + k2;
float d3 = 2 * k1;
kout[0] = gw[i][0];
kout[1] = -2.f/9 * (d0 + k1); // -2.f/9 * (k0 + k1 + k2)
kout[2] = -2.f/9 * (d0 - k1); // -2.f/9 * (k0 - k1 + k2)
kout[3] = 1.f/90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2)
kout[4] = 1.f/90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2)
kout[5] = 8.f/45 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2)
kout[6] = 8.f/45 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2)
kout[7] = gw[i][2];
}
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
int out_channel = weight.dims()[0];
int in_channel = weight.dims()[1];
// reshape and alloc transformed weight
framework::DDim transformed_shape = framework::make_ddim(
std::vector<int>{(out_channel + 3) / 4, 64, in_channel, 4});
float *trans_outptr = output->mutable_data<float>(transformed_shape);
memset(trans_outptr, 0, output->numel() * sizeof(float));
const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180};
const float *inptr = weight.data<float>();
int remain_start = out_channel & 0xFFFC;
#ifdef __aarch64__
remain_start = 0;
#else
for (int oc = 0; oc < out_channel; oc += 4) {
float gw[96]; // gw[3][8][4]
const float *inptr0 = inptr + oc * in_channel * 9; //
const float *inptr1 = inptr + (oc + 1) * in_channel * 9;
const float *inptr2 = inptr + (oc + 2) * in_channel * 9;
const float *inptr3 = inptr + (oc + 3) * in_channel * 9;
// oc * 64 * in_channel
float *outptr = trans_outptr + ((oc * in_channel) << 6);
for (int ic = 0; ic < in_channel; ++ic) {
float *gw_ptr = gw;
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
"mov r0, #6 \n"
"vld1.32 {d2-d5}, [%[inptr0]], r0 \n"
"vld1.32 {d6-d9}, [%[inptr1]], r0 \n"
"vld1.32 {d10-d13}, [%[inptr2]], r0 \n"
"vld1.32 {d14-d17}, [%[inptr3]], r0 \n"
"vtrn.32 q1, q3 \n"
"vtrn.32 q2, q4 \n"
"vtrn.32 q5, q7 \n"
"vtrn.32 q6, q8 \n"
"vswp.32 d3, d10 \n"
"vswp.32 d7, d14 \n"
"vswp.32 d5, d12 \n"
"vswp.32 d9, d16 \n"
// q1: g0, q3: g1, q5: g2
"vst1.32 {d2-d3}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q1, q5 \n"
"vadd.f32 q10, q9, q3 \n"
"vsub.f32 q11, q9, q3 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q3, d0[0] \n" // 2 * g1
"vmul.f32 q11, q5, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q5, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d10-d11}, [%[gw_ptr]]! \n"
// q7: g0, q2: g1, q4: g2
"vst1.32 {d14-d15}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q7, q4 \n"
"vadd.f32 q10, q9, q2 \n"
"vsub.f32 q11, q9, q2 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q7, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q2, d0[0] \n" // 2 * g1
"vmul.f32 q11, q4, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q7, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q4, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d8-d9}, [%[gw_ptr]]! \n"
"mov r0, #3 \n"
"vld1.32 {d2-d5}, [%[inptr0]], r0 \n"
"vld1.32 {d6-d9}, [%[inptr1]], r0 \n"
"vld1.32 {d10-d13}, [%[inptr2]], r0 \n"
"vld1.32 {d14-d17}, [%[inptr3]], r0 \n"
"vtrn.32 q1, q3 \n"
"vtrn.32 q2, q4 \n"
"vtrn.32 q5, q7 \n"
"vtrn.32 q6, q8 \n"
"vswp.32 d3, d10 \n"
"vswp.32 d7, d14 \n"
"vswp.32 d5, d12 \n"
"vswp.32 d9, d16 \n"
// q1: g0, q3: g1, q5: g2
"vst1.32 {d2-d3}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q1, q5 \n"
"vadd.f32 q10, q9, q3 \n"
"vsub.f32 q11, q9, q3 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q3, d0[0] \n" // 2 * g1
"vmul.f32 q11, q5, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q5, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d10-d11}, [%[gw_ptr]]! \n"
: [gw_ptr] "+r"(gw_ptr), [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr0), [inptr3] "+r"(inptr3)
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
float *gw_ptr0 = gw;
float *gw_ptr1 = gw + 32;
float *gw_ptr2 = gw + 64;
float *outptr0 = outptr + (ic << 2); // ic * 4
int steps = (in_channel << 2); // in_channel * 4
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
"mov r0, #8 \n"
"loop_8_%=: \n"
"vld1.32 {d2-d3}, [%[gw_ptr0]]! \n"
"vld1.32 {d4-d5}, [%[gw_ptr1]]! \n"
"vld1.32 {d6-d7}, [%[gw_ptr2]]! \n"
// q1: g0, q2: g1, q3: g2
"vst1.32 {d2-d3}, [%[outptr0]], %[steps] \n"
"vadd.f32 q9, q1, q3 \n"
"vadd.f32 q10, q9, q2 \n"
"vsub.f32 q11, q9, q2 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[outptr0]], %[steps] \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[outptr0]], %[steps] \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q2, d0[0] \n" // 2 * g1
"vmul.f32 q11, q3, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vadd.f32 q12, q2, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vst1.32 {d6-d7}, [%[outptr0]], %[steps] \n"
"subs r0, #1 \n"
"bne loop_8_%= \n"
: [outptr0] "+r"(outptr0), [gw_ptr0] "+r"(gw_ptr0),
[gw_ptr1] "+r"(gw_ptr1), [gw_ptr2] "+r"(gw_ptr2)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q9", "q10", "q11", "q12",
"q13", "r0");
}
}
#endif // __aarch64__
// remain output channel
for (int oc = remain_start; oc < out_channel; ++oc) {
float gw[8][3]; // gw[3][8]
const float *inptr0 = inptr + oc * in_channel * 9; //
// oc * 64 * in_channel + oc % 4
int offset = ((oc * in_channel) << 6) + oc & 0x3;
int steps = (in_channel << 2); // in_channel * 4
float *outptr = trans_outptr + offset;
for (int ic = 0; ic < in_channel; ++ic) {
for (int i = 0; i < 3; ++i) {
float g0 = inptr0[i];
float g1 = inptr0[3 + i];
float g2 = inptr0[6 + i];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
gw[0][i] = g0;
gw[1][i] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2)
gw[2][i] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2)
gw[3][i] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2)
gw[4][i] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2)
gw[5][i] = 8.f / 45 * (d2 + d3); // 8.f/45 * (4 * g0 + 2 * g1 + g2)
gw[6][i] = 8.f / 45 * (d2 - d3); // 8.f/45 * (4 * g0 - 2 * g1 + g2)
gw[7][i] = g2;
}
inptr0 += 9;
outptr += ic * 4;
for (int i = 0; i < 8; ++i) {
float g0 = gw[i][0];
float g1 = gw[i][1];
float g2 = gw[i][2];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
int offset = i * 8 * steps;
outptr[offset] = gw[i][0];
outptr[offset + 1 * steps] = -2.f / 9 * (d0 + g1);
outptr[offset + 2 * steps] = -2.f / 9 * (d0 - g1);
outptr[offset + 3 * steps] = 1.f / 90 * (d1 + d3);
outptr[offset + 4 * steps] = 1.f / 90 * (d1 - d3);
outptr[offset + 5 * steps] = 8.f / 45 * (d2 + d3);
outptr[offset + 6 * steps] = 8.f / 45 * (d2 - d3);
outptr[offset + 7 * steps] = gw[i][2];
}
}
*/
}
}
template <>
......@@ -721,7 +935,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
}
}
}
/*
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6)
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6)
......@@ -741,8 +954,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f};
for (int oc = 0; oc < out_channel; ++oc) {
float at_m[48];
float output_tmp[36];
float at_m[48]; // [6][8]
float output_tmp[36]; // [6][6], temporarily restore results
const float *uv_ptr = uv_trans_ptr + oc * h_tiles * w_tiles * 64;
for (int tile_h = 0; tile_h < h_tiles; ++tile_h) {
for (int tile_w = 0; tile_w < w_tiles; ++tile_w) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册