提交 97f820c2 编写于 作者: 吴承辉

Merge branch 'winograd6x6' into 'master'

Reorder winograd compute order

See merge request !511
......@@ -24,7 +24,7 @@ namespace mace {
namespace kernels {
namespace {
// NCHW => TNCB (T: in tile pixels, B: tile indices)
// NCHW => NTCB (T: in tile pixels, B: tile indices)
void TransformInput4x4(const float *input,
const index_t batch,
const index_t in_height,
......@@ -32,12 +32,15 @@ void TransformInput4x4(const float *input,
const index_t in_channels,
const index_t tile_count,
float *output) {
const index_t stride = batch * in_channels * tile_count;
const index_t stride = in_channels * tile_count;
const index_t in_height_width = in_height * in_width;
const index_t input_batch_size = in_height_width * in_channels;
const index_t output_batch_size = 16 * in_channels * tile_count;
#pragma omp parallel for
for (index_t nc = 0; nc < batch * in_channels; ++nc) {
index_t tile_index = nc * tile_count;
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < in_channels; ++c) {
index_t tile_index = 0;
for (index_t h = 0; h < in_height - 2; h += 2) {
for (index_t w = 0; w < in_width - 2; w += 2) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14,
......@@ -46,26 +49,28 @@ void TransformInput4x4(const float *input,
s15;
// load tile data
const index_t tile_offset = nc * in_height_width + h * in_width + w;
d0 = input[tile_offset];
d1 = input[tile_offset + 1];
d2 = input[tile_offset + 2];
d3 = input[tile_offset + 3];
d4 = input[tile_offset + in_width];
d5 = input[tile_offset + in_width + 1];
d6 = input[tile_offset + in_width + 2];
d7 = input[tile_offset + in_width + 3];
d8 = input[tile_offset + 2 * in_width];
d9 = input[tile_offset + 2 * in_width + 1];
d10 = input[tile_offset + 2 * in_width + 2];
d11 = input[tile_offset + 2 * in_width + 3];
d12 = input[tile_offset + 3 * in_width];
d13 = input[tile_offset + 3 * in_width + 1];
d14 = input[tile_offset + 3 * in_width + 2];
d15 = input[tile_offset + 3 * in_width + 3];
const float *input_ptr =
input + n * input_batch_size + c * in_height_width + h * in_width
+ w;
d0 = input_ptr[0];
d1 = input_ptr[1];
d2 = input_ptr[2];
d3 = input_ptr[3];
d4 = input_ptr[in_width];
d5 = input_ptr[in_width + 1];
d6 = input_ptr[in_width + 2];
d7 = input_ptr[in_width + 3];
d8 = input_ptr[2 * in_width];
d9 = input_ptr[2 * in_width + 1];
d10 = input_ptr[2 * in_width + 2];
d11 = input_ptr[2 * in_width + 3];
d12 = input_ptr[3 * in_width];
d13 = input_ptr[3 * in_width + 1];
d14 = input_ptr[3 * in_width + 2];
d15 = input_ptr[3 * in_width + 3];
// s = BT * d * B
s0 = (d0 - d8) - (d2 - d10);
......@@ -86,33 +91,36 @@ void TransformInput4x4(const float *input,
s15 = (d5 - d13) - (d7 - d15);
// store output
output[tile_index + 0 * stride] = s0;
output[tile_index + 1 * stride] = s1;
output[tile_index + 2 * stride] = s2;
output[tile_index + 3 * stride] = s3;
output[tile_index + 4 * stride] = s4;
output[tile_index + 5 * stride] = s5;
output[tile_index + 6 * stride] = s6;
output[tile_index + 7 * stride] = s7;
output[tile_index + 8 * stride] = s8;
output[tile_index + 9 * stride] = s9;
output[tile_index + 10 * stride] = s10;
output[tile_index + 11 * stride] = s11;
output[tile_index + 12 * stride] = s12;
output[tile_index + 13 * stride] = s13;
output[tile_index + 14 * stride] = s14;
output[tile_index + 15 * stride] = s15;
float *output_ptr =
output + n * output_batch_size + c * tile_count + tile_index;
output_ptr[0] = s0;
output_ptr[1 * stride] = s1;
output_ptr[2 * stride] = s2;
output_ptr[3 * stride] = s3;
output_ptr[4 * stride] = s4;
output_ptr[5 * stride] = s5;
output_ptr[6 * stride] = s6;
output_ptr[7 * stride] = s7;
output_ptr[8 * stride] = s8;
output_ptr[9 * stride] = s9;
output_ptr[10 * stride] = s10;
output_ptr[11 * stride] = s11;
output_ptr[12 * stride] = s12;
output_ptr[13 * stride] = s13;
output_ptr[14 * stride] = s14;
output_ptr[15 * stride] = s15;
++tile_index;
}
}
}
}
}
// NCHW => TNCB (T: in tile pixels, B: tile indices)
// NCHW => NTCB (T: in tile pixels, B: tile indices)
/**
* BT =
⎡1 0 -21/4 0 21/4 0 -1 0⎤
......@@ -146,26 +154,32 @@ void TransformInput8x8(const float *input,
const index_t in_channels,
const index_t tile_count,
float *output) {
const index_t stride = batch * in_channels * tile_count;
const index_t stride = in_channels * tile_count;
const index_t in_height_width = in_height * in_width;
const index_t input_batch_size = in_height_width * in_channels;
const index_t output_batch_size = 64 * in_channels * tile_count;
#pragma omp parallel for
for (index_t nc = 0; nc < batch * in_channels; ++nc) {
index_t tile_index = nc * tile_count;
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < in_channels; ++c) {
index_t tile_index = 0;
float s[8][8];
for (index_t h = 0; h < in_height - 2; h += 6) {
for (index_t w = 0; w < in_width - 2; w += 6) {
index_t tile_offset = nc * in_height_width + h * in_width + w;
const float *input_ptr =
input + n * input_batch_size + c * in_height_width + h * in_width
+ w;
for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = input[tile_offset];
d1 = input[tile_offset + 1];
d2 = input[tile_offset + 2];
d3 = input[tile_offset + 3];
d4 = input[tile_offset + 4];
d5 = input[tile_offset + 5];
d6 = input[tile_offset + 6];
d7 = input[tile_offset + 7];
d0 = input_ptr[0];
d1 = input_ptr[1];
d2 = input_ptr[2];
d3 = input_ptr[3];
d4 = input_ptr[4];
d5 = input_ptr[5];
d6 = input_ptr[6];
d7 = input_ptr[7];
s[i][0] = d0 - d6 + (d4 - d2) * 5.25;
s[i][7] = d7 - d1 + (d3 - d5) * 5.25;
......@@ -185,9 +199,11 @@ void TransformInput8x8(const float *input,
s[i][5] = u + v;
s[i][6] = u - v;
tile_offset += in_width;
input_ptr += in_width;
}
float *output_ptr =
output + n * output_batch_size + c * tile_count + tile_index;
for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = s[0][i];
......@@ -199,32 +215,33 @@ void TransformInput8x8(const float *input,
d6 = s[6][i];
d7 = s[7][i];
output[tile_index + i * stride] = d0 - d6 + (d4 - d2) * 5.25;
output[tile_index + (56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25;
output_ptr[i * stride] = d0 - d6 + (d4 - d2) * 5.25;
output_ptr[(56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25;
float u = d2 + d6 - d4 * 4.25;
float v = d1 + d5 - d3 * 4.25;
output[tile_index + (8 + i) * stride] = u + v;
output[tile_index + (16 + i) * stride] = u - v;
output_ptr[(8 + i) * stride] = u + v;
output_ptr[(16 + i) * stride] = u - v;
u = d6 + d2 * 0.25 - d4 * 1.25;
v = d1 * 0.5 - d3 * 2.5 + d5 * 2;
output[tile_index + (24 + i) * stride] = u + v;
output[tile_index + (32 + i) * stride] = u - v;
output_ptr[(24 + i) * stride] = u + v;
output_ptr[(32 + i) * stride] = u - v;
u = d6 + (d2 - d4 * 1.25) * 4;
v = d1 * 2 - d3 * 2.5 + d5 * 0.5;
output[tile_index + (40 + i) * stride] = u + v;
output[tile_index + (48 + i) * stride] = u - v;
output_ptr[(40 + i) * stride] = u + v;
output_ptr[(48 + i) * stride] = u - v;
}
++tile_index;
}
}
}
}
}
// TOC * TNCB => TNOB
// TOC * NTCB => NTOB
void BatchGemm(const float *input,
const float *filter,
index_t batch,
......@@ -233,12 +250,13 @@ void BatchGemm(const float *input,
index_t tile_count,
int out_tile_size,
float *output) {
const index_t in_stride = batch * in_channels * tile_count;
const index_t in_channels_tile_count = in_channels * tile_count;
const index_t filter_stride = out_channels * in_channels;
const index_t out_stride = batch * out_channels * tile_count;
const index_t out_channels_tile_count = out_channels * tile_count;
const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
const index_t in_batch_size = in_tile_area * in_channels * tile_count;
const index_t in_stride = in_channels * tile_count;
const index_t out_batch_size = in_tile_area * out_channels * tile_count;
const index_t out_stride = out_channels * tile_count;
if (batch == 1) {
Gemm(filter,
input,
......@@ -248,12 +266,13 @@ void BatchGemm(const float *input,
tile_count,
output);
} else {
for (int i = 0; i < in_tile_area; ++i) {
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int i = 0; i < in_tile_area; ++i) {
const float
*in_ptr = input + i * in_stride + b * in_channels_tile_count;
*in_ptr = input + b * in_batch_size + i * in_stride;
const float *filter_ptr = filter + i * filter_stride;
float *out_ptr = output + i * out_stride + b * out_channels_tile_count;
float *out_ptr = output + b * out_batch_size + i * out_stride;
Gemm(filter_ptr,
in_ptr,
1,
......@@ -266,7 +285,7 @@ void BatchGemm(const float *input,
}
}
// TNOB => ToNOB => NOHoWo
// NTOB => NToOB => NOHoWo
void TransformOutput4x4(const float *input,
index_t batch,
index_t out_height,
......@@ -274,11 +293,15 @@ void TransformOutput4x4(const float *input,
index_t out_channels,
index_t tile_count,
float *output) {
const index_t in_stride = batch * out_channels * tile_count;
const index_t stride = out_channels * tile_count;
const index_t input_batch_size = 16 * stride;
const index_t out_image_size = out_height * out_width;
const index_t output_batch_size = out_channels * out_image_size;
#pragma omp parallel for
for (index_t nm = 0; nm < batch * out_channels; ++nm) {
index_t tile_offset = nm * tile_count;
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t m = 0; m < out_channels; ++m) {
index_t tile_offset = 0;
for (index_t h = 0; h < out_height; h += 2) {
for (index_t w = 0; w < out_width; w += 2) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14,
......@@ -286,25 +309,27 @@ void TransformOutput4x4(const float *input,
float s0, s1, s2, s3, s4, s5, s6, s7;
float v0, v1, v2, v3;
d0 = input[tile_offset + 0 * in_stride];
d1 = input[tile_offset + 1 * in_stride];
d2 = input[tile_offset + 2 * in_stride];
d3 = input[tile_offset + 3 * in_stride];
const float *input_ptr =
input + n * input_batch_size + m * tile_count + tile_offset;
d0 = input_ptr[0];
d1 = input_ptr[1 * stride];
d2 = input_ptr[2 * stride];
d3 = input_ptr[3 * stride];
d4 = input[tile_offset + 4 * in_stride];
d5 = input[tile_offset + 5 * in_stride];
d6 = input[tile_offset + 6 * in_stride];
d7 = input[tile_offset + 7 * in_stride];
d4 = input_ptr[4 * stride];
d5 = input_ptr[5 * stride];
d6 = input_ptr[6 * stride];
d7 = input_ptr[7 * stride];
d8 = input[tile_offset + 8 * in_stride];
d9 = input[tile_offset + 9 * in_stride];
d10 = input[tile_offset + 10 * in_stride];
d11 = input[tile_offset + 11 * in_stride];
d8 = input_ptr[8 * stride];
d9 = input_ptr[9 * stride];
d10 = input_ptr[10 * stride];
d11 = input_ptr[11 * stride];
d12 = input[tile_offset + 12 * in_stride];
d13 = input[tile_offset + 13 * in_stride];
d14 = input[tile_offset + 14 * in_stride];
d15 = input[tile_offset + 15 * in_stride];
d12 = input_ptr[12 * stride];
d13 = input_ptr[13 * stride];
d14 = input_ptr[14 * stride];
d15 = input_ptr[15 * stride];
s0 = d0 + d1 + d2;
s1 = d1 - d2 - d3;
......@@ -320,19 +345,22 @@ void TransformOutput4x4(const float *input,
v2 = s2 - s4 - s6;
v3 = s3 - s5 - s7;
index_t out_offset = nm * out_height * out_width + h * out_width + w;
output[out_offset] = v0;
output[out_offset + 1] = v1;
output[out_offset + out_width] = v2;
output[out_offset + out_width + 1] = v3;
float *output_ptr =
output + n * output_batch_size + m * out_image_size + h * out_width
+ w;
output_ptr[0] = v0;
output_ptr[1] = v1;
output_ptr[out_width] = v2;
output_ptr[out_width + 1] = v3;
++tile_offset;
}
}
}
}
}
// TNOB => ToNOB => NOHoWo
// NTOB => NToOB => NOHoWo
/**
* AT =
⎡1 1 1 1 1 32 32 0⎤
......@@ -362,25 +390,31 @@ void TransformOutput8x8(const float *input,
index_t out_channels,
index_t tile_count,
float *output) {
const index_t in_stride = batch * out_channels * tile_count;
const index_t stride = out_channels * tile_count;
const index_t input_batch_size = 64 * stride;
const index_t out_image_size = out_height * out_width;
const index_t output_batch_size = out_channels * out_image_size;
#pragma omp parallel for
for (index_t nm = 0; nm < batch * out_channels; ++nm) {
index_t tile_offset = nm * tile_count;
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t m = 0; m < out_channels; ++m) {
index_t tile_offset = 0;
float s[8][6];
for (index_t h = 0; h < out_height; h += 6) {
for (index_t w = 0; w < out_width; w += 6) {
index_t tile_offset_tmp = tile_offset;
const float *input_ptr =
input + n * input_batch_size + m * tile_count + tile_offset;
for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = input[tile_offset_tmp + 0 * in_stride];
d1 = input[tile_offset_tmp + 1 * in_stride];
d2 = input[tile_offset_tmp + 2 * in_stride];
d3 = input[tile_offset_tmp + 3 * in_stride];
d4 = input[tile_offset_tmp + 4 * in_stride];
d5 = input[tile_offset_tmp + 5 * in_stride];
d6 = input[tile_offset_tmp + 6 * in_stride];
d7 = input[tile_offset_tmp + 7 * in_stride];
d0 = input_ptr[0];
d1 = input_ptr[1 * stride];
d2 = input_ptr[2 * stride];
d3 = input_ptr[3 * stride];
d4 = input_ptr[4 * stride];
d5 = input_ptr[5 * stride];
d6 = input_ptr[6 * stride];
d7 = input_ptr[7 * stride];
float u = d1 + d2;
float v = d1 - d2;
......@@ -396,10 +430,12 @@ void TransformOutput8x8(const float *input,
s[i][4] = u + w * 16 + y + y;
s[i][5] = v + x * 32 + z + d7;
tile_offset_tmp += 8 * in_stride;
input_ptr += 8 * stride;
}
index_t out_offset = nm * out_height * out_width + h * out_width + w;
float *output_ptr =
output + n * output_batch_size + m * out_image_size + h * out_width
+ w;
for (int i = 0; i < 6; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
......@@ -419,18 +455,19 @@ void TransformOutput8x8(const float *input,
float y = d5 + d6;
float z = d5 - d6;
output[out_offset + 0 * out_width + i] = d0 + u + w + y * 32;
output[out_offset + 1 * out_width + i] = v + x + x + z * 16;
output[out_offset + 2 * out_width + i] = u + w * 4 + y * 8;
output[out_offset + 3 * out_width + i] = v + x * 8 + z * 4;
output[out_offset + 4 * out_width + i] = u + w * 16 + y + y;
output[out_offset + 5 * out_width + i] = v + x * 32 + z + d7;
output_ptr[i] = d0 + u + w + y * 32;
output_ptr[1 * out_width + i] = v + x + x + z * 16;
output_ptr[2 * out_width + i] = u + w * 4 + y * 8;
output_ptr[3 * out_width + i] = v + x * 8 + z * 4;
output_ptr[4 * out_width + i] = u + w * 16 + y + y;
output_ptr[5 * out_width + i] = v + x * 32 + z + d7;
}
++tile_offset;
}
}
}
}
}
} // namespace
......
......@@ -165,6 +165,10 @@ BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32);
BM_CONV_2D(1, 128, 56, 56, 1, 1, 1, 1, SAME, 128);
BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024);
BM_CONV_2D(64, 32, 34, 34, 3, 3, 1, 1, VALID, 32);
BM_CONV_2D(1, 32, 34, 34, 3, 3, 1, 1, VALID, 32);
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册