diff --git a/mace/kernels/arm/conv_2d.cc b/mace/kernels/arm/conv_2d.cc index d9dd849e4f8d04535a466d0ced470c10afb8f016..7fc16cda27b4c3c93d491aa9caf6372247df6e96 100644 --- a/mace/kernels/arm/conv_2d.cc +++ b/mace/kernels/arm/conv_2d.cc @@ -7,6 +7,7 @@ // winograd is always superior to neon impl during benchmark #define USE_WINOGRAD 1 +#define WINOGRAD_OUT_TILE_SIZE 6 namespace mace { namespace kernels { @@ -164,9 +165,9 @@ void Conv2dFunctor::operator()(const Tensor *input, && stride_w == 1 && dilation_h == 1 && dilation_w == 1 && input_channels >= 8 && channels >= 8) { - extra_output_height = RoundUp(height, 2); + extra_output_height = RoundUp(height, WINOGRAD_OUT_TILE_SIZE); extra_input_height = std::max(padded_input_height, extra_output_height + 2); - extra_output_width = RoundUp(width, 2); + extra_output_width = RoundUp(width, WINOGRAD_OUT_TILE_SIZE); extra_input_width = std::max(padded_input_width, extra_output_width + 2); if (extra_input_height != padded_input_height) { pad_bottom += (extra_input_height - padded_input_height); @@ -175,12 +176,15 @@ void Conv2dFunctor::operator()(const Tensor *input, pad_right += (extra_input_width - padded_input_width); } - index_t tile_height_count = (extra_output_height + 1) / 2; - index_t tile_width_count = (extra_output_width + 1) / 2; + index_t tile_height_count = extra_output_height / WINOGRAD_OUT_TILE_SIZE; + index_t tile_width_count = extra_output_width / WINOGRAD_OUT_TILE_SIZE; index_t tile_count = tile_height_count * tile_width_count; - transformed_input_.Resize({16, batch, input_channels, tile_count}); - transformed_filter_.Resize({16, channels, input_channels}); - transformed_output_.Resize({16, batch, channels, tile_count}); + index_t in_tile_area = + (WINOGRAD_OUT_TILE_SIZE + 2) * (WINOGRAD_OUT_TILE_SIZE + 2); + transformed_input_.Resize({in_tile_area, batch, input_channels, + tile_count}); + transformed_filter_.Resize({in_tile_area, channels, input_channels}); + transformed_output_.Resize({in_tile_area, batch, channels, tile_count}); conv_func = [=](const float *pad_input, float *pad_output) { WinoGradConv3x3s1(pad_input, @@ -190,6 +194,7 @@ void Conv2dFunctor::operator()(const Tensor *input, extra_input_width, input_channels, channels, + WINOGRAD_OUT_TILE_SIZE, transformed_input_.mutable_data(), transformed_filter_.mutable_data(), transformed_output_.mutable_data(), diff --git a/mace/kernels/arm/conv_winograd.cc b/mace/kernels/arm/conv_winograd.cc index 15c79a7d7bc6f376c3486f693bde26cc83794914..272e3c5227645ff228113c6d226765617aee724b 100644 --- a/mace/kernels/arm/conv_winograd.cc +++ b/mace/kernels/arm/conv_winograd.cc @@ -8,19 +8,20 @@ #include "mace/kernels/arm/conv_winograd.h" #include "mace/kernels/gemm.h" #include "mace/utils/utils.h" +#include "mace/utils/logging.h" namespace mace { namespace kernels { namespace { // NCHW => TNCB (T: in tile pixels, B: tile indices) -void TransformInput(const float *input, - const index_t batch, - const index_t in_height, - const index_t in_width, - const index_t in_channels, - const index_t tile_count, - float *output) { +void TransformInput4x4(const float *input, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t tile_count, + float *output) { const index_t stride = batch * in_channels * tile_count; const index_t in_height_width = in_height * in_width; @@ -101,12 +102,124 @@ void TransformInput(const float *input, } } +// NCHW => TNCB (T: in tile pixels, B: tile indices) +/** + * BT = +⎡1 0 -21/4 0 21/4 0 -1 0⎤ +⎢ ⎥ +⎢0 1 1 -17/4 -17/4 1 1 0⎥ +⎢ ⎥ +⎢0 -1 1 17/4 -17/4 -1 1 0⎥ +⎢ ⎥ +⎢0 1/2 1/4 -5/2 -5/4 2 1 0⎥ +⎢ ⎥ +⎢0 -1/2 1/4 5/2 -5/4 -2 1 0⎥ +⎢ ⎥ +⎢0 2 4 -5/2 -5 1/2 1 0⎥ +⎢ ⎥ +⎢0 -2 4 5/2 -5 -1/2 1 0⎥ +⎢ ⎥ +⎣0 -1 0 21/4 0 -21/4 0 1⎦ + + * @param input + * @param batch + * @param in_height + * @param in_width + * @param in_channels + * @param tile_count + * @param output + */ +void TransformInput8x8(const float *input, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t tile_count, + float *output) { + const index_t stride = batch * in_channels * tile_count; + const index_t in_height_width = in_height * in_width; + +#pragma omp parallel for + for (index_t nc = 0; nc < batch * in_channels; ++nc) { + index_t tile_index = nc * tile_count; + float s[8][8]; + for (index_t h = 0; h < in_height - 2; h += 6) { + for (index_t w = 0; w < in_width - 2; w += 6) { + index_t tile_offset = nc * in_height_width + h * in_width + w; + for (int i = 0; i < 8; ++i) { + float d0, d1, d2, d3, d4, d5, d6, d7; + d0 = input[tile_offset]; + d1 = input[tile_offset + 1]; + d2 = input[tile_offset + 2]; + d3 = input[tile_offset + 3]; + d4 = input[tile_offset + 4]; + d5 = input[tile_offset + 5]; + d6 = input[tile_offset + 6]; + d7 = input[tile_offset + 7]; + + s[i][0] = d0 - d6 + (d4 - d2) * 5.25; + s[i][7] = d7 - d1 + (d3 - d5) * 5.25; + + float u = d2 + d6 - d4 * 4.25; + float v = d1 + d5 - d3 * 4.25; + s[i][1] = u + v; + s[i][2] = u - v; + + u = d6 + d2 * 0.25 - d4 * 1.25; + v = d1 * 0.5 - d3 * 2.5 + d5 * 2; + s[i][3] = u + v; + s[i][4] = u - v; + + u = d6 + (d2 - d4 * 1.25) * 4; + v = d1 * 2 - d3 * 2.5 + d5 * 0.5; + s[i][5] = u + v; + s[i][6] = u - v; + + tile_offset += in_width; + } + + for (int i = 0; i < 8; ++i) { + float d0, d1, d2, d3, d4, d5, d6, d7; + d0 = s[0][i]; + d1 = s[1][i]; + d2 = s[2][i]; + d3 = s[3][i]; + d4 = s[4][i]; + d5 = s[5][i]; + d6 = s[6][i]; + d7 = s[7][i]; + + output[tile_index + i * stride] = d0 - d6 + (d4 - d2) * 5.25; + output[tile_index + (56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25; + + float u = d2 + d6 - d4 * 4.25; + float v = d1 + d5 - d3 * 4.25; + output[tile_index + (8 + i) * stride] = u + v; + output[tile_index + (16 + i) * stride] = u - v; + + u = d6 + d2 * 0.25 - d4 * 1.25; + v = d1 * 0.5 - d3 * 2.5 + d5 * 2; + output[tile_index + (24 + i) * stride] = u + v; + output[tile_index + (32 + i) * stride] = u - v; + + u = d6 + (d2 - d4 * 1.25) * 4; + v = d1 * 2 - d3 * 2.5 + d5 * 0.5; + output[tile_index + (40 + i) * stride] = u + v; + output[tile_index + (48 + i) * stride] = u - v; + } + + ++tile_index; + } + } + } +} + // OCHW => TOC // no need to optimize, it will exist in converter -void TransformFilter(const float *filter, - const index_t in_channels, - const index_t out_channels, - float *output) { +void TransformFilter4x4(const float *filter, + const index_t in_channels, + const index_t out_channels, + float *output) { const index_t stride = out_channels * in_channels; #pragma omp parallel for collapse(2) @@ -171,6 +284,83 @@ void TransformFilter(const float *filter, } } +// OCHW => TOC +// no need to optimize, it will exist in converter +/** + * G = +⎡ 1 0 0 ⎤ +⎢ ⎥ +⎢-2/9 -2/9 -2/9 ⎥ +⎢ ⎥ +⎢-2/9 2/9 -2/9 ⎥ +⎢ ⎥ +⎢1/90 1/45 2/45 ⎥ +⎢ ⎥ +⎢1/90 -1/45 2/45 ⎥ +⎢ ⎥ +⎢1/45 1/90 1/180⎥ +⎢ ⎥ +⎢1/45 -1/90 1/180⎥ +⎢ ⎥ +⎣ 0 0 1 ⎦ + * + * @param filter + * @param in_channels + * @param out_channels + * @param output + */ +void TransformFilter8x8(const float *filter, + const index_t in_channels, + const index_t out_channels, + float *output) { + const index_t stride = out_channels * in_channels; + + const float G[8][3] = { + {1.0f, 0.0f, 0.0f}, + {-2.0f / 9, -2.0f / 9, -2.0f / 9}, + {-2.0f / 9, 2.0f / 9, -2.0f / 9}, + {1.0f / 90, 1.0f / 45, 2.0f / 45}, + {1.0f / 90, -1.0f / 45, 2.0f / 45}, + {1.0f / 45, 1.0f / 90, 1.0f / 180}, + {1.0f / 45, -1.0f / 90, 1.0f / 180}, + {0.0f, 0.0f, 1.0f} + }; + +#pragma omp parallel for collapse(2) + for (index_t m = 0; m < out_channels; ++m) { + for (index_t c = 0; c < in_channels; ++c) { + // load filter + index_t filter_offset = (m * in_channels + c) * 9; + float g0, g1, g2, g3, g4, g5, g6, g7, g8; + g0 = filter[filter_offset]; + g1 = filter[filter_offset + 1]; + g2 = filter[filter_offset + 2]; + g3 = filter[filter_offset + 3]; + g4 = filter[filter_offset + 4]; + g5 = filter[filter_offset + 5]; + g6 = filter[filter_offset + 6]; + g7 = filter[filter_offset + 7]; + g8 = filter[filter_offset + 8]; + + float s[3][8]; + for (int i = 0; i < 8; ++i) { + s[0][i] = g0 * G[i][0] + g1 * G[i][1] + g2 * G[i][2]; + s[1][i] = g3 * G[i][0] + g4 * G[i][1] + g5 * G[i][2]; + s[2][i] = g6 * G[i][0] + g7 * G[i][1] + g8 * G[i][2]; + } + + // store output + index_t output_offset = m * in_channels + c; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + output[output_offset + (i * 8 + j) * stride] = + G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j]; + } + } + } + } +} + // TOC * TNCB => TNOB void BatchGemm(const float *input, const float *filter, @@ -178,17 +368,24 @@ void BatchGemm(const float *input, index_t in_channels, index_t out_channels, index_t tile_count, + int out_tile_size, float *output) { const index_t in_stride = batch * in_channels * tile_count; const index_t in_channels_tile_count = in_channels * tile_count; const index_t filter_stride = out_channels * in_channels; const index_t out_stride = batch * out_channels * tile_count; const index_t out_channels_tile_count = out_channels * tile_count; - + const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2); if (batch == 1) { - Gemm(filter, input, 16, out_channels, in_channels, tile_count, output); + Gemm(filter, + input, + in_tile_area, + out_channels, + in_channels, + tile_count, + output); } else { - for (int i = 0; i < 16; ++i) { + for (int i = 0; i < in_tile_area; ++i) { for (int b = 0; b < batch; ++b) { const float *in_ptr = input + i * in_stride + b * in_channels_tile_count; @@ -207,13 +404,13 @@ void BatchGemm(const float *input, } // TNOB => ToNOB => NOHoWo -void TransformOutput(const float *input, - index_t batch, - index_t out_height, - index_t out_width, - index_t out_channels, - index_t tile_count, - float *output) { +void TransformOutput4x4(const float *input, + index_t batch, + index_t out_height, + index_t out_width, + index_t out_channels, + index_t tile_count, + float *output) { const index_t in_stride = batch * out_channels * tile_count; #pragma omp parallel for @@ -271,6 +468,107 @@ void TransformOutput(const float *input, } } } + +// TNOB => ToNOB => NOHoWo +/** + * AT = +⎡1 1 1 1 1 32 32 0⎤ +⎢ ⎥ +⎢0 1 -1 2 -2 16 -16 0⎥ +⎢ ⎥ +⎢0 1 1 4 4 8 8 0⎥ +⎢ ⎥ +⎢0 1 -1 8 -8 4 -4 0⎥ +⎢ ⎥ +⎢0 1 1 16 16 2 2 0⎥ +⎢ ⎥ +⎣0 1 -1 32 -32 1 -1 1⎦ + * + * @param input + * @param batch + * @param out_height + * @param out_width + * @param out_channels + * @param tile_count + * @param output + */ +void TransformOutput8x8(const float *input, + index_t batch, + index_t out_height, + index_t out_width, + index_t out_channels, + index_t tile_count, + float *output) { + const index_t in_stride = batch * out_channels * tile_count; + +#pragma omp parallel for + for (index_t nm = 0; nm < batch * out_channels; ++nm) { + index_t tile_offset = nm * tile_count; + float s[8][6]; + for (index_t h = 0; h < out_height; h += 6) { + for (index_t w = 0; w < out_width; w += 6) { + index_t tile_offset_tmp = tile_offset; + for (int i = 0; i < 8; ++i) { + float d0, d1, d2, d3, d4, d5, d6, d7; + d0 = input[tile_offset_tmp + 0 * in_stride]; + d1 = input[tile_offset_tmp + 1 * in_stride]; + d2 = input[tile_offset_tmp + 2 * in_stride]; + d3 = input[tile_offset_tmp + 3 * in_stride]; + d4 = input[tile_offset_tmp + 4 * in_stride]; + d5 = input[tile_offset_tmp + 5 * in_stride]; + d6 = input[tile_offset_tmp + 6 * in_stride]; + d7 = input[tile_offset_tmp + 7 * in_stride]; + + float u = d1 + d2; + float v = d1 - d2; + float w = d3 + d4; + float x = d3 - d4; + float y = d5 + d6; + float z = d5 - d6; + + s[i][0] = d0 + u + w + y * 32; + s[i][1] = v + x + x + z * 16; + s[i][2] = u + w * 4 + y * 8; + s[i][3] = v + x * 8 + z * 4; + s[i][4] = u + w * 16 + y + y; + s[i][5] = v + x * 32 + z + d7; + + tile_offset_tmp += 8 * in_stride; + } + + index_t out_offset = nm * out_height * out_width + h * out_width + w; + + for (int i = 0; i < 6; ++i) { + float d0, d1, d2, d3, d4, d5, d6, d7; + d0 = s[0][i]; + d1 = s[1][i]; + d2 = s[2][i]; + d3 = s[3][i]; + d4 = s[4][i]; + d5 = s[5][i]; + d6 = s[6][i]; + d7 = s[7][i]; + + float u = d1 + d2; + float v = d1 - d2; + float w = d3 + d4; + float x = d3 - d4; + float y = d5 + d6; + float z = d5 - d6; + + output[out_offset + 0 * out_width + i] = d0 + u + w + y * 32; + output[out_offset + 1 * out_width + i] = v + x + x + z * 16; + output[out_offset + 2 * out_width + i] = u + w * 4 + y * 8; + output[out_offset + 3 * out_width + i] = v + x * 8 + z * 4; + output[out_offset + 4 * out_width + i] = u + w * 16 + y + y; + output[out_offset + 5 * out_width + i] = v + x * 32 + z + d7; + } + + ++tile_offset; + } + } + } +} } // namespace void WinoGradConv3x3s1(const float *input, @@ -280,6 +578,7 @@ void WinoGradConv3x3s1(const float *input, const index_t in_width, const index_t in_channels, const index_t out_channels, + const int out_tile_size, float *transformed_input, float *transformed_filter, float *transformed_output, @@ -287,22 +586,52 @@ void WinoGradConv3x3s1(const float *input, float *output) { index_t out_height = in_height - 2; index_t out_width = in_width - 2; - index_t tile_height_count = (out_height + 1) / 2; - index_t tile_width_count = (out_width + 1) / 2; + index_t tile_height_count = + RoundUpDiv(out_height, static_cast(out_tile_size)); + index_t tile_width_count = + RoundUpDiv(out_width, static_cast(out_tile_size)); index_t tile_count = tile_height_count * tile_width_count; - TransformInput(input, - batch, - in_height, - in_width, - in_channels, - tile_count, - transformed_input); + switch (out_tile_size) { + case 2: + TransformInput4x4(input, + batch, + in_height, + in_width, + in_channels, + tile_count, + transformed_input); + break; + case 6: + TransformInput8x8(input, + batch, + in_height, + in_width, + in_channels, + tile_count, + transformed_input); + break; + default:MACE_NOT_IMPLEMENTED; + } // TODO(liyin): put it in model converter, but do not worry, it is fast and // will only do once if (!is_filter_transformed) { - TransformFilter(filter, in_channels, out_channels, transformed_filter); + switch (out_tile_size) { + case 2: + TransformFilter4x4(filter, + in_channels, + out_channels, + transformed_filter); + break; + case 6: + TransformFilter8x8(filter, + in_channels, + out_channels, + transformed_filter); + break; + default:MACE_NOT_IMPLEMENTED; + } } BatchGemm(transformed_input, @@ -311,15 +640,30 @@ void WinoGradConv3x3s1(const float *input, in_channels, out_channels, tile_count, + out_tile_size, transformed_output); - TransformOutput(transformed_output, - batch, - out_height, - out_width, - out_channels, - tile_count, - output); + switch (out_tile_size) { + case 2: + TransformOutput4x4(transformed_output, + batch, + out_height, + out_width, + out_channels, + tile_count, + output); + break; + case 6: + TransformOutput8x8(transformed_output, + batch, + out_height, + out_width, + out_channels, + tile_count, + output); + break; + default:MACE_NOT_IMPLEMENTED; + } } void WinoGradConv3x3s1(const float *input, @@ -329,16 +673,21 @@ void WinoGradConv3x3s1(const float *input, const index_t in_width, const index_t in_channels, const index_t out_channels, + const int out_tile_size, float *output) { index_t out_height = in_height - 2; index_t out_width = in_width - 2; - index_t tile_height_count = (out_height + 1) / 2; - index_t tile_width_count = (out_width + 1) / 2; + index_t tile_height_count = + RoundUpDiv(out_height, static_cast(out_tile_size)); + index_t tile_width_count = + RoundUpDiv(out_width, static_cast(out_tile_size)); index_t tile_count = tile_height_count * tile_width_count; - - index_t transformed_input_size = 16 * batch * in_channels * tile_count; - index_t transformed_filter_size = 16 * out_channels * in_channels; - index_t transformed_output_size = 16 * batch * out_channels * tile_count; + index_t in_tile_area = (out_tile_size + 2) * (out_tile_size + 2); + index_t transformed_input_size = + in_tile_area * batch * in_channels * tile_count; + index_t transformed_filter_size = in_tile_area * out_channels * in_channels; + index_t + transformed_output_size = in_tile_area * batch * out_channels * tile_count; float *transformed_input = new float[transformed_input_size]; // TNCB float *transformed_filter = new float[transformed_filter_size]; // TOC @@ -351,6 +700,7 @@ void WinoGradConv3x3s1(const float *input, in_width, in_channels, out_channels, + out_tile_size, transformed_input, transformed_filter, transformed_output, @@ -362,7 +712,6 @@ void WinoGradConv3x3s1(const float *input, delete[]transformed_output; } - void ConvRef3x3s1(const float *input, const float *filter, const index_t batch, @@ -391,7 +740,8 @@ void ConvRef3x3s1(const float *input, ((b * in_channels + c) * in_height + ih) * in_width + iw; index_t filter_offset = (((m * in_channels) + c) * 3 + kh) * 3 + kw; - output[out_offset] += input[in_offset] * filter[filter_offset]; + output[out_offset] += + input[in_offset] * filter[filter_offset]; } } } diff --git a/mace/kernels/arm/conv_winograd.h b/mace/kernels/arm/conv_winograd.h index 0b288dd158a1eaace3ffa34ce1b3891ca8f90acc..d058a29c159f07ddb92a984f6e93ba38926c3f62 100644 --- a/mace/kernels/arm/conv_winograd.h +++ b/mace/kernels/arm/conv_winograd.h @@ -21,6 +21,7 @@ void WinoGradConv3x3s1(const float *input, const index_t in_width, const index_t in_channels, const index_t out_channels, + const int out_tile_size, float *output); void WinoGradConv3x3s1(const float *input, @@ -30,6 +31,7 @@ void WinoGradConv3x3s1(const float *input, const index_t in_width, const index_t in_channels, const index_t out_channels, + const int out_tile_size, float *transformed_input, float *transformed_filter, float *transformed_output, diff --git a/mace/kernels/arm/conv_winograd_test.cc b/mace/kernels/arm/conv_winograd_test.cc index 52be053bd9e5cb2d57a5c2d3c1b1fce322752996..4cb591ec1b6ae422fa963ec1505aed8df2dc874c 100644 --- a/mace/kernels/arm/conv_winograd_test.cc +++ b/mace/kernels/arm/conv_winograd_test.cc @@ -58,11 +58,12 @@ TEST(ConvWinogradTest, winograd) { in_width, in_channels, out_channels, + 6, output_data); // test for (index_t i = 0; i < output_size; ++i) { - EXPECT_NEAR(output_data_ref[i], output_data[i], 0.1); + EXPECT_NEAR(output_data_ref[i], output_data[i], 0.1) << " with index " << i; } delete[]input_data;