提交 523bbcca 编写于 作者: H hjchen2

Fix winograd if input height != width

上级 ce169c24
...@@ -327,8 +327,8 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, ...@@ -327,8 +327,8 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
int channel = input.dims()[1]; int channel = input.dims()[1];
int height = input.dims()[2]; int height = input.dims()[2];
int width = input.dims()[3]; int width = input.dims()[3];
int h_tiles = (height + 3) / 6; // (height - 8 + 5 + 6) / 6 int h_tiles = (height + 3) / 6; // (height - 2 + 5) / 6
int w_tiles = (width + 3) / 6; // (width - 8 + 5 + 6) / 6 int w_tiles = (width + 3) / 6; // (width - 2 + 5) / 6
int tiles = (h_tiles * w_tiles + 7) / 8; int tiles = (h_tiles * w_tiles + 7) / 8;
framework::DDim transformed_shape = framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{tiles, 64, channel, 8}); framework::make_ddim(std::vector<int>{tiles, 64, channel, 8});
...@@ -336,16 +336,10 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, ...@@ -336,16 +336,10 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
memset(outptr, 0, output->numel() * sizeof(float)); memset(outptr, 0, output->numel() * sizeof(float));
const float *inptr = input.data<float>(); const float *inptr = input.data<float>();
int inter_h = (height - 2) / 6; height = h_tiles * 6 + 2;
int inter_w = (width - 2) / 6; width = w_tiles * 6 + 2;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
framework::Tensor input_pad; framework::Tensor input_pad;
if (remain_h > 2 || remain_w > 2) { if (height > input.dims()[2] || width > input.dims()[3]) {
inter_h += (remain_h > 2);
inter_w += (remain_w > 2);
height = (inter_h - 1) * 6 + 8;
width = (inter_w - 1) * 6 + 8;
framework::DDim input_shape = framework::DDim input_shape =
framework::make_ddim(std::vector<int>{1, channel, height, width}); framework::make_ddim(std::vector<int>{1, channel, height, width});
PadFunctor<CPU, float> pad; PadFunctor<CPU, float> pad;
...@@ -878,8 +872,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -878,8 +872,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
framework::Tensor *output) { framework::Tensor *output) {
// weight shape is [out_channel/4, 64, in_channel, 4], // weight shape is [out_channel/4, 64, in_channel, 4],
// input shape is [hw/8, 64, in_channel, 8] // input shape is [hw/8, 64, in_channel, 8]
int in_channel = input.dims()[2];
int tiles = input.dims()[0]; int tiles = input.dims()[0];
int in_channel = input.dims()[2];
int out_channel = weight.dims()[0]; int out_channel = weight.dims()[0];
// compute U*V first // compute U*V first
...@@ -887,7 +881,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -887,7 +881,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
framework::DDim shape = framework::DDim shape =
framework::make_ddim(std::vector<int>{out_channel, tiles, 64, 32}); framework::make_ddim(std::vector<int>{out_channel, tiles, 64, 32});
float *uv_trans_ptr = uv_trans.mutable_data<float>(shape); float *uv_trans_ptr = uv_trans.mutable_data<float>(shape);
memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float));
const float *input_ptr = input.data<float>(); const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>(); const float *weight_ptr = weight.data<float>();
...@@ -910,7 +903,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -910,7 +903,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"veor q14, q14, q14 \n" "veor q14, q14, q14 \n"
"veor q15, q15, q15 \n" "veor q15, q15, q15 \n"
"b store_res_%= \n" "cmp %[inter_channel], #0 \n"
"ble loop_1c_%= \n"
// loop 2 channels // loop 2 channels
"loop_2c_%=: \n" "loop_2c_%=: \n"
"vld1.32 {d0-d3}, [%[w_ptr]]! \n" "vld1.32 {d0-d3}, [%[w_ptr]]! \n"
...@@ -936,13 +930,14 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -936,13 +930,14 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"subs %[inter_channel], #1 \n" "subs %[inter_channel], #1 \n"
"bne loop_2c_%= \n" "bne loop_2c_%= \n"
"mov pc, lr \n"
// loop 1 channel // loop 1 channel
"loop_c_%=: \n" "loop_1c_%=: \n"
"cmp %[remain_channel], #0 \n"
"ble store_res_%= \n"
"vld1.32 {d0-d1}, [%[w_ptr]]! \n" "vld1.32 {d0-d1}, [%[w_ptr]]! \n"
"vld1.32 {d4-d7}, [%[in_ptr]]! \n" "vld1.32 {d4-d7}, [%[in_ptr]]! \n"
"vmla.f32 q8, q2, d0[0] \n" "vmla.f32 q8, q2, d0[0] \n"
"vmla.f32 q9, q3, d0[0] \n" "vmla.f32 q9, q3, d0[0] \n"
"vmla.f32 q10, q2, d0[1] \n" "vmla.f32 q10, q2, d0[1] \n"
...@@ -952,28 +947,16 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -952,28 +947,16 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"vmla.f32 q14, q2, d1[1] \n" "vmla.f32 q14, q2, d1[1] \n"
"vmla.f32 q15, q3, d1[1] \n" "vmla.f32 q15, q3, d1[1] \n"
"subs %[remain_channel], #1 \n"
"bne loop_c_%= \n"
"mov pc, lr \n"
"store_res_%=: \n" "store_res_%=: \n"
"cmp %[inter_channel], #0 \n"
"it gt \n"
"blgt loop_2c_%= \n"
"cmp %[remain_channel], #0 \n"
"it gt \n"
"blgt loop_c_%= \n"
"vst1.32 {d16-d19}, [%[uv_ptr]]! \n" "vst1.32 {d16-d19}, [%[uv_ptr]]! \n"
"vst1.32 {d20-d23}, [%[uv_ptr]]! \n" "vst1.32 {d20-d23}, [%[uv_ptr]]! \n"
"vst1.32 {d24-d27}, [%[uv_ptr]]! \n" "vst1.32 {d24-d27}, [%[uv_ptr]]! \n"
"vst1.32 {d28-d31}, [%[uv_ptr]]! \n" "vst1.32 {d28-d31}, [%[uv_ptr]]! \n"
: [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr), : [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr),
[remain_channel] "+r"(remain_channel),
[inter_channel] "+r"(inter_channel) [inter_channel] "+r"(inter_channel)
: : [remain_channel] "r"(remain_channel)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "pc", "lr"); "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
} }
} }
} }
...@@ -1223,8 +1206,10 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -1223,8 +1206,10 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w; size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w;
float *out_ptr = output_ptr + offset; float *out_ptr = output_ptr + offset;
int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h; int remain_row = out_h - 6 * tile_h;
int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w; int remain_col = out_w - 6 * tile_w;
remain_row = (remain_row > 6) ? 6 : remain_row;
remain_col = (remain_col > 6) ? 6 : remain_col;
for (int i = 0; i < remain_row; ++i, out_ptr += out_w) { for (int i = 0; i < remain_row; ++i, out_ptr += out_w) {
memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float)); memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float));
} }
......
...@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn // We refer https://github.com/andravin/wincnn to access the winograd transform
// project. // matrixs
#ifdef CONV_OP #ifdef CONV_OP
#ifdef __aarch64__ #ifdef __aarch64__
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h" #include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -29,46 +27,382 @@ namespace math { ...@@ -29,46 +27,382 @@ namespace math {
template <> template <>
void winograd_transform_weight<8, 3>(const framework::Tensor &weight, void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
framework::Tensor *output) { framework::Tensor *output) {
/* // weight shape is [out_channel, in_channel, kernel_h, kernel_w]
* w0 = g0 int out_channel = weight.dims()[0];
* w1 = ((g0 + g2) + g1) * (-2.0 / 9) int in_channel = weight.dims()[1];
* w2 = ((g0 + g2) - g1) * (-2.0 / 9) // reshape and alloc transformed weight
* w3 = ((g0 + 4 * g2) + 2 * g1) * (1.0 / 90) framework::DDim transformed_shape =
* w4 = ((g0 + 4 * g2) - 2 * g1) * (1.0 / 90) framework::make_ddim(std::vector<int>{out_channel, in_channel, 64});
* w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180) float *outptr = output->mutable_data<float>(transformed_shape);
* w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180) const float *inptr = weight.data<float>();
* w7 = g2 for (int oc = 0; oc < out_channel; ++oc) {
*/ for (int ic = 0; ic < in_channel; ++ic) {
// TODO(hjchen2) size_t offset = oc * in_channel + ic;
PADDLE_MOBILE_THROW_EXCEPTION( float *kout = outptr + offset * 64;
"Winograd for arm v8 has not been implemented."); const float *k = inptr + offset * 9;
float gw[3][8];
for (int i = 0; i < 3; ++i, k += 3) {
float g0 = k[0];
float g1 = k[1];
float g2 = k[2];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
gw[i][0] = g0;
gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2)
gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2)
gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2)
gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2)
gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2)
gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2)
gw[i][7] = g2;
}
for (int i = 0; i < 8; ++i, kout += 8) {
float g0 = gw[0][i];
float g1 = gw[1][i];
float g2 = gw[2][i];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
kout[0] = g0;
kout[1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (k0 + k1 + k2)
kout[2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (k0 - k1 + k2)
kout[3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2)
kout[4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2)
kout[5] = 1.f / 180 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2)
kout[6] = 1.f / 180 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2)
kout[7] = g2;
}
}
}
} }
template <> template <>
void winograd_transform_input<8, 3>(const framework::Tensor &input, void winograd_transform_input<8, 3>(const framework::Tensor &input,
framework::Tensor *output) { framework::Tensor *output) {
/* // tile input to [c, roundup(h/6), roundup(w/6), 64] and do transformation
* x0 = (d0 - d6) + (d4 - d2) * 5.25 int channel = input.dims()[1];
* x1 = (d2 + d6) - 4.25 * (d4 + d3) + (d1 + d5) int height = input.dims()[2];
* x2 = (d2 + d6) - 4.25 * (d4 - d3) - (d1 + d5) int width = input.dims()[3];
* x3 = (0.25 * d2 - 1.25 * d4 + d6) + (0.5 * d1 - 2.5 * d3 + 2 * d5) int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6
* x4 = (0.25 * d2 - 1.25 * d4 + d6) - (0.5 * d1 - 2.5 * d3 + 2 * d5) int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6
* x5 = (4 * d2 - 5 * d4 + d6) + (2 * d1 - 2.5 * d3 + 0.5 * d5) framework::DDim transformed_shape =
* x6 = (4 * d2 - 5 * d4 + d6) - (2 * d1 - 2.5 * d3 + 0.5 * d5) framework::make_ddim(std::vector<int>{channel, h_tiles, w_tiles, 64});
* x7 = (d7 - d1) + (d3 - d5) * 5.25 float *outptr = output->mutable_data<float>(transformed_shape);
*/ memset(outptr, 0, channel * h_tiles * w_tiles * 64 * sizeof(float));
// TODO(hjchen2) const float *inptr = input.data<float>();
PADDLE_MOBILE_THROW_EXCEPTION( // pack input to tiles
"Winograd for arm v8 has not been implemented."); for (int c = 0; c < channel; ++c) {
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
const float *in0 = inptr + c * height * width;
const float *in1 = in0 + width;
const float *in2 = in1 + width;
const float *in3 = in2 + width;
const float *in4 = in3 + width;
const float *in5 = in4 + width;
const float *in6 = in5 + width;
const float *in7 = in6 + width;
float *out = outptr + c * h_tiles * w_tiles * 64;
for (int h = 0; h < inter_h; ++h) {
for (int w = 0; w < inter_w; ++w) {
memcpy(out, in0, 8 * sizeof(float));
memcpy(out + 8, in1, 8 * sizeof(float));
memcpy(out + 16, in2, 8 * sizeof(float));
memcpy(out + 24, in3, 8 * sizeof(float));
memcpy(out + 32, in4, 8 * sizeof(float));
memcpy(out + 40, in5, 8 * sizeof(float));
memcpy(out + 48, in6, 8 * sizeof(float));
memcpy(out + 56, in7, 8 * sizeof(float));
in0 += 6;
in1 += 6;
in2 += 6;
in3 += 6;
in4 += 6;
in5 += 6;
in6 += 6;
in7 += 6;
out += 64;
}
// remain width
if (remain_w > 2) {
memcpy(out, in0, remain_w * sizeof(float));
memcpy(out + 8, in1, remain_w * sizeof(float));
memcpy(out + 16, in2, remain_w * sizeof(float));
memcpy(out + 24, in3, remain_w * sizeof(float));
memcpy(out + 32, in4, remain_w * sizeof(float));
memcpy(out + 40, in5, remain_w * sizeof(float));
memcpy(out + 48, in6, remain_w * sizeof(float));
memcpy(out + 56, in7, remain_w * sizeof(float));
out += 64;
}
in0 += 5 * width + remain_w;
in1 += 5 * width + remain_w;
in2 += 5 * width + remain_w;
in3 += 5 * width + remain_w;
in4 += 5 * width + remain_w;
in5 += 5 * width + remain_w;
in6 += 5 * width + remain_w;
in7 += 5 * width + remain_w;
}
// remain height
if (remain_h > 2) {
for (int w = 0; w < inter_w; ++w) {
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out + rh * 8, in0 + rh * width, 8 * sizeof(float));
}
out += 64;
in0 += 6;
}
// remain width
if (remain_w > 2) {
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out + rh * 8, in0 + rh * width, remain_w * sizeof(float));
}
}
}
}
// transform tiles, compute B_T * d(c, b) * B
for (int c = 0; c < channel; ++c) {
for (int tile = 0; tile < h_tiles * w_tiles; ++tile) {
float *out = outptr + (c * h_tiles * w_tiles + tile) * 64;
// compute B_T * d(c, b)
float bd[8][8];
for (int i = 0; i < 8; ++i) {
float d0 = out[8 * i + 0];
float d1 = out[8 * i + 1];
float d2 = out[8 * i + 2];
float d3 = out[8 * i + 3];
float d4 = out[8 * i + 4];
float d5 = out[8 * i + 5];
float d6 = out[8 * i + 6];
float d7 = out[8 * i + 7];
bd[i][0] = d0 - d6 + (d4 - d2) * 5.25;
float v1 = d2 - 4.25 * d4 + d6;
float v2 = d1 - 4.25 * d3 + d5;
// d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6
bd[i][1] = v1 + v2;
// -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6
bd[i][2] = v1 - v2;
v1 = 0.25 * d2 - 1.25 * d4 + d6;
v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5;
// 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6
bd[i][3] = v1 + v2;
// -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6
bd[i][4] = v1 - v2;
v1 = 4 * d2 - 5 * d4 + d6;
v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5;
// 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6
bd[i][5] = v1 + v2;
// -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6
bd[i][6] = v1 - v2;
bd[i][7] = d7 - d1 + (d3 - d5) * 5.25;
}
// compute B_T * d(c, b) * B
for (int i = 0; i < 8; ++i, out += 8) {
float d0 = bd[0][i];
float d1 = bd[1][i];
float d2 = bd[2][i];
float d3 = bd[3][i];
float d4 = bd[4][i];
float d5 = bd[5][i];
float d6 = bd[6][i];
float d7 = bd[7][i];
out[0] = d0 - d6 + (d4 - d2) * 5.25;
float v1 = d2 - 4.25 * d4 + d6;
float v2 = d1 - 4.25 * d3 + d5;
// d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6
out[1] = v1 + v2;
// -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6
out[2] = v1 - v2;
v1 = 0.25 * d2 - 1.25 * d4 + d6;
v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5;
// 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6
out[3] = v1 + v2;
// -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6
out[4] = v1 - v2;
v1 = 4 * d2 - 5 * d4 + d6;
v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5;
// 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6
out[5] = v1 + v2;
// -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6
out[6] = v1 - v2;
out[7] = d7 - d1 + (d3 - d5) * 5.25;
}
}
}
} }
template <> template <>
void winograd_transform_output<8, 3>(const framework::Tensor &input, void winograd_transform_output<8, 3>(const framework::Tensor &input,
const framework::Tensor &weight, const framework::Tensor &weight,
framework::Tensor *output) { framework::Tensor *output) {
// TODO(hjchen2) // input shape is [in_channel, h_tiles, w_tiles, 64]
PADDLE_MOBILE_THROW_EXCEPTION( // weight shape is [out_channel, in_channel, 64]
"Winograd for arm v8 has not been implemented."); int in_channel = input.dims()[0];
int h_tiles = input.dims()[1];
int w_tiles = input.dims()[2];
int tiles = h_tiles * w_tiles;
int out_channel = weight.dims()[0];
// compute U*V first
framework::Tensor output_m;
framework::DDim shape =
framework::make_ddim(std::vector<int>{out_channel, tiles, 64});
float *output_m_ptr = output_m.mutable_data<float>(shape);
memset(output_m_ptr, 0, output_m.numel() * sizeof(float));
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
for (int i = 0; i < out_channel; ++i) {
for (int j = 0; j < tiles; ++j) {
const float *w_ptr = weight_ptr + i * in_channel * 64;
const float *in_ptr = input_ptr + j * 64;
float *m_ptr = output_m_ptr + (i * tiles + j) * 64;
for (int c = 0; c < in_channel; ++c) {
for (int k = 0; k < 64; ++k) {
m_ptr[k] += w_ptr[k] * in_ptr[k];
}
w_ptr += 64;
in_ptr += tiles * 64;
}
}
}
for (int oc = 0; oc < out_channel; ++oc) {
for (int tile = 0; tile < tiles; ++tile) {
float *m = output_m_ptr + (oc * tiles + tile) * 64;
// compute A_T * m
float am[6][8];
for (int i = 0; i < 8; ++i) {
float d0 = m[i * 8 + 0];
float d1 = m[i * 8 + 1];
float d2 = m[i * 8 + 2];
float d3 = m[i * 8 + 3];
float d4 = m[i * 8 + 4];
float d5 = m[i * 8 + 5];
float d6 = m[i * 8 + 6];
float d7 = m[i * 8 + 7];
float v0 = d1 + d2;
float v1 = d1 - d2;
float v2 = d3 + d4;
float v3 = d3 - d4;
float v4 = d5 + d6;
float v5 = d5 - d6;
am[0][i] = d0 + v0 + v2 + 32 * v4;
am[1][i] = v1 + 2 * v3 + 16 * v5;
am[2][i] = v0 + 4 * v2 + 8 * v4;
am[3][i] = v1 + 8 * v3 + 4 * v5;
am[4][i] = v0 + 16 * v2 + 2 * v4;
am[5][i] = v1 + 32 * v3 + v5 + d7;
}
// compute A_T * m * A
for (int i = 0; i < 6; ++i, m += 8) {
float d0 = am[i][0];
float d1 = am[i][1];
float d2 = am[i][2];
float d3 = am[i][3];
float d4 = am[i][4];
float d5 = am[i][5];
float d6 = am[i][6];
float d7 = am[i][7];
float v0 = d1 + d2;
float v1 = d1 - d2;
float v2 = d3 + d4;
float v3 = d3 - d4;
float v4 = d5 + d6;
float v5 = d5 - d6;
m[0] = d0 + v0 + v2 + 32 * v4;
m[1] = v1 + 2 * v3 + 16 * v5;
m[2] = v0 + 4 * v2 + 8 * v4;
m[3] = v1 + 8 * v3 + 4 * v5;
m[4] = v0 + 16 * v2 + 2 * v4;
m[5] = v1 + 32 * v3 + v5 + d7;
}
}
}
int out_h = output->dims()[2];
int out_w = output->dims()[3];
float *output_ptr = output->mutable_data<float>();
// copy valid region to final output
for (int oc = 0; oc < out_channel; ++oc) {
int inter_h = out_h / 6;
int inter_w = out_w / 6;
int remain_h = out_h - inter_h * 6;
int remain_w = out_w - inter_w * 6;
float *out_ptr0 = output_ptr + oc * out_h * out_w;
float *out_ptr1 = out_ptr0 + out_w;
float *out_ptr2 = out_ptr1 + out_w;
float *out_ptr3 = out_ptr2 + out_w;
float *out_ptr4 = out_ptr3 + out_w;
float *out_ptr5 = out_ptr4 + out_w;
const float *m_ptr = output_m_ptr + oc * tiles * 64;
for (int tile_h = 0; tile_h < inter_h; ++tile_h) {
for (int tile_w = 0; tile_w < inter_w; ++tile_w) {
const float *m = m_ptr + (tile_h * w_tiles + tile_w) * 64;
memcpy(out_ptr0, m, 6 * sizeof(float));
memcpy(out_ptr1, m + 8, 6 * sizeof(float));
memcpy(out_ptr2, m + 16, 6 * sizeof(float));
memcpy(out_ptr3, m + 24, 6 * sizeof(float));
memcpy(out_ptr4, m + 32, 6 * sizeof(float));
memcpy(out_ptr5, m + 40, 6 * sizeof(float));
out_ptr0 += 6;
out_ptr1 += 6;
out_ptr2 += 6;
out_ptr3 += 6;
out_ptr4 += 6;
out_ptr5 += 6;
}
// remain w
if (remain_w > 0) {
const float *m = m_ptr + (tile_h * w_tiles + inter_w) * 64;
memcpy(out_ptr0, m, remain_w * sizeof(float));
memcpy(out_ptr1, m + 8, remain_w * sizeof(float));
memcpy(out_ptr2, m + 16, remain_w * sizeof(float));
memcpy(out_ptr3, m + 24, remain_w * sizeof(float));
memcpy(out_ptr4, m + 32, remain_w * sizeof(float));
memcpy(out_ptr5, m + 40, remain_w * sizeof(float));
out_ptr0 += remain_w;
out_ptr1 += remain_w;
out_ptr2 += remain_w;
out_ptr3 += remain_w;
out_ptr4 += remain_w;
out_ptr5 += remain_w;
}
out_ptr0 += 5 * out_w;
out_ptr1 += 5 * out_w;
out_ptr2 += 5 * out_w;
out_ptr3 += 5 * out_w;
out_ptr4 += 5 * out_w;
out_ptr5 += 5 * out_w;
}
// remain h
if (remain_h > 0) {
for (int tile_w = 0; tile_w < inter_w; ++tile_w) {
const float *m = m_ptr + (inter_h * w_tiles + tile_w) * 64;
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out_ptr0 + rh * out_w, m + rh * 8, 6 * sizeof(float));
}
out_ptr0 += 6;
}
if (remain_w > 0) {
const float *m = m_ptr + (inter_h * w_tiles + inter_w) * 64;
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out_ptr0 + rh * out_w, m + rh * 8, remain_w * sizeof(float));
}
}
}
}
} }
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册