提交 a2ca646b 编写于 作者: 李寅

Reorder winograd compute order

上级 f15a122d
...@@ -24,7 +24,7 @@ namespace mace { ...@@ -24,7 +24,7 @@ namespace mace {
namespace kernels { namespace kernels {
namespace { namespace {
// NCHW => TNCB (T: in tile pixels, B: tile indices) // NCHW => NTCB (T: in tile pixels, B: tile indices)
void TransformInput4x4(const float *input, void TransformInput4x4(const float *input,
const index_t batch, const index_t batch,
const index_t in_height, const index_t in_height,
...@@ -32,12 +32,15 @@ void TransformInput4x4(const float *input, ...@@ -32,12 +32,15 @@ void TransformInput4x4(const float *input,
const index_t in_channels, const index_t in_channels,
const index_t tile_count, const index_t tile_count,
float *output) { float *output) {
const index_t stride = batch * in_channels * tile_count; const index_t stride = in_channels * tile_count;
const index_t in_height_width = in_height * in_width; const index_t in_height_width = in_height * in_width;
const index_t input_batch_size = in_height_width * in_channels;
const index_t output_batch_size = 16 * in_channels * tile_count;
#pragma omp parallel for #pragma omp parallel for collapse(2)
for (index_t nc = 0; nc < batch * in_channels; ++nc) { for (index_t n = 0; n < batch; ++n) {
index_t tile_index = nc * tile_count; for (index_t c = 0; c < in_channels; ++c) {
index_t tile_index = 0;
for (index_t h = 0; h < in_height - 2; h += 2) { for (index_t h = 0; h < in_height - 2; h += 2) {
for (index_t w = 0; w < in_width - 2; w += 2) { for (index_t w = 0; w < in_width - 2; w += 2) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14,
...@@ -46,26 +49,28 @@ void TransformInput4x4(const float *input, ...@@ -46,26 +49,28 @@ void TransformInput4x4(const float *input,
s15; s15;
// load tile data // load tile data
const index_t tile_offset = nc * in_height_width + h * in_width + w; const float *input_ptr =
d0 = input[tile_offset]; input + n * input_batch_size + c * in_height_width + h * in_width
d1 = input[tile_offset + 1]; + w;
d2 = input[tile_offset + 2]; d0 = input_ptr[0];
d3 = input[tile_offset + 3]; d1 = input_ptr[1];
d2 = input_ptr[2];
d4 = input[tile_offset + in_width]; d3 = input_ptr[3];
d5 = input[tile_offset + in_width + 1];
d6 = input[tile_offset + in_width + 2]; d4 = input_ptr[in_width];
d7 = input[tile_offset + in_width + 3]; d5 = input_ptr[in_width + 1];
d6 = input_ptr[in_width + 2];
d8 = input[tile_offset + 2 * in_width]; d7 = input_ptr[in_width + 3];
d9 = input[tile_offset + 2 * in_width + 1];
d10 = input[tile_offset + 2 * in_width + 2]; d8 = input_ptr[2 * in_width];
d11 = input[tile_offset + 2 * in_width + 3]; d9 = input_ptr[2 * in_width + 1];
d10 = input_ptr[2 * in_width + 2];
d12 = input[tile_offset + 3 * in_width]; d11 = input_ptr[2 * in_width + 3];
d13 = input[tile_offset + 3 * in_width + 1];
d14 = input[tile_offset + 3 * in_width + 2]; d12 = input_ptr[3 * in_width];
d15 = input[tile_offset + 3 * in_width + 3]; d13 = input_ptr[3 * in_width + 1];
d14 = input_ptr[3 * in_width + 2];
d15 = input_ptr[3 * in_width + 3];
// s = BT * d * B // s = BT * d * B
s0 = (d0 - d8) - (d2 - d10); s0 = (d0 - d8) - (d2 - d10);
...@@ -86,33 +91,36 @@ void TransformInput4x4(const float *input, ...@@ -86,33 +91,36 @@ void TransformInput4x4(const float *input,
s15 = (d5 - d13) - (d7 - d15); s15 = (d5 - d13) - (d7 - d15);
// store output // store output
output[tile_index + 0 * stride] = s0; float *output_ptr =
output[tile_index + 1 * stride] = s1; output + n * output_batch_size + c * tile_count + tile_index;
output[tile_index + 2 * stride] = s2; output_ptr[0] = s0;
output[tile_index + 3 * stride] = s3; output_ptr[1 * stride] = s1;
output_ptr[2 * stride] = s2;
output[tile_index + 4 * stride] = s4; output_ptr[3 * stride] = s3;
output[tile_index + 5 * stride] = s5;
output[tile_index + 6 * stride] = s6; output_ptr[4 * stride] = s4;
output[tile_index + 7 * stride] = s7; output_ptr[5 * stride] = s5;
output_ptr[6 * stride] = s6;
output[tile_index + 8 * stride] = s8; output_ptr[7 * stride] = s7;
output[tile_index + 9 * stride] = s9;
output[tile_index + 10 * stride] = s10; output_ptr[8 * stride] = s8;
output[tile_index + 11 * stride] = s11; output_ptr[9 * stride] = s9;
output_ptr[10 * stride] = s10;
output[tile_index + 12 * stride] = s12; output_ptr[11 * stride] = s11;
output[tile_index + 13 * stride] = s13;
output[tile_index + 14 * stride] = s14; output_ptr[12 * stride] = s12;
output[tile_index + 15 * stride] = s15; output_ptr[13 * stride] = s13;
output_ptr[14 * stride] = s14;
output_ptr[15 * stride] = s15;
++tile_index; ++tile_index;
} }
} }
} }
}
} }
// NCHW => TNCB (T: in tile pixels, B: tile indices) // NCHW => NTCB (T: in tile pixels, B: tile indices)
/** /**
* BT = * BT =
⎡1 0 -21/4 0 21/4 0 -1 0⎤ ⎡1 0 -21/4 0 21/4 0 -1 0⎤
...@@ -146,26 +154,32 @@ void TransformInput8x8(const float *input, ...@@ -146,26 +154,32 @@ void TransformInput8x8(const float *input,
const index_t in_channels, const index_t in_channels,
const index_t tile_count, const index_t tile_count,
float *output) { float *output) {
const index_t stride = batch * in_channels * tile_count; const index_t stride = in_channels * tile_count;
const index_t in_height_width = in_height * in_width; const index_t in_height_width = in_height * in_width;
const index_t input_batch_size = in_height_width * in_channels;
const index_t output_batch_size = 64 * in_channels * tile_count;
#pragma omp parallel for #pragma omp parallel for collapse(2)
for (index_t nc = 0; nc < batch * in_channels; ++nc) { for (index_t n = 0; n < batch; ++n) {
index_t tile_index = nc * tile_count; for (index_t c = 0; c < in_channels; ++c) {
index_t tile_index = 0;
float s[8][8]; float s[8][8];
for (index_t h = 0; h < in_height - 2; h += 6) { for (index_t h = 0; h < in_height - 2; h += 6) {
for (index_t w = 0; w < in_width - 2; w += 6) { for (index_t w = 0; w < in_width - 2; w += 6) {
index_t tile_offset = nc * in_height_width + h * in_width + w; const float *input_ptr =
input + n * input_batch_size + c * in_height_width + h * in_width
+ w;
for (int i = 0; i < 8; ++i) { for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7; float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = input[tile_offset]; d0 = input_ptr[0];
d1 = input[tile_offset + 1]; d1 = input_ptr[1];
d2 = input[tile_offset + 2]; d2 = input_ptr[2];
d3 = input[tile_offset + 3]; d3 = input_ptr[3];
d4 = input[tile_offset + 4]; d4 = input_ptr[4];
d5 = input[tile_offset + 5]; d5 = input_ptr[5];
d6 = input[tile_offset + 6]; d6 = input_ptr[6];
d7 = input[tile_offset + 7]; d7 = input_ptr[7];
s[i][0] = d0 - d6 + (d4 - d2) * 5.25; s[i][0] = d0 - d6 + (d4 - d2) * 5.25;
s[i][7] = d7 - d1 + (d3 - d5) * 5.25; s[i][7] = d7 - d1 + (d3 - d5) * 5.25;
...@@ -185,9 +199,11 @@ void TransformInput8x8(const float *input, ...@@ -185,9 +199,11 @@ void TransformInput8x8(const float *input,
s[i][5] = u + v; s[i][5] = u + v;
s[i][6] = u - v; s[i][6] = u - v;
tile_offset += in_width; input_ptr += in_width;
} }
float *output_ptr =
output + n * output_batch_size + c * tile_count + tile_index;
for (int i = 0; i < 8; ++i) { for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7; float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = s[0][i]; d0 = s[0][i];
...@@ -199,32 +215,33 @@ void TransformInput8x8(const float *input, ...@@ -199,32 +215,33 @@ void TransformInput8x8(const float *input,
d6 = s[6][i]; d6 = s[6][i];
d7 = s[7][i]; d7 = s[7][i];
output[tile_index + i * stride] = d0 - d6 + (d4 - d2) * 5.25; output_ptr[i * stride] = d0 - d6 + (d4 - d2) * 5.25;
output[tile_index + (56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25; output_ptr[(56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25;
float u = d2 + d6 - d4 * 4.25; float u = d2 + d6 - d4 * 4.25;
float v = d1 + d5 - d3 * 4.25; float v = d1 + d5 - d3 * 4.25;
output[tile_index + (8 + i) * stride] = u + v; output_ptr[(8 + i) * stride] = u + v;
output[tile_index + (16 + i) * stride] = u - v; output_ptr[(16 + i) * stride] = u - v;
u = d6 + d2 * 0.25 - d4 * 1.25; u = d6 + d2 * 0.25 - d4 * 1.25;
v = d1 * 0.5 - d3 * 2.5 + d5 * 2; v = d1 * 0.5 - d3 * 2.5 + d5 * 2;
output[tile_index + (24 + i) * stride] = u + v; output_ptr[(24 + i) * stride] = u + v;
output[tile_index + (32 + i) * stride] = u - v; output_ptr[(32 + i) * stride] = u - v;
u = d6 + (d2 - d4 * 1.25) * 4; u = d6 + (d2 - d4 * 1.25) * 4;
v = d1 * 2 - d3 * 2.5 + d5 * 0.5; v = d1 * 2 - d3 * 2.5 + d5 * 0.5;
output[tile_index + (40 + i) * stride] = u + v; output_ptr[(40 + i) * stride] = u + v;
output[tile_index + (48 + i) * stride] = u - v; output_ptr[(48 + i) * stride] = u - v;
} }
++tile_index; ++tile_index;
} }
} }
} }
}
} }
// TOC * TNCB => TNOB // TOC * NTCB => NTOB
void BatchGemm(const float *input, void BatchGemm(const float *input,
const float *filter, const float *filter,
index_t batch, index_t batch,
...@@ -233,12 +250,13 @@ void BatchGemm(const float *input, ...@@ -233,12 +250,13 @@ void BatchGemm(const float *input,
index_t tile_count, index_t tile_count,
int out_tile_size, int out_tile_size,
float *output) { float *output) {
const index_t in_stride = batch * in_channels * tile_count;
const index_t in_channels_tile_count = in_channels * tile_count;
const index_t filter_stride = out_channels * in_channels; const index_t filter_stride = out_channels * in_channels;
const index_t out_stride = batch * out_channels * tile_count;
const index_t out_channels_tile_count = out_channels * tile_count;
const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2); const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
const index_t in_batch_size = in_tile_area * in_channels * tile_count;
const index_t in_stride = in_channels * tile_count;
const index_t out_batch_size = in_tile_area * out_channels * tile_count;
const index_t out_stride = out_channels * tile_count;
if (batch == 1) { if (batch == 1) {
Gemm(filter, Gemm(filter,
input, input,
...@@ -248,12 +266,13 @@ void BatchGemm(const float *input, ...@@ -248,12 +266,13 @@ void BatchGemm(const float *input,
tile_count, tile_count,
output); output);
} else { } else {
for (int i = 0; i < in_tile_area; ++i) { #pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
for (int i = 0; i < in_tile_area; ++i) {
const float const float
*in_ptr = input + i * in_stride + b * in_channels_tile_count; *in_ptr = input + b * in_batch_size + i * in_stride;
const float *filter_ptr = filter + i * filter_stride; const float *filter_ptr = filter + i * filter_stride;
float *out_ptr = output + i * out_stride + b * out_channels_tile_count; float *out_ptr = output + b * out_batch_size + i * out_stride;
Gemm(filter_ptr, Gemm(filter_ptr,
in_ptr, in_ptr,
1, 1,
...@@ -266,7 +285,7 @@ void BatchGemm(const float *input, ...@@ -266,7 +285,7 @@ void BatchGemm(const float *input,
} }
} }
// TNOB => ToNOB => NOHoWo // NTOB => NToOB => NOHoWo
void TransformOutput4x4(const float *input, void TransformOutput4x4(const float *input,
index_t batch, index_t batch,
index_t out_height, index_t out_height,
...@@ -274,11 +293,15 @@ void TransformOutput4x4(const float *input, ...@@ -274,11 +293,15 @@ void TransformOutput4x4(const float *input,
index_t out_channels, index_t out_channels,
index_t tile_count, index_t tile_count,
float *output) { float *output) {
const index_t in_stride = batch * out_channels * tile_count; const index_t stride = out_channels * tile_count;
const index_t input_batch_size = 16 * stride;
const index_t out_image_size = out_height * out_width;
const index_t output_batch_size = out_channels * out_image_size;
#pragma omp parallel for #pragma omp parallel for collapse(2)
for (index_t nm = 0; nm < batch * out_channels; ++nm) { for (index_t n = 0; n < batch; ++n) {
index_t tile_offset = nm * tile_count; for (index_t m = 0; m < out_channels; ++m) {
index_t tile_offset = 0;
for (index_t h = 0; h < out_height; h += 2) { for (index_t h = 0; h < out_height; h += 2) {
for (index_t w = 0; w < out_width; w += 2) { for (index_t w = 0; w < out_width; w += 2) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14,
...@@ -286,25 +309,27 @@ void TransformOutput4x4(const float *input, ...@@ -286,25 +309,27 @@ void TransformOutput4x4(const float *input,
float s0, s1, s2, s3, s4, s5, s6, s7; float s0, s1, s2, s3, s4, s5, s6, s7;
float v0, v1, v2, v3; float v0, v1, v2, v3;
d0 = input[tile_offset + 0 * in_stride]; const float *input_ptr =
d1 = input[tile_offset + 1 * in_stride]; input + n * input_batch_size + m * tile_count + tile_offset;
d2 = input[tile_offset + 2 * in_stride]; d0 = input_ptr[0];
d3 = input[tile_offset + 3 * in_stride]; d1 = input_ptr[1 * stride];
d2 = input_ptr[2 * stride];
d3 = input_ptr[3 * stride];
d4 = input[tile_offset + 4 * in_stride]; d4 = input_ptr[4 * stride];
d5 = input[tile_offset + 5 * in_stride]; d5 = input_ptr[5 * stride];
d6 = input[tile_offset + 6 * in_stride]; d6 = input_ptr[6 * stride];
d7 = input[tile_offset + 7 * in_stride]; d7 = input_ptr[7 * stride];
d8 = input[tile_offset + 8 * in_stride]; d8 = input_ptr[8 * stride];
d9 = input[tile_offset + 9 * in_stride]; d9 = input_ptr[9 * stride];
d10 = input[tile_offset + 10 * in_stride]; d10 = input_ptr[10 * stride];
d11 = input[tile_offset + 11 * in_stride]; d11 = input_ptr[11 * stride];
d12 = input[tile_offset + 12 * in_stride]; d12 = input_ptr[12 * stride];
d13 = input[tile_offset + 13 * in_stride]; d13 = input_ptr[13 * stride];
d14 = input[tile_offset + 14 * in_stride]; d14 = input_ptr[14 * stride];
d15 = input[tile_offset + 15 * in_stride]; d15 = input_ptr[15 * stride];
s0 = d0 + d1 + d2; s0 = d0 + d1 + d2;
s1 = d1 - d2 - d3; s1 = d1 - d2 - d3;
...@@ -320,19 +345,22 @@ void TransformOutput4x4(const float *input, ...@@ -320,19 +345,22 @@ void TransformOutput4x4(const float *input,
v2 = s2 - s4 - s6; v2 = s2 - s4 - s6;
v3 = s3 - s5 - s7; v3 = s3 - s5 - s7;
index_t out_offset = nm * out_height * out_width + h * out_width + w; float *output_ptr =
output[out_offset] = v0; output + n * output_batch_size + m * out_image_size + h * out_width
output[out_offset + 1] = v1; + w;
output[out_offset + out_width] = v2; output_ptr[0] = v0;
output[out_offset + out_width + 1] = v3; output_ptr[1] = v1;
output_ptr[out_width] = v2;
output_ptr[out_width + 1] = v3;
++tile_offset; ++tile_offset;
} }
} }
} }
}
} }
// TNOB => ToNOB => NOHoWo // NTOB => NToOB => NOHoWo
/** /**
* AT = * AT =
⎡1 1 1 1 1 32 32 0⎤ ⎡1 1 1 1 1 32 32 0⎤
...@@ -362,25 +390,31 @@ void TransformOutput8x8(const float *input, ...@@ -362,25 +390,31 @@ void TransformOutput8x8(const float *input,
index_t out_channels, index_t out_channels,
index_t tile_count, index_t tile_count,
float *output) { float *output) {
const index_t in_stride = batch * out_channels * tile_count; const index_t stride = out_channels * tile_count;
const index_t input_batch_size = 64 * stride;
const index_t out_image_size = out_height * out_width;
const index_t output_batch_size = out_channels * out_image_size;
#pragma omp parallel for #pragma omp parallel for collapse(2)
for (index_t nm = 0; nm < batch * out_channels; ++nm) { for (index_t n = 0; n < batch; ++n) {
index_t tile_offset = nm * tile_count; for (index_t m = 0; m < out_channels; ++m) {
index_t tile_offset = 0;
float s[8][6]; float s[8][6];
for (index_t h = 0; h < out_height; h += 6) { for (index_t h = 0; h < out_height; h += 6) {
for (index_t w = 0; w < out_width; w += 6) { for (index_t w = 0; w < out_width; w += 6) {
index_t tile_offset_tmp = tile_offset; const float *input_ptr =
input + n * input_batch_size + m * tile_count + tile_offset;
for (int i = 0; i < 8; ++i) { for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7; float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = input[tile_offset_tmp + 0 * in_stride];
d1 = input[tile_offset_tmp + 1 * in_stride]; d0 = input_ptr[0];
d2 = input[tile_offset_tmp + 2 * in_stride]; d1 = input_ptr[1 * stride];
d3 = input[tile_offset_tmp + 3 * in_stride]; d2 = input_ptr[2 * stride];
d4 = input[tile_offset_tmp + 4 * in_stride]; d3 = input_ptr[3 * stride];
d5 = input[tile_offset_tmp + 5 * in_stride]; d4 = input_ptr[4 * stride];
d6 = input[tile_offset_tmp + 6 * in_stride]; d5 = input_ptr[5 * stride];
d7 = input[tile_offset_tmp + 7 * in_stride]; d6 = input_ptr[6 * stride];
d7 = input_ptr[7 * stride];
float u = d1 + d2; float u = d1 + d2;
float v = d1 - d2; float v = d1 - d2;
...@@ -396,10 +430,12 @@ void TransformOutput8x8(const float *input, ...@@ -396,10 +430,12 @@ void TransformOutput8x8(const float *input,
s[i][4] = u + w * 16 + y + y; s[i][4] = u + w * 16 + y + y;
s[i][5] = v + x * 32 + z + d7; s[i][5] = v + x * 32 + z + d7;
tile_offset_tmp += 8 * in_stride; input_ptr += 8 * stride;
} }
index_t out_offset = nm * out_height * out_width + h * out_width + w; float *output_ptr =
output + n * output_batch_size + m * out_image_size + h * out_width
+ w;
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7; float d0, d1, d2, d3, d4, d5, d6, d7;
...@@ -419,18 +455,19 @@ void TransformOutput8x8(const float *input, ...@@ -419,18 +455,19 @@ void TransformOutput8x8(const float *input,
float y = d5 + d6; float y = d5 + d6;
float z = d5 - d6; float z = d5 - d6;
output[out_offset + 0 * out_width + i] = d0 + u + w + y * 32; output_ptr[i] = d0 + u + w + y * 32;
output[out_offset + 1 * out_width + i] = v + x + x + z * 16; output_ptr[1 * out_width + i] = v + x + x + z * 16;
output[out_offset + 2 * out_width + i] = u + w * 4 + y * 8; output_ptr[2 * out_width + i] = u + w * 4 + y * 8;
output[out_offset + 3 * out_width + i] = v + x * 8 + z * 4; output_ptr[3 * out_width + i] = v + x * 8 + z * 4;
output[out_offset + 4 * out_width + i] = u + w * 16 + y + y; output_ptr[4 * out_width + i] = u + w * 16 + y + y;
output[out_offset + 5 * out_width + i] = v + x * 32 + z + d7; output_ptr[5 * out_width + i] = v + x * 32 + z + d7;
} }
++tile_offset; ++tile_offset;
} }
} }
} }
}
} }
} // namespace } // namespace
......
...@@ -165,6 +165,10 @@ BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32); ...@@ -165,6 +165,10 @@ BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32);
BM_CONV_2D(1, 128, 56, 56, 1, 1, 1, 1, SAME, 128); BM_CONV_2D(1, 128, 56, 56, 1, 1, 1, 1, SAME, 128);
BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024); BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024);
BM_CONV_2D(64, 32, 34, 34, 3, 3, 1, 1, VALID, 32);
BM_CONV_2D(1, 32, 34, 34, 3, 3, 1, 1, VALID, 32);
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册