提交 523bbcca 编写于 作者: H hjchen2

Fix winograd if input height != width

上级 ce169c24
......@@ -327,8 +327,8 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
int channel = input.dims()[1];
int height = input.dims()[2];
int width = input.dims()[3];
int h_tiles = (height + 3) / 6; // (height - 8 + 5 + 6) / 6
int w_tiles = (width + 3) / 6; // (width - 8 + 5 + 6) / 6
int h_tiles = (height + 3) / 6; // (height - 2 + 5) / 6
int w_tiles = (width + 3) / 6; // (width - 2 + 5) / 6
int tiles = (h_tiles * w_tiles + 7) / 8;
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{tiles, 64, channel, 8});
......@@ -336,16 +336,10 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
memset(outptr, 0, output->numel() * sizeof(float));
const float *inptr = input.data<float>();
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
height = h_tiles * 6 + 2;
width = w_tiles * 6 + 2;
framework::Tensor input_pad;
if (remain_h > 2 || remain_w > 2) {
inter_h += (remain_h > 2);
inter_w += (remain_w > 2);
height = (inter_h - 1) * 6 + 8;
width = (inter_w - 1) * 6 + 8;
if (height > input.dims()[2] || width > input.dims()[3]) {
framework::DDim input_shape =
framework::make_ddim(std::vector<int>{1, channel, height, width});
PadFunctor<CPU, float> pad;
......@@ -878,8 +872,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
framework::Tensor *output) {
// weight shape is [out_channel/4, 64, in_channel, 4],
// input shape is [hw/8, 64, in_channel, 8]
int in_channel = input.dims()[2];
int tiles = input.dims()[0];
int in_channel = input.dims()[2];
int out_channel = weight.dims()[0];
// compute U*V first
......@@ -887,7 +881,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
framework::DDim shape =
framework::make_ddim(std::vector<int>{out_channel, tiles, 64, 32});
float *uv_trans_ptr = uv_trans.mutable_data<float>(shape);
memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float));
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
......@@ -910,7 +903,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"veor q14, q14, q14 \n"
"veor q15, q15, q15 \n"
"b store_res_%= \n"
"cmp %[inter_channel], #0 \n"
"ble loop_1c_%= \n"
// loop 2 channels
"loop_2c_%=: \n"
"vld1.32 {d0-d3}, [%[w_ptr]]! \n"
......@@ -936,13 +930,14 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"subs %[inter_channel], #1 \n"
"bne loop_2c_%= \n"
"mov pc, lr \n"
// loop 1 channel
"loop_c_%=: \n"
"loop_1c_%=: \n"
"cmp %[remain_channel], #0 \n"
"ble store_res_%= \n"
"vld1.32 {d0-d1}, [%[w_ptr]]! \n"
"vld1.32 {d4-d7}, [%[in_ptr]]! \n"
"vmla.f32 q8, q2, d0[0] \n"
"vmla.f32 q9, q3, d0[0] \n"
"vmla.f32 q10, q2, d0[1] \n"
......@@ -952,28 +947,16 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"vmla.f32 q14, q2, d1[1] \n"
"vmla.f32 q15, q3, d1[1] \n"
"subs %[remain_channel], #1 \n"
"bne loop_c_%= \n"
"mov pc, lr \n"
"store_res_%=: \n"
"cmp %[inter_channel], #0 \n"
"it gt \n"
"blgt loop_2c_%= \n"
"cmp %[remain_channel], #0 \n"
"it gt \n"
"blgt loop_c_%= \n"
"vst1.32 {d16-d19}, [%[uv_ptr]]! \n"
"vst1.32 {d20-d23}, [%[uv_ptr]]! \n"
"vst1.32 {d24-d27}, [%[uv_ptr]]! \n"
"vst1.32 {d28-d31}, [%[uv_ptr]]! \n"
: [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr),
[remain_channel] "+r"(remain_channel),
[inter_channel] "+r"(inter_channel)
:
: [remain_channel] "r"(remain_channel)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "pc", "lr");
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
}
}
......@@ -1223,8 +1206,10 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w;
float *out_ptr = output_ptr + offset;
int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h;
int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w;
int remain_row = out_h - 6 * tile_h;
int remain_col = out_w - 6 * tile_w;
remain_row = (remain_row > 6) ? 6 : remain_row;
remain_col = (remain_col > 6) ? 6 : remain_col;
for (int i = 0; i < remain_row; ++i, out_ptr += out_w) {
memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float));
}
......
......@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn
// project.
// We refer https://github.com/andravin/wincnn to access the winograd transform
// matrixs
#ifdef CONV_OP
#ifdef __aarch64__
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile {
......@@ -29,46 +27,382 @@ namespace math {
template <>
void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
framework::Tensor *output) {
/*
* w0 = g0
* w1 = ((g0 + g2) + g1) * (-2.0 / 9)
* w2 = ((g0 + g2) - g1) * (-2.0 / 9)
* w3 = ((g0 + 4 * g2) + 2 * g1) * (1.0 / 90)
* w4 = ((g0 + 4 * g2) - 2 * g1) * (1.0 / 90)
* w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180)
* w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180)
* w7 = g2
*/
// TODO(hjchen2)
PADDLE_MOBILE_THROW_EXCEPTION(
"Winograd for arm v8 has not been implemented.");
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
int out_channel = weight.dims()[0];
int in_channel = weight.dims()[1];
// reshape and alloc transformed weight
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{out_channel, in_channel, 64});
float *outptr = output->mutable_data<float>(transformed_shape);
const float *inptr = weight.data<float>();
for (int oc = 0; oc < out_channel; ++oc) {
for (int ic = 0; ic < in_channel; ++ic) {
size_t offset = oc * in_channel + ic;
float *kout = outptr + offset * 64;
const float *k = inptr + offset * 9;
float gw[3][8];
for (int i = 0; i < 3; ++i, k += 3) {
float g0 = k[0];
float g1 = k[1];
float g2 = k[2];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
gw[i][0] = g0;
gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2)
gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2)
gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2)
gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2)
gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2)
gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2)
gw[i][7] = g2;
}
for (int i = 0; i < 8; ++i, kout += 8) {
float g0 = gw[0][i];
float g1 = gw[1][i];
float g2 = gw[2][i];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
kout[0] = g0;
kout[1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (k0 + k1 + k2)
kout[2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (k0 - k1 + k2)
kout[3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2)
kout[4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2)
kout[5] = 1.f / 180 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2)
kout[6] = 1.f / 180 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2)
kout[7] = g2;
}
}
}
}
template <>
void winograd_transform_input<8, 3>(const framework::Tensor &input,
framework::Tensor *output) {
/*
* x0 = (d0 - d6) + (d4 - d2) * 5.25
* x1 = (d2 + d6) - 4.25 * (d4 + d3) + (d1 + d5)
* x2 = (d2 + d6) - 4.25 * (d4 - d3) - (d1 + d5)
* x3 = (0.25 * d2 - 1.25 * d4 + d6) + (0.5 * d1 - 2.5 * d3 + 2 * d5)
* x4 = (0.25 * d2 - 1.25 * d4 + d6) - (0.5 * d1 - 2.5 * d3 + 2 * d5)
* x5 = (4 * d2 - 5 * d4 + d6) + (2 * d1 - 2.5 * d3 + 0.5 * d5)
* x6 = (4 * d2 - 5 * d4 + d6) - (2 * d1 - 2.5 * d3 + 0.5 * d5)
* x7 = (d7 - d1) + (d3 - d5) * 5.25
*/
// TODO(hjchen2)
PADDLE_MOBILE_THROW_EXCEPTION(
"Winograd for arm v8 has not been implemented.");
// tile input to [c, roundup(h/6), roundup(w/6), 64] and do transformation
int channel = input.dims()[1];
int height = input.dims()[2];
int width = input.dims()[3];
int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6
int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{channel, h_tiles, w_tiles, 64});
float *outptr = output->mutable_data<float>(transformed_shape);
memset(outptr, 0, channel * h_tiles * w_tiles * 64 * sizeof(float));
const float *inptr = input.data<float>();
// pack input to tiles
for (int c = 0; c < channel; ++c) {
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
const float *in0 = inptr + c * height * width;
const float *in1 = in0 + width;
const float *in2 = in1 + width;
const float *in3 = in2 + width;
const float *in4 = in3 + width;
const float *in5 = in4 + width;
const float *in6 = in5 + width;
const float *in7 = in6 + width;
float *out = outptr + c * h_tiles * w_tiles * 64;
for (int h = 0; h < inter_h; ++h) {
for (int w = 0; w < inter_w; ++w) {
memcpy(out, in0, 8 * sizeof(float));
memcpy(out + 8, in1, 8 * sizeof(float));
memcpy(out + 16, in2, 8 * sizeof(float));
memcpy(out + 24, in3, 8 * sizeof(float));
memcpy(out + 32, in4, 8 * sizeof(float));
memcpy(out + 40, in5, 8 * sizeof(float));
memcpy(out + 48, in6, 8 * sizeof(float));
memcpy(out + 56, in7, 8 * sizeof(float));
in0 += 6;
in1 += 6;
in2 += 6;
in3 += 6;
in4 += 6;
in5 += 6;
in6 += 6;
in7 += 6;
out += 64;
}
// remain width
if (remain_w > 2) {
memcpy(out, in0, remain_w * sizeof(float));
memcpy(out + 8, in1, remain_w * sizeof(float));
memcpy(out + 16, in2, remain_w * sizeof(float));
memcpy(out + 24, in3, remain_w * sizeof(float));
memcpy(out + 32, in4, remain_w * sizeof(float));
memcpy(out + 40, in5, remain_w * sizeof(float));
memcpy(out + 48, in6, remain_w * sizeof(float));
memcpy(out + 56, in7, remain_w * sizeof(float));
out += 64;
}
in0 += 5 * width + remain_w;
in1 += 5 * width + remain_w;
in2 += 5 * width + remain_w;
in3 += 5 * width + remain_w;
in4 += 5 * width + remain_w;
in5 += 5 * width + remain_w;
in6 += 5 * width + remain_w;
in7 += 5 * width + remain_w;
}
// remain height
if (remain_h > 2) {
for (int w = 0; w < inter_w; ++w) {
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out + rh * 8, in0 + rh * width, 8 * sizeof(float));
}
out += 64;
in0 += 6;
}
// remain width
if (remain_w > 2) {
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out + rh * 8, in0 + rh * width, remain_w * sizeof(float));
}
}
}
}
// transform tiles, compute B_T * d(c, b) * B
for (int c = 0; c < channel; ++c) {
for (int tile = 0; tile < h_tiles * w_tiles; ++tile) {
float *out = outptr + (c * h_tiles * w_tiles + tile) * 64;
// compute B_T * d(c, b)
float bd[8][8];
for (int i = 0; i < 8; ++i) {
float d0 = out[8 * i + 0];
float d1 = out[8 * i + 1];
float d2 = out[8 * i + 2];
float d3 = out[8 * i + 3];
float d4 = out[8 * i + 4];
float d5 = out[8 * i + 5];
float d6 = out[8 * i + 6];
float d7 = out[8 * i + 7];
bd[i][0] = d0 - d6 + (d4 - d2) * 5.25;
float v1 = d2 - 4.25 * d4 + d6;
float v2 = d1 - 4.25 * d3 + d5;
// d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6
bd[i][1] = v1 + v2;
// -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6
bd[i][2] = v1 - v2;
v1 = 0.25 * d2 - 1.25 * d4 + d6;
v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5;
// 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6
bd[i][3] = v1 + v2;
// -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6
bd[i][4] = v1 - v2;
v1 = 4 * d2 - 5 * d4 + d6;
v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5;
// 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6
bd[i][5] = v1 + v2;
// -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6
bd[i][6] = v1 - v2;
bd[i][7] = d7 - d1 + (d3 - d5) * 5.25;
}
// compute B_T * d(c, b) * B
for (int i = 0; i < 8; ++i, out += 8) {
float d0 = bd[0][i];
float d1 = bd[1][i];
float d2 = bd[2][i];
float d3 = bd[3][i];
float d4 = bd[4][i];
float d5 = bd[5][i];
float d6 = bd[6][i];
float d7 = bd[7][i];
out[0] = d0 - d6 + (d4 - d2) * 5.25;
float v1 = d2 - 4.25 * d4 + d6;
float v2 = d1 - 4.25 * d3 + d5;
// d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6
out[1] = v1 + v2;
// -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6
out[2] = v1 - v2;
v1 = 0.25 * d2 - 1.25 * d4 + d6;
v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5;
// 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6
out[3] = v1 + v2;
// -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6
out[4] = v1 - v2;
v1 = 4 * d2 - 5 * d4 + d6;
v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5;
// 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6
out[5] = v1 + v2;
// -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6
out[6] = v1 - v2;
out[7] = d7 - d1 + (d3 - d5) * 5.25;
}
}
}
}
template <>
void winograd_transform_output<8, 3>(const framework::Tensor &input,
const framework::Tensor &weight,
framework::Tensor *output) {
// TODO(hjchen2)
PADDLE_MOBILE_THROW_EXCEPTION(
"Winograd for arm v8 has not been implemented.");
// input shape is [in_channel, h_tiles, w_tiles, 64]
// weight shape is [out_channel, in_channel, 64]
int in_channel = input.dims()[0];
int h_tiles = input.dims()[1];
int w_tiles = input.dims()[2];
int tiles = h_tiles * w_tiles;
int out_channel = weight.dims()[0];
// compute U*V first
framework::Tensor output_m;
framework::DDim shape =
framework::make_ddim(std::vector<int>{out_channel, tiles, 64});
float *output_m_ptr = output_m.mutable_data<float>(shape);
memset(output_m_ptr, 0, output_m.numel() * sizeof(float));
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
for (int i = 0; i < out_channel; ++i) {
for (int j = 0; j < tiles; ++j) {
const float *w_ptr = weight_ptr + i * in_channel * 64;
const float *in_ptr = input_ptr + j * 64;
float *m_ptr = output_m_ptr + (i * tiles + j) * 64;
for (int c = 0; c < in_channel; ++c) {
for (int k = 0; k < 64; ++k) {
m_ptr[k] += w_ptr[k] * in_ptr[k];
}
w_ptr += 64;
in_ptr += tiles * 64;
}
}
}
for (int oc = 0; oc < out_channel; ++oc) {
for (int tile = 0; tile < tiles; ++tile) {
float *m = output_m_ptr + (oc * tiles + tile) * 64;
// compute A_T * m
float am[6][8];
for (int i = 0; i < 8; ++i) {
float d0 = m[i * 8 + 0];
float d1 = m[i * 8 + 1];
float d2 = m[i * 8 + 2];
float d3 = m[i * 8 + 3];
float d4 = m[i * 8 + 4];
float d5 = m[i * 8 + 5];
float d6 = m[i * 8 + 6];
float d7 = m[i * 8 + 7];
float v0 = d1 + d2;
float v1 = d1 - d2;
float v2 = d3 + d4;
float v3 = d3 - d4;
float v4 = d5 + d6;
float v5 = d5 - d6;
am[0][i] = d0 + v0 + v2 + 32 * v4;
am[1][i] = v1 + 2 * v3 + 16 * v5;
am[2][i] = v0 + 4 * v2 + 8 * v4;
am[3][i] = v1 + 8 * v3 + 4 * v5;
am[4][i] = v0 + 16 * v2 + 2 * v4;
am[5][i] = v1 + 32 * v3 + v5 + d7;
}
// compute A_T * m * A
for (int i = 0; i < 6; ++i, m += 8) {
float d0 = am[i][0];
float d1 = am[i][1];
float d2 = am[i][2];
float d3 = am[i][3];
float d4 = am[i][4];
float d5 = am[i][5];
float d6 = am[i][6];
float d7 = am[i][7];
float v0 = d1 + d2;
float v1 = d1 - d2;
float v2 = d3 + d4;
float v3 = d3 - d4;
float v4 = d5 + d6;
float v5 = d5 - d6;
m[0] = d0 + v0 + v2 + 32 * v4;
m[1] = v1 + 2 * v3 + 16 * v5;
m[2] = v0 + 4 * v2 + 8 * v4;
m[3] = v1 + 8 * v3 + 4 * v5;
m[4] = v0 + 16 * v2 + 2 * v4;
m[5] = v1 + 32 * v3 + v5 + d7;
}
}
}
int out_h = output->dims()[2];
int out_w = output->dims()[3];
float *output_ptr = output->mutable_data<float>();
// copy valid region to final output
for (int oc = 0; oc < out_channel; ++oc) {
int inter_h = out_h / 6;
int inter_w = out_w / 6;
int remain_h = out_h - inter_h * 6;
int remain_w = out_w - inter_w * 6;
float *out_ptr0 = output_ptr + oc * out_h * out_w;
float *out_ptr1 = out_ptr0 + out_w;
float *out_ptr2 = out_ptr1 + out_w;
float *out_ptr3 = out_ptr2 + out_w;
float *out_ptr4 = out_ptr3 + out_w;
float *out_ptr5 = out_ptr4 + out_w;
const float *m_ptr = output_m_ptr + oc * tiles * 64;
for (int tile_h = 0; tile_h < inter_h; ++tile_h) {
for (int tile_w = 0; tile_w < inter_w; ++tile_w) {
const float *m = m_ptr + (tile_h * w_tiles + tile_w) * 64;
memcpy(out_ptr0, m, 6 * sizeof(float));
memcpy(out_ptr1, m + 8, 6 * sizeof(float));
memcpy(out_ptr2, m + 16, 6 * sizeof(float));
memcpy(out_ptr3, m + 24, 6 * sizeof(float));
memcpy(out_ptr4, m + 32, 6 * sizeof(float));
memcpy(out_ptr5, m + 40, 6 * sizeof(float));
out_ptr0 += 6;
out_ptr1 += 6;
out_ptr2 += 6;
out_ptr3 += 6;
out_ptr4 += 6;
out_ptr5 += 6;
}
// remain w
if (remain_w > 0) {
const float *m = m_ptr + (tile_h * w_tiles + inter_w) * 64;
memcpy(out_ptr0, m, remain_w * sizeof(float));
memcpy(out_ptr1, m + 8, remain_w * sizeof(float));
memcpy(out_ptr2, m + 16, remain_w * sizeof(float));
memcpy(out_ptr3, m + 24, remain_w * sizeof(float));
memcpy(out_ptr4, m + 32, remain_w * sizeof(float));
memcpy(out_ptr5, m + 40, remain_w * sizeof(float));
out_ptr0 += remain_w;
out_ptr1 += remain_w;
out_ptr2 += remain_w;
out_ptr3 += remain_w;
out_ptr4 += remain_w;
out_ptr5 += remain_w;
}
out_ptr0 += 5 * out_w;
out_ptr1 += 5 * out_w;
out_ptr2 += 5 * out_w;
out_ptr3 += 5 * out_w;
out_ptr4 += 5 * out_w;
out_ptr5 += 5 * out_w;
}
// remain h
if (remain_h > 0) {
for (int tile_w = 0; tile_w < inter_w; ++tile_w) {
const float *m = m_ptr + (inter_h * w_tiles + tile_w) * 64;
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out_ptr0 + rh * out_w, m + rh * 8, 6 * sizeof(float));
}
out_ptr0 += 6;
}
if (remain_w > 0) {
const float *m = m_ptr + (inter_h * w_tiles + inter_w) * 64;
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out_ptr0 + rh * out_w, m + rh * 8, remain_w * sizeof(float));
}
}
}
}
}
} // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册