提交 97f820c2 编写于 作者: 吴承辉

Merge branch 'winograd6x6' into 'master'

Reorder winograd compute order

See merge request !511
...@@ -24,7 +24,7 @@ namespace mace { ...@@ -24,7 +24,7 @@ namespace mace {
namespace kernels { namespace kernels {
namespace { namespace {
// NCHW => TNCB (T: in tile pixels, B: tile indices) // NCHW => NTCB (T: in tile pixels, B: tile indices)
void TransformInput4x4(const float *input, void TransformInput4x4(const float *input,
const index_t batch, const index_t batch,
const index_t in_height, const index_t in_height,
...@@ -32,87 +32,95 @@ void TransformInput4x4(const float *input, ...@@ -32,87 +32,95 @@ void TransformInput4x4(const float *input,
const index_t in_channels, const index_t in_channels,
const index_t tile_count, const index_t tile_count,
float *output) { float *output) {
const index_t stride = batch * in_channels * tile_count; const index_t stride = in_channels * tile_count;
const index_t in_height_width = in_height * in_width; const index_t in_height_width = in_height * in_width;
const index_t input_batch_size = in_height_width * in_channels;
const index_t output_batch_size = 16 * in_channels * tile_count;
#pragma omp parallel for #pragma omp parallel for collapse(2)
for (index_t nc = 0; nc < batch * in_channels; ++nc) { for (index_t n = 0; n < batch; ++n) {
index_t tile_index = nc * tile_count; for (index_t c = 0; c < in_channels; ++c) {
for (index_t h = 0; h < in_height - 2; h += 2) { index_t tile_index = 0;
for (index_t w = 0; w < in_width - 2; w += 2) { for (index_t h = 0; h < in_height - 2; h += 2) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, for (index_t w = 0; w < in_width - 2; w += 2) {
d15; float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14,
float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, d15;
s15; float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14,
s15;
// load tile data
const index_t tile_offset = nc * in_height_width + h * in_width + w; // load tile data
d0 = input[tile_offset]; const float *input_ptr =
d1 = input[tile_offset + 1]; input + n * input_batch_size + c * in_height_width + h * in_width
d2 = input[tile_offset + 2]; + w;
d3 = input[tile_offset + 3]; d0 = input_ptr[0];
d1 = input_ptr[1];
d4 = input[tile_offset + in_width]; d2 = input_ptr[2];
d5 = input[tile_offset + in_width + 1]; d3 = input_ptr[3];
d6 = input[tile_offset + in_width + 2];
d7 = input[tile_offset + in_width + 3]; d4 = input_ptr[in_width];
d5 = input_ptr[in_width + 1];
d8 = input[tile_offset + 2 * in_width]; d6 = input_ptr[in_width + 2];
d9 = input[tile_offset + 2 * in_width + 1]; d7 = input_ptr[in_width + 3];
d10 = input[tile_offset + 2 * in_width + 2];
d11 = input[tile_offset + 2 * in_width + 3]; d8 = input_ptr[2 * in_width];
d9 = input_ptr[2 * in_width + 1];
d12 = input[tile_offset + 3 * in_width]; d10 = input_ptr[2 * in_width + 2];
d13 = input[tile_offset + 3 * in_width + 1]; d11 = input_ptr[2 * in_width + 3];
d14 = input[tile_offset + 3 * in_width + 2];
d15 = input[tile_offset + 3 * in_width + 3]; d12 = input_ptr[3 * in_width];
d13 = input_ptr[3 * in_width + 1];
// s = BT * d * B d14 = input_ptr[3 * in_width + 2];
s0 = (d0 - d8) - (d2 - d10); d15 = input_ptr[3 * in_width + 3];
s1 = (d1 - d9) + (d2 - d10);
s2 = (d2 - d10) - (d1 - d9); // s = BT * d * B
s3 = (d1 - d9) - (d3 - d11); s0 = (d0 - d8) - (d2 - d10);
s4 = (d4 + d8) - (d6 + d10); s1 = (d1 - d9) + (d2 - d10);
s5 = (d5 + d9) + (d6 + d10); s2 = (d2 - d10) - (d1 - d9);
s6 = (d6 + d10) - (d5 + d9); s3 = (d1 - d9) - (d3 - d11);
s7 = (d5 + d9) - (d7 + d11); s4 = (d4 + d8) - (d6 + d10);
s8 = (d8 - d4) - (d10 - d6); s5 = (d5 + d9) + (d6 + d10);
s9 = (d9 - d5) + (d10 - d6); s6 = (d6 + d10) - (d5 + d9);
s10 = (d10 - d6) - (d9 - d5); s7 = (d5 + d9) - (d7 + d11);
s11 = (d9 - d5) - (d11 - d7); s8 = (d8 - d4) - (d10 - d6);
s12 = (d4 - d12) - (d6 - d14); s9 = (d9 - d5) + (d10 - d6);
s13 = (d5 - d13) + (d6 - d14); s10 = (d10 - d6) - (d9 - d5);
s14 = (d6 - d14) - (d5 - d13); s11 = (d9 - d5) - (d11 - d7);
s15 = (d5 - d13) - (d7 - d15); s12 = (d4 - d12) - (d6 - d14);
s13 = (d5 - d13) + (d6 - d14);
// store output s14 = (d6 - d14) - (d5 - d13);
output[tile_index + 0 * stride] = s0; s15 = (d5 - d13) - (d7 - d15);
output[tile_index + 1 * stride] = s1;
output[tile_index + 2 * stride] = s2; // store output
output[tile_index + 3 * stride] = s3; float *output_ptr =
output + n * output_batch_size + c * tile_count + tile_index;
output[tile_index + 4 * stride] = s4; output_ptr[0] = s0;
output[tile_index + 5 * stride] = s5; output_ptr[1 * stride] = s1;
output[tile_index + 6 * stride] = s6; output_ptr[2 * stride] = s2;
output[tile_index + 7 * stride] = s7; output_ptr[3 * stride] = s3;
output[tile_index + 8 * stride] = s8; output_ptr[4 * stride] = s4;
output[tile_index + 9 * stride] = s9; output_ptr[5 * stride] = s5;
output[tile_index + 10 * stride] = s10; output_ptr[6 * stride] = s6;
output[tile_index + 11 * stride] = s11; output_ptr[7 * stride] = s7;
output[tile_index + 12 * stride] = s12; output_ptr[8 * stride] = s8;
output[tile_index + 13 * stride] = s13; output_ptr[9 * stride] = s9;
output[tile_index + 14 * stride] = s14; output_ptr[10 * stride] = s10;
output[tile_index + 15 * stride] = s15; output_ptr[11 * stride] = s11;
++tile_index; output_ptr[12 * stride] = s12;
output_ptr[13 * stride] = s13;
output_ptr[14 * stride] = s14;
output_ptr[15 * stride] = s15;
++tile_index;
}
} }
} }
} }
} }
// NCHW => TNCB (T: in tile pixels, B: tile indices) // NCHW => NTCB (T: in tile pixels, B: tile indices)
/** /**
* BT = * BT =
⎡1 0 -21/4 0 21/4 0 -1 0⎤ ⎡1 0 -21/4 0 21/4 0 -1 0⎤
...@@ -146,85 +154,94 @@ void TransformInput8x8(const float *input, ...@@ -146,85 +154,94 @@ void TransformInput8x8(const float *input,
const index_t in_channels, const index_t in_channels,
const index_t tile_count, const index_t tile_count,
float *output) { float *output) {
const index_t stride = batch * in_channels * tile_count; const index_t stride = in_channels * tile_count;
const index_t in_height_width = in_height * in_width; const index_t in_height_width = in_height * in_width;
const index_t input_batch_size = in_height_width * in_channels;
const index_t output_batch_size = 64 * in_channels * tile_count;
#pragma omp parallel for #pragma omp parallel for collapse(2)
for (index_t nc = 0; nc < batch * in_channels; ++nc) { for (index_t n = 0; n < batch; ++n) {
index_t tile_index = nc * tile_count; for (index_t c = 0; c < in_channels; ++c) {
float s[8][8]; index_t tile_index = 0;
for (index_t h = 0; h < in_height - 2; h += 6) { float s[8][8];
for (index_t w = 0; w < in_width - 2; w += 6) { for (index_t h = 0; h < in_height - 2; h += 6) {
index_t tile_offset = nc * in_height_width + h * in_width + w; for (index_t w = 0; w < in_width - 2; w += 6) {
for (int i = 0; i < 8; ++i) { const float *input_ptr =
float d0, d1, d2, d3, d4, d5, d6, d7; input + n * input_batch_size + c * in_height_width + h * in_width
d0 = input[tile_offset]; + w;
d1 = input[tile_offset + 1];
d2 = input[tile_offset + 2]; for (int i = 0; i < 8; ++i) {
d3 = input[tile_offset + 3]; float d0, d1, d2, d3, d4, d5, d6, d7;
d4 = input[tile_offset + 4]; d0 = input_ptr[0];
d5 = input[tile_offset + 5]; d1 = input_ptr[1];
d6 = input[tile_offset + 6]; d2 = input_ptr[2];
d7 = input[tile_offset + 7]; d3 = input_ptr[3];
d4 = input_ptr[4];
s[i][0] = d0 - d6 + (d4 - d2) * 5.25; d5 = input_ptr[5];
s[i][7] = d7 - d1 + (d3 - d5) * 5.25; d6 = input_ptr[6];
d7 = input_ptr[7];
float u = d2 + d6 - d4 * 4.25;
float v = d1 + d5 - d3 * 4.25; s[i][0] = d0 - d6 + (d4 - d2) * 5.25;
s[i][1] = u + v; s[i][7] = d7 - d1 + (d3 - d5) * 5.25;
s[i][2] = u - v;
float u = d2 + d6 - d4 * 4.25;
u = d6 + d2 * 0.25 - d4 * 1.25; float v = d1 + d5 - d3 * 4.25;
v = d1 * 0.5 - d3 * 2.5 + d5 * 2; s[i][1] = u + v;
s[i][3] = u + v; s[i][2] = u - v;
s[i][4] = u - v;
u = d6 + d2 * 0.25 - d4 * 1.25;
u = d6 + (d2 - d4 * 1.25) * 4; v = d1 * 0.5 - d3 * 2.5 + d5 * 2;
v = d1 * 2 - d3 * 2.5 + d5 * 0.5; s[i][3] = u + v;
s[i][5] = u + v; s[i][4] = u - v;
s[i][6] = u - v;
u = d6 + (d2 - d4 * 1.25) * 4;
tile_offset += in_width; v = d1 * 2 - d3 * 2.5 + d5 * 0.5;
} s[i][5] = u + v;
s[i][6] = u - v;
input_ptr += in_width;
}
for (int i = 0; i < 8; ++i) { float *output_ptr =
float d0, d1, d2, d3, d4, d5, d6, d7; output + n * output_batch_size + c * tile_count + tile_index;
d0 = s[0][i]; for (int i = 0; i < 8; ++i) {
d1 = s[1][i]; float d0, d1, d2, d3, d4, d5, d6, d7;
d2 = s[2][i]; d0 = s[0][i];
d3 = s[3][i]; d1 = s[1][i];
d4 = s[4][i]; d2 = s[2][i];
d5 = s[5][i]; d3 = s[3][i];
d6 = s[6][i]; d4 = s[4][i];
d7 = s[7][i]; d5 = s[5][i];
d6 = s[6][i];
output[tile_index + i * stride] = d0 - d6 + (d4 - d2) * 5.25; d7 = s[7][i];
output[tile_index + (56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25;
output_ptr[i * stride] = d0 - d6 + (d4 - d2) * 5.25;
float u = d2 + d6 - d4 * 4.25; output_ptr[(56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25;
float v = d1 + d5 - d3 * 4.25;
output[tile_index + (8 + i) * stride] = u + v; float u = d2 + d6 - d4 * 4.25;
output[tile_index + (16 + i) * stride] = u - v; float v = d1 + d5 - d3 * 4.25;
output_ptr[(8 + i) * stride] = u + v;
u = d6 + d2 * 0.25 - d4 * 1.25; output_ptr[(16 + i) * stride] = u - v;
v = d1 * 0.5 - d3 * 2.5 + d5 * 2;
output[tile_index + (24 + i) * stride] = u + v; u = d6 + d2 * 0.25 - d4 * 1.25;
output[tile_index + (32 + i) * stride] = u - v; v = d1 * 0.5 - d3 * 2.5 + d5 * 2;
output_ptr[(24 + i) * stride] = u + v;
u = d6 + (d2 - d4 * 1.25) * 4; output_ptr[(32 + i) * stride] = u - v;
v = d1 * 2 - d3 * 2.5 + d5 * 0.5;
output[tile_index + (40 + i) * stride] = u + v; u = d6 + (d2 - d4 * 1.25) * 4;
output[tile_index + (48 + i) * stride] = u - v; v = d1 * 2 - d3 * 2.5 + d5 * 0.5;
} output_ptr[(40 + i) * stride] = u + v;
output_ptr[(48 + i) * stride] = u - v;
}
++tile_index; ++tile_index;
}
} }
} }
} }
} }
// TOC * TNCB => TNOB // TOC * NTCB => NTOB
void BatchGemm(const float *input, void BatchGemm(const float *input,
const float *filter, const float *filter,
index_t batch, index_t batch,
...@@ -233,12 +250,13 @@ void BatchGemm(const float *input, ...@@ -233,12 +250,13 @@ void BatchGemm(const float *input,
index_t tile_count, index_t tile_count,
int out_tile_size, int out_tile_size,
float *output) { float *output) {
const index_t in_stride = batch * in_channels * tile_count;
const index_t in_channels_tile_count = in_channels * tile_count;
const index_t filter_stride = out_channels * in_channels; const index_t filter_stride = out_channels * in_channels;
const index_t out_stride = batch * out_channels * tile_count;
const index_t out_channels_tile_count = out_channels * tile_count;
const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2); const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
const index_t in_batch_size = in_tile_area * in_channels * tile_count;
const index_t in_stride = in_channels * tile_count;
const index_t out_batch_size = in_tile_area * out_channels * tile_count;
const index_t out_stride = out_channels * tile_count;
if (batch == 1) { if (batch == 1) {
Gemm(filter, Gemm(filter,
input, input,
...@@ -248,12 +266,13 @@ void BatchGemm(const float *input, ...@@ -248,12 +266,13 @@ void BatchGemm(const float *input,
tile_count, tile_count,
output); output);
} else { } else {
for (int i = 0; i < in_tile_area; ++i) { #pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
for (int i = 0; i < in_tile_area; ++i) {
const float const float
*in_ptr = input + i * in_stride + b * in_channels_tile_count; *in_ptr = input + b * in_batch_size + i * in_stride;
const float *filter_ptr = filter + i * filter_stride; const float *filter_ptr = filter + i * filter_stride;
float *out_ptr = output + i * out_stride + b * out_channels_tile_count; float *out_ptr = output + b * out_batch_size + i * out_stride;
Gemm(filter_ptr, Gemm(filter_ptr,
in_ptr, in_ptr,
1, 1,
...@@ -266,7 +285,7 @@ void BatchGemm(const float *input, ...@@ -266,7 +285,7 @@ void BatchGemm(const float *input,
} }
} }
// TNOB => ToNOB => NOHoWo // NTOB => NToOB => NOHoWo
void TransformOutput4x4(const float *input, void TransformOutput4x4(const float *input,
index_t batch, index_t batch,
index_t out_height, index_t out_height,
...@@ -274,65 +293,74 @@ void TransformOutput4x4(const float *input, ...@@ -274,65 +293,74 @@ void TransformOutput4x4(const float *input,
index_t out_channels, index_t out_channels,
index_t tile_count, index_t tile_count,
float *output) { float *output) {
const index_t in_stride = batch * out_channels * tile_count; const index_t stride = out_channels * tile_count;
const index_t input_batch_size = 16 * stride;
#pragma omp parallel for const index_t out_image_size = out_height * out_width;
for (index_t nm = 0; nm < batch * out_channels; ++nm) { const index_t output_batch_size = out_channels * out_image_size;
index_t tile_offset = nm * tile_count;
for (index_t h = 0; h < out_height; h += 2) { #pragma omp parallel for collapse(2)
for (index_t w = 0; w < out_width; w += 2) { for (index_t n = 0; n < batch; ++n) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, for (index_t m = 0; m < out_channels; ++m) {
d15; index_t tile_offset = 0;
float s0, s1, s2, s3, s4, s5, s6, s7; for (index_t h = 0; h < out_height; h += 2) {
float v0, v1, v2, v3; for (index_t w = 0; w < out_width; w += 2) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14,
d0 = input[tile_offset + 0 * in_stride]; d15;
d1 = input[tile_offset + 1 * in_stride]; float s0, s1, s2, s3, s4, s5, s6, s7;
d2 = input[tile_offset + 2 * in_stride]; float v0, v1, v2, v3;
d3 = input[tile_offset + 3 * in_stride];
const float *input_ptr =
d4 = input[tile_offset + 4 * in_stride]; input + n * input_batch_size + m * tile_count + tile_offset;
d5 = input[tile_offset + 5 * in_stride]; d0 = input_ptr[0];
d6 = input[tile_offset + 6 * in_stride]; d1 = input_ptr[1 * stride];
d7 = input[tile_offset + 7 * in_stride]; d2 = input_ptr[2 * stride];
d3 = input_ptr[3 * stride];
d8 = input[tile_offset + 8 * in_stride];
d9 = input[tile_offset + 9 * in_stride]; d4 = input_ptr[4 * stride];
d10 = input[tile_offset + 10 * in_stride]; d5 = input_ptr[5 * stride];
d11 = input[tile_offset + 11 * in_stride]; d6 = input_ptr[6 * stride];
d7 = input_ptr[7 * stride];
d12 = input[tile_offset + 12 * in_stride];
d13 = input[tile_offset + 13 * in_stride]; d8 = input_ptr[8 * stride];
d14 = input[tile_offset + 14 * in_stride]; d9 = input_ptr[9 * stride];
d15 = input[tile_offset + 15 * in_stride]; d10 = input_ptr[10 * stride];
d11 = input_ptr[11 * stride];
s0 = d0 + d1 + d2;
s1 = d1 - d2 - d3; d12 = input_ptr[12 * stride];
s2 = d4 + d5 + d6; d13 = input_ptr[13 * stride];
s3 = d5 - d6 - d7; d14 = input_ptr[14 * stride];
s4 = d8 + d9 + d10; d15 = input_ptr[15 * stride];
s5 = d9 - d10 - d11;
s6 = d12 + d13 + d14; s0 = d0 + d1 + d2;
s7 = d13 - d14 - d15; s1 = d1 - d2 - d3;
s2 = d4 + d5 + d6;
v0 = s0 + s2 + s4; s3 = d5 - d6 - d7;
v1 = s1 + s3 + s5; s4 = d8 + d9 + d10;
v2 = s2 - s4 - s6; s5 = d9 - d10 - d11;
v3 = s3 - s5 - s7; s6 = d12 + d13 + d14;
s7 = d13 - d14 - d15;
index_t out_offset = nm * out_height * out_width + h * out_width + w;
output[out_offset] = v0; v0 = s0 + s2 + s4;
output[out_offset + 1] = v1; v1 = s1 + s3 + s5;
output[out_offset + out_width] = v2; v2 = s2 - s4 - s6;
output[out_offset + out_width + 1] = v3; v3 = s3 - s5 - s7;
++tile_offset; float *output_ptr =
output + n * output_batch_size + m * out_image_size + h * out_width
+ w;
output_ptr[0] = v0;
output_ptr[1] = v1;
output_ptr[out_width] = v2;
output_ptr[out_width + 1] = v3;
++tile_offset;
}
} }
} }
} }
} }
// TNOB => ToNOB => NOHoWo // NTOB => NToOB => NOHoWo
/** /**
* AT = * AT =
⎡1 1 1 1 1 32 32 0⎤ ⎡1 1 1 1 1 32 32 0⎤
...@@ -362,72 +390,81 @@ void TransformOutput8x8(const float *input, ...@@ -362,72 +390,81 @@ void TransformOutput8x8(const float *input,
index_t out_channels, index_t out_channels,
index_t tile_count, index_t tile_count,
float *output) { float *output) {
const index_t in_stride = batch * out_channels * tile_count; const index_t stride = out_channels * tile_count;
const index_t input_batch_size = 64 * stride;
#pragma omp parallel for const index_t out_image_size = out_height * out_width;
for (index_t nm = 0; nm < batch * out_channels; ++nm) { const index_t output_batch_size = out_channels * out_image_size;
index_t tile_offset = nm * tile_count;
float s[8][6];
for (index_t h = 0; h < out_height; h += 6) {
for (index_t w = 0; w < out_width; w += 6) {
index_t tile_offset_tmp = tile_offset;
for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = input[tile_offset_tmp + 0 * in_stride];
d1 = input[tile_offset_tmp + 1 * in_stride];
d2 = input[tile_offset_tmp + 2 * in_stride];
d3 = input[tile_offset_tmp + 3 * in_stride];
d4 = input[tile_offset_tmp + 4 * in_stride];
d5 = input[tile_offset_tmp + 5 * in_stride];
d6 = input[tile_offset_tmp + 6 * in_stride];
d7 = input[tile_offset_tmp + 7 * in_stride];
float u = d1 + d2;
float v = d1 - d2;
float w = d3 + d4;
float x = d3 - d4;
float y = d5 + d6;
float z = d5 - d6;
s[i][0] = d0 + u + w + y * 32;
s[i][1] = v + x + x + z * 16;
s[i][2] = u + w * 4 + y * 8;
s[i][3] = v + x * 8 + z * 4;
s[i][4] = u + w * 16 + y + y;
s[i][5] = v + x * 32 + z + d7;
tile_offset_tmp += 8 * in_stride;
}
index_t out_offset = nm * out_height * out_width + h * out_width + w; #pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (int i = 0; i < 6; ++i) { for (index_t m = 0; m < out_channels; ++m) {
float d0, d1, d2, d3, d4, d5, d6, d7; index_t tile_offset = 0;
d0 = s[0][i]; float s[8][6];
d1 = s[1][i]; for (index_t h = 0; h < out_height; h += 6) {
d2 = s[2][i]; for (index_t w = 0; w < out_width; w += 6) {
d3 = s[3][i]; const float *input_ptr =
d4 = s[4][i]; input + n * input_batch_size + m * tile_count + tile_offset;
d5 = s[5][i]; for (int i = 0; i < 8; ++i) {
d6 = s[6][i]; float d0, d1, d2, d3, d4, d5, d6, d7;
d7 = s[7][i];
d0 = input_ptr[0];
float u = d1 + d2; d1 = input_ptr[1 * stride];
float v = d1 - d2; d2 = input_ptr[2 * stride];
float w = d3 + d4; d3 = input_ptr[3 * stride];
float x = d3 - d4; d4 = input_ptr[4 * stride];
float y = d5 + d6; d5 = input_ptr[5 * stride];
float z = d5 - d6; d6 = input_ptr[6 * stride];
d7 = input_ptr[7 * stride];
output[out_offset + 0 * out_width + i] = d0 + u + w + y * 32;
output[out_offset + 1 * out_width + i] = v + x + x + z * 16; float u = d1 + d2;
output[out_offset + 2 * out_width + i] = u + w * 4 + y * 8; float v = d1 - d2;
output[out_offset + 3 * out_width + i] = v + x * 8 + z * 4; float w = d3 + d4;
output[out_offset + 4 * out_width + i] = u + w * 16 + y + y; float x = d3 - d4;
output[out_offset + 5 * out_width + i] = v + x * 32 + z + d7; float y = d5 + d6;
} float z = d5 - d6;
s[i][0] = d0 + u + w + y * 32;
s[i][1] = v + x + x + z * 16;
s[i][2] = u + w * 4 + y * 8;
s[i][3] = v + x * 8 + z * 4;
s[i][4] = u + w * 16 + y + y;
s[i][5] = v + x * 32 + z + d7;
input_ptr += 8 * stride;
}
++tile_offset; float *output_ptr =
output + n * output_batch_size + m * out_image_size + h * out_width
+ w;
for (int i = 0; i < 6; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = s[0][i];
d1 = s[1][i];
d2 = s[2][i];
d3 = s[3][i];
d4 = s[4][i];
d5 = s[5][i];
d6 = s[6][i];
d7 = s[7][i];
float u = d1 + d2;
float v = d1 - d2;
float w = d3 + d4;
float x = d3 - d4;
float y = d5 + d6;
float z = d5 - d6;
output_ptr[i] = d0 + u + w + y * 32;
output_ptr[1 * out_width + i] = v + x + x + z * 16;
output_ptr[2 * out_width + i] = u + w * 4 + y * 8;
output_ptr[3 * out_width + i] = v + x * 8 + z * 4;
output_ptr[4 * out_width + i] = u + w * 16 + y + y;
output_ptr[5 * out_width + i] = v + x * 32 + z + d7;
}
++tile_offset;
}
} }
} }
} }
...@@ -448,7 +485,7 @@ void TransformFilter4x4(const float *filter, ...@@ -448,7 +485,7 @@ void TransformFilter4x4(const float *filter,
for (index_t c = 0; c < in_channels; ++c) { for (index_t c = 0; c < in_channels; ++c) {
float g0, g1, g2, g3, g4, g5, g6, g7, g8; float g0, g1, g2, g3, g4, g5, g6, g7, g8;
float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14,
s15; s15;
// load filter // load filter
index_t filter_offset = (m * in_channels + c) * 9; index_t filter_offset = (m * in_channels + c) * 9;
...@@ -537,14 +574,14 @@ void TransformFilter8x8(const float *filter, ...@@ -537,14 +574,14 @@ void TransformFilter8x8(const float *filter,
const index_t stride = out_channels * in_channels; const index_t stride = out_channels * in_channels;
const float G[8][3] = { const float G[8][3] = {
{1.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f},
{-2.0f / 9, -2.0f / 9, -2.0f / 9}, {-2.0f / 9, -2.0f / 9, -2.0f / 9},
{-2.0f / 9, 2.0f / 9, -2.0f / 9}, {-2.0f / 9, 2.0f / 9, -2.0f / 9},
{1.0f / 90, 1.0f / 45, 2.0f / 45}, {1.0f / 90, 1.0f / 45, 2.0f / 45},
{1.0f / 90, -1.0f / 45, 2.0f / 45}, {1.0f / 90, -1.0f / 45, 2.0f / 45},
{1.0f / 45, 1.0f / 90, 1.0f / 180}, {1.0f / 45, 1.0f / 90, 1.0f / 180},
{1.0f / 45, -1.0f / 90, 1.0f / 180}, {1.0f / 45, -1.0f / 90, 1.0f / 180},
{0.0f, 0.0f, 1.0f} {0.0f, 0.0f, 1.0f}
}; };
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
...@@ -575,7 +612,7 @@ void TransformFilter8x8(const float *filter, ...@@ -575,7 +612,7 @@ void TransformFilter8x8(const float *filter,
for (int i = 0; i < 8; ++i) { for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) { for (int j = 0; j < 8; ++j) {
output[output_offset + (i * 8 + j) * stride] = output[output_offset + (i * 8 + j) * stride] =
G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j]; G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j];
} }
} }
} }
......
...@@ -165,6 +165,10 @@ BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32); ...@@ -165,6 +165,10 @@ BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32);
BM_CONV_2D(1, 128, 56, 56, 1, 1, 1, 1, SAME, 128); BM_CONV_2D(1, 128, 56, 56, 1, 1, 1, 1, SAME, 128);
BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024); BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024);
BM_CONV_2D(64, 32, 34, 34, 3, 3, 1, 1, VALID, 32);
BM_CONV_2D(1, 32, 34, 34, 3, 3, 1, 1, VALID, 32);
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册