From 523bbccac1dec294b2cf8279f66cdb826974262e Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 17 Dec 2018 23:52:12 +0800 Subject: [PATCH] Fix winograd if input height != width --- .../math/winograd/winograd_transform_f6k3.cpp | 51 +-- .../winograd_transform_f6k3_arm64.cpp | 400 ++++++++++++++++-- 2 files changed, 385 insertions(+), 66 deletions(-) diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp index d9a7cb3b51..937050ebbd 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -327,8 +327,8 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, int channel = input.dims()[1]; int height = input.dims()[2]; int width = input.dims()[3]; - int h_tiles = (height + 3) / 6; // (height - 8 + 5 + 6) / 6 - int w_tiles = (width + 3) / 6; // (width - 8 + 5 + 6) / 6 + int h_tiles = (height + 3) / 6; // (height - 2 + 5) / 6 + int w_tiles = (width + 3) / 6; // (width - 2 + 5) / 6 int tiles = (h_tiles * w_tiles + 7) / 8; framework::DDim transformed_shape = framework::make_ddim(std::vector{tiles, 64, channel, 8}); @@ -336,16 +336,10 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, memset(outptr, 0, output->numel() * sizeof(float)); const float *inptr = input.data(); - int inter_h = (height - 2) / 6; - int inter_w = (width - 2) / 6; - int remain_h = height - (inter_h * 6); - int remain_w = width - (inter_w * 6); + height = h_tiles * 6 + 2; + width = w_tiles * 6 + 2; framework::Tensor input_pad; - if (remain_h > 2 || remain_w > 2) { - inter_h += (remain_h > 2); - inter_w += (remain_w > 2); - height = (inter_h - 1) * 6 + 8; - width = (inter_w - 1) * 6 + 8; + if (height > input.dims()[2] || width > input.dims()[3]) { framework::DDim input_shape = framework::make_ddim(std::vector{1, channel, height, width}); PadFunctor pad; @@ -878,8 +872,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, framework::Tensor *output) { // weight shape is [out_channel/4, 64, in_channel, 4], // input shape is [hw/8, 64, in_channel, 8] - int in_channel = input.dims()[2]; int tiles = input.dims()[0]; + int in_channel = input.dims()[2]; int out_channel = weight.dims()[0]; // compute U*V first @@ -887,7 +881,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, framework::DDim shape = framework::make_ddim(std::vector{out_channel, tiles, 64, 32}); float *uv_trans_ptr = uv_trans.mutable_data(shape); - memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float)); const float *input_ptr = input.data(); const float *weight_ptr = weight.data(); @@ -910,7 +903,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, "veor q14, q14, q14 \n" "veor q15, q15, q15 \n" - "b store_res_%= \n" + "cmp %[inter_channel], #0 \n" + "ble loop_1c_%= \n" // loop 2 channels "loop_2c_%=: \n" "vld1.32 {d0-d3}, [%[w_ptr]]! \n" @@ -936,13 +930,14 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, "subs %[inter_channel], #1 \n" "bne loop_2c_%= \n" - "mov pc, lr \n" // loop 1 channel - "loop_c_%=: \n" + "loop_1c_%=: \n" + "cmp %[remain_channel], #0 \n" + "ble store_res_%= \n" + "vld1.32 {d0-d1}, [%[w_ptr]]! \n" "vld1.32 {d4-d7}, [%[in_ptr]]! \n" - "vmla.f32 q8, q2, d0[0] \n" "vmla.f32 q9, q3, d0[0] \n" "vmla.f32 q10, q2, d0[1] \n" @@ -952,28 +947,16 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, "vmla.f32 q14, q2, d1[1] \n" "vmla.f32 q15, q3, d1[1] \n" - "subs %[remain_channel], #1 \n" - "bne loop_c_%= \n" - "mov pc, lr \n" - "store_res_%=: \n" - "cmp %[inter_channel], #0 \n" - "it gt \n" - "blgt loop_2c_%= \n" - "cmp %[remain_channel], #0 \n" - "it gt \n" - "blgt loop_c_%= \n" - "vst1.32 {d16-d19}, [%[uv_ptr]]! \n" "vst1.32 {d20-d23}, [%[uv_ptr]]! \n" "vst1.32 {d24-d27}, [%[uv_ptr]]! \n" "vst1.32 {d28-d31}, [%[uv_ptr]]! \n" : [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr), - [remain_channel] "+r"(remain_channel), [inter_channel] "+r"(inter_channel) - : + : [remain_channel] "r"(remain_channel) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "pc", "lr"); + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); } } } @@ -1223,8 +1206,10 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w; float *out_ptr = output_ptr + offset; - int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h; - int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w; + int remain_row = out_h - 6 * tile_h; + int remain_col = out_w - 6 * tile_w; + remain_row = (remain_row > 6) ? 6 : remain_row; + remain_col = (remain_col > 6) ? 6 : remain_col; for (int i = 0; i < remain_row; ++i, out_ptr += out_w) { memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float)); } diff --git a/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp b/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp index 2ac25daffc..5ef9c194f2 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp @@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn -// project. +// We refer https://github.com/andravin/wincnn to access the winograd transform +// matrixs #ifdef CONV_OP - #ifdef __aarch64__ -#include "operators/math/pad.h" #include "operators/math/winograd/winograd_transform.h" namespace paddle_mobile { @@ -29,46 +27,382 @@ namespace math { template <> void winograd_transform_weight<8, 3>(const framework::Tensor &weight, framework::Tensor *output) { - /* - * w0 = g0 - * w1 = ((g0 + g2) + g1) * (-2.0 / 9) - * w2 = ((g0 + g2) - g1) * (-2.0 / 9) - * w3 = ((g0 + 4 * g2) + 2 * g1) * (1.0 / 90) - * w4 = ((g0 + 4 * g2) - 2 * g1) * (1.0 / 90) - * w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180) - * w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180) - * w7 = g2 - */ - // TODO(hjchen2) - PADDLE_MOBILE_THROW_EXCEPTION( - "Winograd for arm v8 has not been implemented."); + // weight shape is [out_channel, in_channel, kernel_h, kernel_w] + int out_channel = weight.dims()[0]; + int in_channel = weight.dims()[1]; + // reshape and alloc transformed weight + framework::DDim transformed_shape = + framework::make_ddim(std::vector{out_channel, in_channel, 64}); + float *outptr = output->mutable_data(transformed_shape); + const float *inptr = weight.data(); + for (int oc = 0; oc < out_channel; ++oc) { + for (int ic = 0; ic < in_channel; ++ic) { + size_t offset = oc * in_channel + ic; + float *kout = outptr + offset * 64; + const float *k = inptr + offset * 9; + + float gw[3][8]; + for (int i = 0; i < 3; ++i, k += 3) { + float g0 = k[0]; + float g1 = k[1]; + float g2 = k[2]; + float d0 = g0 + g2; + float d1 = g0 + 4 * g2; + float d2 = g2 + 4 * g0; + float d3 = 2 * g1; + gw[i][0] = g0; + gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2) + gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2) + gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2) + gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2) + gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2) + gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2) + gw[i][7] = g2; + } + for (int i = 0; i < 8; ++i, kout += 8) { + float g0 = gw[0][i]; + float g1 = gw[1][i]; + float g2 = gw[2][i]; + float d0 = g0 + g2; + float d1 = g0 + 4 * g2; + float d2 = g2 + 4 * g0; + float d3 = 2 * g1; + kout[0] = g0; + kout[1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (k0 + k1 + k2) + kout[2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (k0 - k1 + k2) + kout[3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2) + kout[4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2) + kout[5] = 1.f / 180 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2) + kout[6] = 1.f / 180 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2) + kout[7] = g2; + } + } + } } template <> void winograd_transform_input<8, 3>(const framework::Tensor &input, framework::Tensor *output) { - /* - * x0 = (d0 - d6) + (d4 - d2) * 5.25 - * x1 = (d2 + d6) - 4.25 * (d4 + d3) + (d1 + d5) - * x2 = (d2 + d6) - 4.25 * (d4 - d3) - (d1 + d5) - * x3 = (0.25 * d2 - 1.25 * d4 + d6) + (0.5 * d1 - 2.5 * d3 + 2 * d5) - * x4 = (0.25 * d2 - 1.25 * d4 + d6) - (0.5 * d1 - 2.5 * d3 + 2 * d5) - * x5 = (4 * d2 - 5 * d4 + d6) + (2 * d1 - 2.5 * d3 + 0.5 * d5) - * x6 = (4 * d2 - 5 * d4 + d6) - (2 * d1 - 2.5 * d3 + 0.5 * d5) - * x7 = (d7 - d1) + (d3 - d5) * 5.25 - */ - // TODO(hjchen2) - PADDLE_MOBILE_THROW_EXCEPTION( - "Winograd for arm v8 has not been implemented."); + // tile input to [c, roundup(h/6), roundup(w/6), 64] and do transformation + int channel = input.dims()[1]; + int height = input.dims()[2]; + int width = input.dims()[3]; + int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6 + int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6 + framework::DDim transformed_shape = + framework::make_ddim(std::vector{channel, h_tiles, w_tiles, 64}); + float *outptr = output->mutable_data(transformed_shape); + memset(outptr, 0, channel * h_tiles * w_tiles * 64 * sizeof(float)); + const float *inptr = input.data(); + // pack input to tiles + for (int c = 0; c < channel; ++c) { + int inter_h = (height - 2) / 6; + int inter_w = (width - 2) / 6; + int remain_h = height - (inter_h * 6); + int remain_w = width - (inter_w * 6); + const float *in0 = inptr + c * height * width; + const float *in1 = in0 + width; + const float *in2 = in1 + width; + const float *in3 = in2 + width; + const float *in4 = in3 + width; + const float *in5 = in4 + width; + const float *in6 = in5 + width; + const float *in7 = in6 + width; + float *out = outptr + c * h_tiles * w_tiles * 64; + + for (int h = 0; h < inter_h; ++h) { + for (int w = 0; w < inter_w; ++w) { + memcpy(out, in0, 8 * sizeof(float)); + memcpy(out + 8, in1, 8 * sizeof(float)); + memcpy(out + 16, in2, 8 * sizeof(float)); + memcpy(out + 24, in3, 8 * sizeof(float)); + memcpy(out + 32, in4, 8 * sizeof(float)); + memcpy(out + 40, in5, 8 * sizeof(float)); + memcpy(out + 48, in6, 8 * sizeof(float)); + memcpy(out + 56, in7, 8 * sizeof(float)); + in0 += 6; + in1 += 6; + in2 += 6; + in3 += 6; + in4 += 6; + in5 += 6; + in6 += 6; + in7 += 6; + out += 64; + } + // remain width + if (remain_w > 2) { + memcpy(out, in0, remain_w * sizeof(float)); + memcpy(out + 8, in1, remain_w * sizeof(float)); + memcpy(out + 16, in2, remain_w * sizeof(float)); + memcpy(out + 24, in3, remain_w * sizeof(float)); + memcpy(out + 32, in4, remain_w * sizeof(float)); + memcpy(out + 40, in5, remain_w * sizeof(float)); + memcpy(out + 48, in6, remain_w * sizeof(float)); + memcpy(out + 56, in7, remain_w * sizeof(float)); + out += 64; + } + in0 += 5 * width + remain_w; + in1 += 5 * width + remain_w; + in2 += 5 * width + remain_w; + in3 += 5 * width + remain_w; + in4 += 5 * width + remain_w; + in5 += 5 * width + remain_w; + in6 += 5 * width + remain_w; + in7 += 5 * width + remain_w; + } + // remain height + if (remain_h > 2) { + for (int w = 0; w < inter_w; ++w) { + for (int rh = 0; rh < remain_h; ++rh) { + memcpy(out + rh * 8, in0 + rh * width, 8 * sizeof(float)); + } + out += 64; + in0 += 6; + } + // remain width + if (remain_w > 2) { + for (int rh = 0; rh < remain_h; ++rh) { + memcpy(out + rh * 8, in0 + rh * width, remain_w * sizeof(float)); + } + } + } + } + // transform tiles, compute B_T * d(c, b) * B + for (int c = 0; c < channel; ++c) { + for (int tile = 0; tile < h_tiles * w_tiles; ++tile) { + float *out = outptr + (c * h_tiles * w_tiles + tile) * 64; + // compute B_T * d(c, b) + float bd[8][8]; + for (int i = 0; i < 8; ++i) { + float d0 = out[8 * i + 0]; + float d1 = out[8 * i + 1]; + float d2 = out[8 * i + 2]; + float d3 = out[8 * i + 3]; + float d4 = out[8 * i + 4]; + float d5 = out[8 * i + 5]; + float d6 = out[8 * i + 6]; + float d7 = out[8 * i + 7]; + + bd[i][0] = d0 - d6 + (d4 - d2) * 5.25; + float v1 = d2 - 4.25 * d4 + d6; + float v2 = d1 - 4.25 * d3 + d5; + // d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6 + bd[i][1] = v1 + v2; + // -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6 + bd[i][2] = v1 - v2; + v1 = 0.25 * d2 - 1.25 * d4 + d6; + v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5; + // 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6 + bd[i][3] = v1 + v2; + // -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6 + bd[i][4] = v1 - v2; + v1 = 4 * d2 - 5 * d4 + d6; + v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5; + // 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6 + bd[i][5] = v1 + v2; + // -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6 + bd[i][6] = v1 - v2; + bd[i][7] = d7 - d1 + (d3 - d5) * 5.25; + } + // compute B_T * d(c, b) * B + for (int i = 0; i < 8; ++i, out += 8) { + float d0 = bd[0][i]; + float d1 = bd[1][i]; + float d2 = bd[2][i]; + float d3 = bd[3][i]; + float d4 = bd[4][i]; + float d5 = bd[5][i]; + float d6 = bd[6][i]; + float d7 = bd[7][i]; + + out[0] = d0 - d6 + (d4 - d2) * 5.25; + float v1 = d2 - 4.25 * d4 + d6; + float v2 = d1 - 4.25 * d3 + d5; + // d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6 + out[1] = v1 + v2; + // -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6 + out[2] = v1 - v2; + v1 = 0.25 * d2 - 1.25 * d4 + d6; + v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5; + // 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6 + out[3] = v1 + v2; + // -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6 + out[4] = v1 - v2; + v1 = 4 * d2 - 5 * d4 + d6; + v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5; + // 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6 + out[5] = v1 + v2; + // -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6 + out[6] = v1 - v2; + out[7] = d7 - d1 + (d3 - d5) * 5.25; + } + } + } } template <> void winograd_transform_output<8, 3>(const framework::Tensor &input, const framework::Tensor &weight, framework::Tensor *output) { - // TODO(hjchen2) - PADDLE_MOBILE_THROW_EXCEPTION( - "Winograd for arm v8 has not been implemented."); + // input shape is [in_channel, h_tiles, w_tiles, 64] + // weight shape is [out_channel, in_channel, 64] + int in_channel = input.dims()[0]; + int h_tiles = input.dims()[1]; + int w_tiles = input.dims()[2]; + int tiles = h_tiles * w_tiles; + int out_channel = weight.dims()[0]; + // compute U*V first + framework::Tensor output_m; + framework::DDim shape = + framework::make_ddim(std::vector{out_channel, tiles, 64}); + float *output_m_ptr = output_m.mutable_data(shape); + memset(output_m_ptr, 0, output_m.numel() * sizeof(float)); + const float *input_ptr = input.data(); + const float *weight_ptr = weight.data(); + for (int i = 0; i < out_channel; ++i) { + for (int j = 0; j < tiles; ++j) { + const float *w_ptr = weight_ptr + i * in_channel * 64; + const float *in_ptr = input_ptr + j * 64; + float *m_ptr = output_m_ptr + (i * tiles + j) * 64; + for (int c = 0; c < in_channel; ++c) { + for (int k = 0; k < 64; ++k) { + m_ptr[k] += w_ptr[k] * in_ptr[k]; + } + w_ptr += 64; + in_ptr += tiles * 64; + } + } + } + + for (int oc = 0; oc < out_channel; ++oc) { + for (int tile = 0; tile < tiles; ++tile) { + float *m = output_m_ptr + (oc * tiles + tile) * 64; + // compute A_T * m + float am[6][8]; + for (int i = 0; i < 8; ++i) { + float d0 = m[i * 8 + 0]; + float d1 = m[i * 8 + 1]; + float d2 = m[i * 8 + 2]; + float d3 = m[i * 8 + 3]; + float d4 = m[i * 8 + 4]; + float d5 = m[i * 8 + 5]; + float d6 = m[i * 8 + 6]; + float d7 = m[i * 8 + 7]; + float v0 = d1 + d2; + float v1 = d1 - d2; + float v2 = d3 + d4; + float v3 = d3 - d4; + float v4 = d5 + d6; + float v5 = d5 - d6; + + am[0][i] = d0 + v0 + v2 + 32 * v4; + am[1][i] = v1 + 2 * v3 + 16 * v5; + am[2][i] = v0 + 4 * v2 + 8 * v4; + am[3][i] = v1 + 8 * v3 + 4 * v5; + am[4][i] = v0 + 16 * v2 + 2 * v4; + am[5][i] = v1 + 32 * v3 + v5 + d7; + } + // compute A_T * m * A + for (int i = 0; i < 6; ++i, m += 8) { + float d0 = am[i][0]; + float d1 = am[i][1]; + float d2 = am[i][2]; + float d3 = am[i][3]; + float d4 = am[i][4]; + float d5 = am[i][5]; + float d6 = am[i][6]; + float d7 = am[i][7]; + float v0 = d1 + d2; + float v1 = d1 - d2; + float v2 = d3 + d4; + float v3 = d3 - d4; + float v4 = d5 + d6; + float v5 = d5 - d6; + + m[0] = d0 + v0 + v2 + 32 * v4; + m[1] = v1 + 2 * v3 + 16 * v5; + m[2] = v0 + 4 * v2 + 8 * v4; + m[3] = v1 + 8 * v3 + 4 * v5; + m[4] = v0 + 16 * v2 + 2 * v4; + m[5] = v1 + 32 * v3 + v5 + d7; + } + } + } + + int out_h = output->dims()[2]; + int out_w = output->dims()[3]; + float *output_ptr = output->mutable_data(); + // copy valid region to final output + for (int oc = 0; oc < out_channel; ++oc) { + int inter_h = out_h / 6; + int inter_w = out_w / 6; + int remain_h = out_h - inter_h * 6; + int remain_w = out_w - inter_w * 6; + + float *out_ptr0 = output_ptr + oc * out_h * out_w; + float *out_ptr1 = out_ptr0 + out_w; + float *out_ptr2 = out_ptr1 + out_w; + float *out_ptr3 = out_ptr2 + out_w; + float *out_ptr4 = out_ptr3 + out_w; + float *out_ptr5 = out_ptr4 + out_w; + const float *m_ptr = output_m_ptr + oc * tiles * 64; + for (int tile_h = 0; tile_h < inter_h; ++tile_h) { + for (int tile_w = 0; tile_w < inter_w; ++tile_w) { + const float *m = m_ptr + (tile_h * w_tiles + tile_w) * 64; + memcpy(out_ptr0, m, 6 * sizeof(float)); + memcpy(out_ptr1, m + 8, 6 * sizeof(float)); + memcpy(out_ptr2, m + 16, 6 * sizeof(float)); + memcpy(out_ptr3, m + 24, 6 * sizeof(float)); + memcpy(out_ptr4, m + 32, 6 * sizeof(float)); + memcpy(out_ptr5, m + 40, 6 * sizeof(float)); + out_ptr0 += 6; + out_ptr1 += 6; + out_ptr2 += 6; + out_ptr3 += 6; + out_ptr4 += 6; + out_ptr5 += 6; + } + // remain w + if (remain_w > 0) { + const float *m = m_ptr + (tile_h * w_tiles + inter_w) * 64; + memcpy(out_ptr0, m, remain_w * sizeof(float)); + memcpy(out_ptr1, m + 8, remain_w * sizeof(float)); + memcpy(out_ptr2, m + 16, remain_w * sizeof(float)); + memcpy(out_ptr3, m + 24, remain_w * sizeof(float)); + memcpy(out_ptr4, m + 32, remain_w * sizeof(float)); + memcpy(out_ptr5, m + 40, remain_w * sizeof(float)); + out_ptr0 += remain_w; + out_ptr1 += remain_w; + out_ptr2 += remain_w; + out_ptr3 += remain_w; + out_ptr4 += remain_w; + out_ptr5 += remain_w; + } + out_ptr0 += 5 * out_w; + out_ptr1 += 5 * out_w; + out_ptr2 += 5 * out_w; + out_ptr3 += 5 * out_w; + out_ptr4 += 5 * out_w; + out_ptr5 += 5 * out_w; + } + // remain h + if (remain_h > 0) { + for (int tile_w = 0; tile_w < inter_w; ++tile_w) { + const float *m = m_ptr + (inter_h * w_tiles + tile_w) * 64; + for (int rh = 0; rh < remain_h; ++rh) { + memcpy(out_ptr0 + rh * out_w, m + rh * 8, 6 * sizeof(float)); + } + out_ptr0 += 6; + } + if (remain_w > 0) { + const float *m = m_ptr + (inter_h * w_tiles + inter_w) * 64; + for (int rh = 0; rh < remain_h; ++rh) { + memcpy(out_ptr0 + rh * out_w, m + rh * 8, remain_w * sizeof(float)); + } + } + } + } } } // namespace math -- GitLab