diff --git a/mace/kernels/arm/conv_winograd.cc b/mace/kernels/arm/conv_winograd.cc index e73061e3a1160bebed2da0ac17cf0a5474ae00f5..6a3b520b7a579b69d6dfa9378c47fab92dc765cd 100644 --- a/mace/kernels/arm/conv_winograd.cc +++ b/mace/kernels/arm/conv_winograd.cc @@ -24,7 +24,7 @@ namespace mace { namespace kernels { namespace { -// NCHW => TNCB (T: in tile pixels, B: tile indices) +// NCHW => NTCB (T: in tile pixels, B: tile indices) void TransformInput4x4(const float *input, const index_t batch, const index_t in_height, @@ -32,87 +32,95 @@ void TransformInput4x4(const float *input, const index_t in_channels, const index_t tile_count, float *output) { - const index_t stride = batch * in_channels * tile_count; + const index_t stride = in_channels * tile_count; const index_t in_height_width = in_height * in_width; + const index_t input_batch_size = in_height_width * in_channels; + const index_t output_batch_size = 16 * in_channels * tile_count; -#pragma omp parallel for - for (index_t nc = 0; nc < batch * in_channels; ++nc) { - index_t tile_index = nc * tile_count; - for (index_t h = 0; h < in_height - 2; h += 2) { - for (index_t w = 0; w < in_width - 2; w += 2) { - float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, - d15; - float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, - s15; - - // load tile data - const index_t tile_offset = nc * in_height_width + h * in_width + w; - d0 = input[tile_offset]; - d1 = input[tile_offset + 1]; - d2 = input[tile_offset + 2]; - d3 = input[tile_offset + 3]; - - d4 = input[tile_offset + in_width]; - d5 = input[tile_offset + in_width + 1]; - d6 = input[tile_offset + in_width + 2]; - d7 = input[tile_offset + in_width + 3]; - - d8 = input[tile_offset + 2 * in_width]; - d9 = input[tile_offset + 2 * in_width + 1]; - d10 = input[tile_offset + 2 * in_width + 2]; - d11 = input[tile_offset + 2 * in_width + 3]; - - d12 = input[tile_offset + 3 * in_width]; - d13 = input[tile_offset + 3 * in_width + 1]; - d14 = input[tile_offset + 3 * in_width + 2]; - d15 = input[tile_offset + 3 * in_width + 3]; - - // s = BT * d * B - s0 = (d0 - d8) - (d2 - d10); - s1 = (d1 - d9) + (d2 - d10); - s2 = (d2 - d10) - (d1 - d9); - s3 = (d1 - d9) - (d3 - d11); - s4 = (d4 + d8) - (d6 + d10); - s5 = (d5 + d9) + (d6 + d10); - s6 = (d6 + d10) - (d5 + d9); - s7 = (d5 + d9) - (d7 + d11); - s8 = (d8 - d4) - (d10 - d6); - s9 = (d9 - d5) + (d10 - d6); - s10 = (d10 - d6) - (d9 - d5); - s11 = (d9 - d5) - (d11 - d7); - s12 = (d4 - d12) - (d6 - d14); - s13 = (d5 - d13) + (d6 - d14); - s14 = (d6 - d14) - (d5 - d13); - s15 = (d5 - d13) - (d7 - d15); - - // store output - output[tile_index + 0 * stride] = s0; - output[tile_index + 1 * stride] = s1; - output[tile_index + 2 * stride] = s2; - output[tile_index + 3 * stride] = s3; - - output[tile_index + 4 * stride] = s4; - output[tile_index + 5 * stride] = s5; - output[tile_index + 6 * stride] = s6; - output[tile_index + 7 * stride] = s7; - - output[tile_index + 8 * stride] = s8; - output[tile_index + 9 * stride] = s9; - output[tile_index + 10 * stride] = s10; - output[tile_index + 11 * stride] = s11; - - output[tile_index + 12 * stride] = s12; - output[tile_index + 13 * stride] = s13; - output[tile_index + 14 * stride] = s14; - output[tile_index + 15 * stride] = s15; - - ++tile_index; +#pragma omp parallel for collapse(2) + for (index_t n = 0; n < batch; ++n) { + for (index_t c = 0; c < in_channels; ++c) { + index_t tile_index = 0; + for (index_t h = 0; h < in_height - 2; h += 2) { + for (index_t w = 0; w < in_width - 2; w += 2) { + float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, + d15; + float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, + s15; + + // load tile data + const float *input_ptr = + input + n * input_batch_size + c * in_height_width + h * in_width + + w; + d0 = input_ptr[0]; + d1 = input_ptr[1]; + d2 = input_ptr[2]; + d3 = input_ptr[3]; + + d4 = input_ptr[in_width]; + d5 = input_ptr[in_width + 1]; + d6 = input_ptr[in_width + 2]; + d7 = input_ptr[in_width + 3]; + + d8 = input_ptr[2 * in_width]; + d9 = input_ptr[2 * in_width + 1]; + d10 = input_ptr[2 * in_width + 2]; + d11 = input_ptr[2 * in_width + 3]; + + d12 = input_ptr[3 * in_width]; + d13 = input_ptr[3 * in_width + 1]; + d14 = input_ptr[3 * in_width + 2]; + d15 = input_ptr[3 * in_width + 3]; + + // s = BT * d * B + s0 = (d0 - d8) - (d2 - d10); + s1 = (d1 - d9) + (d2 - d10); + s2 = (d2 - d10) - (d1 - d9); + s3 = (d1 - d9) - (d3 - d11); + s4 = (d4 + d8) - (d6 + d10); + s5 = (d5 + d9) + (d6 + d10); + s6 = (d6 + d10) - (d5 + d9); + s7 = (d5 + d9) - (d7 + d11); + s8 = (d8 - d4) - (d10 - d6); + s9 = (d9 - d5) + (d10 - d6); + s10 = (d10 - d6) - (d9 - d5); + s11 = (d9 - d5) - (d11 - d7); + s12 = (d4 - d12) - (d6 - d14); + s13 = (d5 - d13) + (d6 - d14); + s14 = (d6 - d14) - (d5 - d13); + s15 = (d5 - d13) - (d7 - d15); + + // store output + float *output_ptr = + output + n * output_batch_size + c * tile_count + tile_index; + output_ptr[0] = s0; + output_ptr[1 * stride] = s1; + output_ptr[2 * stride] = s2; + output_ptr[3 * stride] = s3; + + output_ptr[4 * stride] = s4; + output_ptr[5 * stride] = s5; + output_ptr[6 * stride] = s6; + output_ptr[7 * stride] = s7; + + output_ptr[8 * stride] = s8; + output_ptr[9 * stride] = s9; + output_ptr[10 * stride] = s10; + output_ptr[11 * stride] = s11; + + output_ptr[12 * stride] = s12; + output_ptr[13 * stride] = s13; + output_ptr[14 * stride] = s14; + output_ptr[15 * stride] = s15; + + ++tile_index; + } } } } } -// NCHW => TNCB (T: in tile pixels, B: tile indices) +// NCHW => NTCB (T: in tile pixels, B: tile indices) /** * BT = ⎡1 0 -21/4 0 21/4 0 -1 0⎤ @@ -146,85 +154,94 @@ void TransformInput8x8(const float *input, const index_t in_channels, const index_t tile_count, float *output) { - const index_t stride = batch * in_channels * tile_count; + const index_t stride = in_channels * tile_count; const index_t in_height_width = in_height * in_width; + const index_t input_batch_size = in_height_width * in_channels; + const index_t output_batch_size = 64 * in_channels * tile_count; -#pragma omp parallel for - for (index_t nc = 0; nc < batch * in_channels; ++nc) { - index_t tile_index = nc * tile_count; - float s[8][8]; - for (index_t h = 0; h < in_height - 2; h += 6) { - for (index_t w = 0; w < in_width - 2; w += 6) { - index_t tile_offset = nc * in_height_width + h * in_width + w; - for (int i = 0; i < 8; ++i) { - float d0, d1, d2, d3, d4, d5, d6, d7; - d0 = input[tile_offset]; - d1 = input[tile_offset + 1]; - d2 = input[tile_offset + 2]; - d3 = input[tile_offset + 3]; - d4 = input[tile_offset + 4]; - d5 = input[tile_offset + 5]; - d6 = input[tile_offset + 6]; - d7 = input[tile_offset + 7]; - - s[i][0] = d0 - d6 + (d4 - d2) * 5.25; - s[i][7] = d7 - d1 + (d3 - d5) * 5.25; - - float u = d2 + d6 - d4 * 4.25; - float v = d1 + d5 - d3 * 4.25; - s[i][1] = u + v; - s[i][2] = u - v; - - u = d6 + d2 * 0.25 - d4 * 1.25; - v = d1 * 0.5 - d3 * 2.5 + d5 * 2; - s[i][3] = u + v; - s[i][4] = u - v; - - u = d6 + (d2 - d4 * 1.25) * 4; - v = d1 * 2 - d3 * 2.5 + d5 * 0.5; - s[i][5] = u + v; - s[i][6] = u - v; - - tile_offset += in_width; - } +#pragma omp parallel for collapse(2) + for (index_t n = 0; n < batch; ++n) { + for (index_t c = 0; c < in_channels; ++c) { + index_t tile_index = 0; + float s[8][8]; + for (index_t h = 0; h < in_height - 2; h += 6) { + for (index_t w = 0; w < in_width - 2; w += 6) { + const float *input_ptr = + input + n * input_batch_size + c * in_height_width + h * in_width + + w; + + for (int i = 0; i < 8; ++i) { + float d0, d1, d2, d3, d4, d5, d6, d7; + d0 = input_ptr[0]; + d1 = input_ptr[1]; + d2 = input_ptr[2]; + d3 = input_ptr[3]; + d4 = input_ptr[4]; + d5 = input_ptr[5]; + d6 = input_ptr[6]; + d7 = input_ptr[7]; + + s[i][0] = d0 - d6 + (d4 - d2) * 5.25; + s[i][7] = d7 - d1 + (d3 - d5) * 5.25; + + float u = d2 + d6 - d4 * 4.25; + float v = d1 + d5 - d3 * 4.25; + s[i][1] = u + v; + s[i][2] = u - v; + + u = d6 + d2 * 0.25 - d4 * 1.25; + v = d1 * 0.5 - d3 * 2.5 + d5 * 2; + s[i][3] = u + v; + s[i][4] = u - v; + + u = d6 + (d2 - d4 * 1.25) * 4; + v = d1 * 2 - d3 * 2.5 + d5 * 0.5; + s[i][5] = u + v; + s[i][6] = u - v; + + input_ptr += in_width; + } - for (int i = 0; i < 8; ++i) { - float d0, d1, d2, d3, d4, d5, d6, d7; - d0 = s[0][i]; - d1 = s[1][i]; - d2 = s[2][i]; - d3 = s[3][i]; - d4 = s[4][i]; - d5 = s[5][i]; - d6 = s[6][i]; - d7 = s[7][i]; - - output[tile_index + i * stride] = d0 - d6 + (d4 - d2) * 5.25; - output[tile_index + (56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25; - - float u = d2 + d6 - d4 * 4.25; - float v = d1 + d5 - d3 * 4.25; - output[tile_index + (8 + i) * stride] = u + v; - output[tile_index + (16 + i) * stride] = u - v; - - u = d6 + d2 * 0.25 - d4 * 1.25; - v = d1 * 0.5 - d3 * 2.5 + d5 * 2; - output[tile_index + (24 + i) * stride] = u + v; - output[tile_index + (32 + i) * stride] = u - v; - - u = d6 + (d2 - d4 * 1.25) * 4; - v = d1 * 2 - d3 * 2.5 + d5 * 0.5; - output[tile_index + (40 + i) * stride] = u + v; - output[tile_index + (48 + i) * stride] = u - v; - } + float *output_ptr = + output + n * output_batch_size + c * tile_count + tile_index; + for (int i = 0; i < 8; ++i) { + float d0, d1, d2, d3, d4, d5, d6, d7; + d0 = s[0][i]; + d1 = s[1][i]; + d2 = s[2][i]; + d3 = s[3][i]; + d4 = s[4][i]; + d5 = s[5][i]; + d6 = s[6][i]; + d7 = s[7][i]; + + output_ptr[i * stride] = d0 - d6 + (d4 - d2) * 5.25; + output_ptr[(56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25; + + float u = d2 + d6 - d4 * 4.25; + float v = d1 + d5 - d3 * 4.25; + output_ptr[(8 + i) * stride] = u + v; + output_ptr[(16 + i) * stride] = u - v; + + u = d6 + d2 * 0.25 - d4 * 1.25; + v = d1 * 0.5 - d3 * 2.5 + d5 * 2; + output_ptr[(24 + i) * stride] = u + v; + output_ptr[(32 + i) * stride] = u - v; + + u = d6 + (d2 - d4 * 1.25) * 4; + v = d1 * 2 - d3 * 2.5 + d5 * 0.5; + output_ptr[(40 + i) * stride] = u + v; + output_ptr[(48 + i) * stride] = u - v; + } - ++tile_index; + ++tile_index; + } } } } } -// TOC * TNCB => TNOB +// TOC * NTCB => NTOB void BatchGemm(const float *input, const float *filter, index_t batch, @@ -233,12 +250,13 @@ void BatchGemm(const float *input, index_t tile_count, int out_tile_size, float *output) { - const index_t in_stride = batch * in_channels * tile_count; - const index_t in_channels_tile_count = in_channels * tile_count; const index_t filter_stride = out_channels * in_channels; - const index_t out_stride = batch * out_channels * tile_count; - const index_t out_channels_tile_count = out_channels * tile_count; const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2); + const index_t in_batch_size = in_tile_area * in_channels * tile_count; + const index_t in_stride = in_channels * tile_count; + const index_t out_batch_size = in_tile_area * out_channels * tile_count; + const index_t out_stride = out_channels * tile_count; + if (batch == 1) { Gemm(filter, input, @@ -248,12 +266,13 @@ void BatchGemm(const float *input, tile_count, output); } else { - for (int i = 0; i < in_tile_area; ++i) { - for (int b = 0; b < batch; ++b) { +#pragma omp parallel for collapse(2) + for (int b = 0; b < batch; ++b) { + for (int i = 0; i < in_tile_area; ++i) { const float - *in_ptr = input + i * in_stride + b * in_channels_tile_count; + *in_ptr = input + b * in_batch_size + i * in_stride; const float *filter_ptr = filter + i * filter_stride; - float *out_ptr = output + i * out_stride + b * out_channels_tile_count; + float *out_ptr = output + b * out_batch_size + i * out_stride; Gemm(filter_ptr, in_ptr, 1, @@ -266,7 +285,7 @@ void BatchGemm(const float *input, } } -// TNOB => ToNOB => NOHoWo +// NTOB => NToOB => NOHoWo void TransformOutput4x4(const float *input, index_t batch, index_t out_height, @@ -274,65 +293,74 @@ void TransformOutput4x4(const float *input, index_t out_channels, index_t tile_count, float *output) { - const index_t in_stride = batch * out_channels * tile_count; - -#pragma omp parallel for - for (index_t nm = 0; nm < batch * out_channels; ++nm) { - index_t tile_offset = nm * tile_count; - for (index_t h = 0; h < out_height; h += 2) { - for (index_t w = 0; w < out_width; w += 2) { - float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, - d15; - float s0, s1, s2, s3, s4, s5, s6, s7; - float v0, v1, v2, v3; - - d0 = input[tile_offset + 0 * in_stride]; - d1 = input[tile_offset + 1 * in_stride]; - d2 = input[tile_offset + 2 * in_stride]; - d3 = input[tile_offset + 3 * in_stride]; - - d4 = input[tile_offset + 4 * in_stride]; - d5 = input[tile_offset + 5 * in_stride]; - d6 = input[tile_offset + 6 * in_stride]; - d7 = input[tile_offset + 7 * in_stride]; - - d8 = input[tile_offset + 8 * in_stride]; - d9 = input[tile_offset + 9 * in_stride]; - d10 = input[tile_offset + 10 * in_stride]; - d11 = input[tile_offset + 11 * in_stride]; - - d12 = input[tile_offset + 12 * in_stride]; - d13 = input[tile_offset + 13 * in_stride]; - d14 = input[tile_offset + 14 * in_stride]; - d15 = input[tile_offset + 15 * in_stride]; - - s0 = d0 + d1 + d2; - s1 = d1 - d2 - d3; - s2 = d4 + d5 + d6; - s3 = d5 - d6 - d7; - s4 = d8 + d9 + d10; - s5 = d9 - d10 - d11; - s6 = d12 + d13 + d14; - s7 = d13 - d14 - d15; - - v0 = s0 + s2 + s4; - v1 = s1 + s3 + s5; - v2 = s2 - s4 - s6; - v3 = s3 - s5 - s7; - - index_t out_offset = nm * out_height * out_width + h * out_width + w; - output[out_offset] = v0; - output[out_offset + 1] = v1; - output[out_offset + out_width] = v2; - output[out_offset + out_width + 1] = v3; - - ++tile_offset; + const index_t stride = out_channels * tile_count; + const index_t input_batch_size = 16 * stride; + const index_t out_image_size = out_height * out_width; + const index_t output_batch_size = out_channels * out_image_size; + +#pragma omp parallel for collapse(2) + for (index_t n = 0; n < batch; ++n) { + for (index_t m = 0; m < out_channels; ++m) { + index_t tile_offset = 0; + for (index_t h = 0; h < out_height; h += 2) { + for (index_t w = 0; w < out_width; w += 2) { + float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, + d15; + float s0, s1, s2, s3, s4, s5, s6, s7; + float v0, v1, v2, v3; + + const float *input_ptr = + input + n * input_batch_size + m * tile_count + tile_offset; + d0 = input_ptr[0]; + d1 = input_ptr[1 * stride]; + d2 = input_ptr[2 * stride]; + d3 = input_ptr[3 * stride]; + + d4 = input_ptr[4 * stride]; + d5 = input_ptr[5 * stride]; + d6 = input_ptr[6 * stride]; + d7 = input_ptr[7 * stride]; + + d8 = input_ptr[8 * stride]; + d9 = input_ptr[9 * stride]; + d10 = input_ptr[10 * stride]; + d11 = input_ptr[11 * stride]; + + d12 = input_ptr[12 * stride]; + d13 = input_ptr[13 * stride]; + d14 = input_ptr[14 * stride]; + d15 = input_ptr[15 * stride]; + + s0 = d0 + d1 + d2; + s1 = d1 - d2 - d3; + s2 = d4 + d5 + d6; + s3 = d5 - d6 - d7; + s4 = d8 + d9 + d10; + s5 = d9 - d10 - d11; + s6 = d12 + d13 + d14; + s7 = d13 - d14 - d15; + + v0 = s0 + s2 + s4; + v1 = s1 + s3 + s5; + v2 = s2 - s4 - s6; + v3 = s3 - s5 - s7; + + float *output_ptr = + output + n * output_batch_size + m * out_image_size + h * out_width + + w; + output_ptr[0] = v0; + output_ptr[1] = v1; + output_ptr[out_width] = v2; + output_ptr[out_width + 1] = v3; + + ++tile_offset; + } } } } } -// TNOB => ToNOB => NOHoWo +// NTOB => NToOB => NOHoWo /** * AT = ⎡1 1 1 1 1 32 32 0⎤ @@ -362,72 +390,81 @@ void TransformOutput8x8(const float *input, index_t out_channels, index_t tile_count, float *output) { - const index_t in_stride = batch * out_channels * tile_count; - -#pragma omp parallel for - for (index_t nm = 0; nm < batch * out_channels; ++nm) { - index_t tile_offset = nm * tile_count; - float s[8][6]; - for (index_t h = 0; h < out_height; h += 6) { - for (index_t w = 0; w < out_width; w += 6) { - index_t tile_offset_tmp = tile_offset; - for (int i = 0; i < 8; ++i) { - float d0, d1, d2, d3, d4, d5, d6, d7; - d0 = input[tile_offset_tmp + 0 * in_stride]; - d1 = input[tile_offset_tmp + 1 * in_stride]; - d2 = input[tile_offset_tmp + 2 * in_stride]; - d3 = input[tile_offset_tmp + 3 * in_stride]; - d4 = input[tile_offset_tmp + 4 * in_stride]; - d5 = input[tile_offset_tmp + 5 * in_stride]; - d6 = input[tile_offset_tmp + 6 * in_stride]; - d7 = input[tile_offset_tmp + 7 * in_stride]; - - float u = d1 + d2; - float v = d1 - d2; - float w = d3 + d4; - float x = d3 - d4; - float y = d5 + d6; - float z = d5 - d6; - - s[i][0] = d0 + u + w + y * 32; - s[i][1] = v + x + x + z * 16; - s[i][2] = u + w * 4 + y * 8; - s[i][3] = v + x * 8 + z * 4; - s[i][4] = u + w * 16 + y + y; - s[i][5] = v + x * 32 + z + d7; - - tile_offset_tmp += 8 * in_stride; - } + const index_t stride = out_channels * tile_count; + const index_t input_batch_size = 64 * stride; + const index_t out_image_size = out_height * out_width; + const index_t output_batch_size = out_channels * out_image_size; - index_t out_offset = nm * out_height * out_width + h * out_width + w; - - for (int i = 0; i < 6; ++i) { - float d0, d1, d2, d3, d4, d5, d6, d7; - d0 = s[0][i]; - d1 = s[1][i]; - d2 = s[2][i]; - d3 = s[3][i]; - d4 = s[4][i]; - d5 = s[5][i]; - d6 = s[6][i]; - d7 = s[7][i]; - - float u = d1 + d2; - float v = d1 - d2; - float w = d3 + d4; - float x = d3 - d4; - float y = d5 + d6; - float z = d5 - d6; - - output[out_offset + 0 * out_width + i] = d0 + u + w + y * 32; - output[out_offset + 1 * out_width + i] = v + x + x + z * 16; - output[out_offset + 2 * out_width + i] = u + w * 4 + y * 8; - output[out_offset + 3 * out_width + i] = v + x * 8 + z * 4; - output[out_offset + 4 * out_width + i] = u + w * 16 + y + y; - output[out_offset + 5 * out_width + i] = v + x * 32 + z + d7; - } +#pragma omp parallel for collapse(2) + for (index_t n = 0; n < batch; ++n) { + for (index_t m = 0; m < out_channels; ++m) { + index_t tile_offset = 0; + float s[8][6]; + for (index_t h = 0; h < out_height; h += 6) { + for (index_t w = 0; w < out_width; w += 6) { + const float *input_ptr = + input + n * input_batch_size + m * tile_count + tile_offset; + for (int i = 0; i < 8; ++i) { + float d0, d1, d2, d3, d4, d5, d6, d7; + + d0 = input_ptr[0]; + d1 = input_ptr[1 * stride]; + d2 = input_ptr[2 * stride]; + d3 = input_ptr[3 * stride]; + d4 = input_ptr[4 * stride]; + d5 = input_ptr[5 * stride]; + d6 = input_ptr[6 * stride]; + d7 = input_ptr[7 * stride]; + + float u = d1 + d2; + float v = d1 - d2; + float w = d3 + d4; + float x = d3 - d4; + float y = d5 + d6; + float z = d5 - d6; + + s[i][0] = d0 + u + w + y * 32; + s[i][1] = v + x + x + z * 16; + s[i][2] = u + w * 4 + y * 8; + s[i][3] = v + x * 8 + z * 4; + s[i][4] = u + w * 16 + y + y; + s[i][5] = v + x * 32 + z + d7; + + input_ptr += 8 * stride; + } - ++tile_offset; + float *output_ptr = + output + n * output_batch_size + m * out_image_size + h * out_width + + w; + + for (int i = 0; i < 6; ++i) { + float d0, d1, d2, d3, d4, d5, d6, d7; + d0 = s[0][i]; + d1 = s[1][i]; + d2 = s[2][i]; + d3 = s[3][i]; + d4 = s[4][i]; + d5 = s[5][i]; + d6 = s[6][i]; + d7 = s[7][i]; + + float u = d1 + d2; + float v = d1 - d2; + float w = d3 + d4; + float x = d3 - d4; + float y = d5 + d6; + float z = d5 - d6; + + output_ptr[i] = d0 + u + w + y * 32; + output_ptr[1 * out_width + i] = v + x + x + z * 16; + output_ptr[2 * out_width + i] = u + w * 4 + y * 8; + output_ptr[3 * out_width + i] = v + x * 8 + z * 4; + output_ptr[4 * out_width + i] = u + w * 16 + y + y; + output_ptr[5 * out_width + i] = v + x * 32 + z + d7; + } + + ++tile_offset; + } } } } @@ -448,7 +485,7 @@ void TransformFilter4x4(const float *filter, for (index_t c = 0; c < in_channels; ++c) { float g0, g1, g2, g3, g4, g5, g6, g7, g8; float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, - s15; + s15; // load filter index_t filter_offset = (m * in_channels + c) * 9; @@ -537,14 +574,14 @@ void TransformFilter8x8(const float *filter, const index_t stride = out_channels * in_channels; const float G[8][3] = { - {1.0f, 0.0f, 0.0f}, - {-2.0f / 9, -2.0f / 9, -2.0f / 9}, - {-2.0f / 9, 2.0f / 9, -2.0f / 9}, - {1.0f / 90, 1.0f / 45, 2.0f / 45}, - {1.0f / 90, -1.0f / 45, 2.0f / 45}, - {1.0f / 45, 1.0f / 90, 1.0f / 180}, - {1.0f / 45, -1.0f / 90, 1.0f / 180}, - {0.0f, 0.0f, 1.0f} + {1.0f, 0.0f, 0.0f}, + {-2.0f / 9, -2.0f / 9, -2.0f / 9}, + {-2.0f / 9, 2.0f / 9, -2.0f / 9}, + {1.0f / 90, 1.0f / 45, 2.0f / 45}, + {1.0f / 90, -1.0f / 45, 2.0f / 45}, + {1.0f / 45, 1.0f / 90, 1.0f / 180}, + {1.0f / 45, -1.0f / 90, 1.0f / 180}, + {0.0f, 0.0f, 1.0f} }; #pragma omp parallel for collapse(2) @@ -575,7 +612,7 @@ void TransformFilter8x8(const float *filter, for (int i = 0; i < 8; ++i) { for (int j = 0; j < 8; ++j) { output[output_offset + (i * 8 + j) * stride] = - G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j]; + G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j]; } } } diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index a208653333bdea04dfa81303cbd9b78a5b8aa5a8..b795e12768e57d96fcfbf194d50beeec02ae4333 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -165,6 +165,10 @@ BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32); BM_CONV_2D(1, 128, 56, 56, 1, 1, 1, 1, SAME, 128); BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024); + +BM_CONV_2D(64, 32, 34, 34, 3, 3, 1, 1, VALID, 32); +BM_CONV_2D(1, 32, 34, 34, 3, 3, 1, 1, VALID, 32); + } // namespace test } // namespace ops } // namespace mace