From 8a480d62033bf3a91a70c0ee3b10fa40c8281c83 Mon Sep 17 00:00:00 2001 From: wuchenghui Date: Wed, 18 Apr 2018 11:29:10 +0800 Subject: [PATCH] move cpu winograd filter transform to graph converter --- mace/kernels/arm/conv_2d.cc | 79 ++++-- mace/kernels/arm/conv_winograd.cc | 339 +++++++++++------------ mace/kernels/arm/conv_winograd.h | 14 +- mace/kernels/conv_2d.h | 7 +- mace/ops/conv_2d.h | 2 + mace/ops/fused_conv_2d.h | 2 + mace/python/tools/caffe_converter_lib.py | 126 ++++++--- mace/python/tools/tf_converter_lib.py | 78 +++++- 8 files changed, 389 insertions(+), 258 deletions(-) diff --git a/mace/kernels/arm/conv_2d.cc b/mace/kernels/arm/conv_2d.cc index 18050433..3b99890e 100644 --- a/mace/kernels/arm/conv_2d.cc +++ b/mace/kernels/arm/conv_2d.cc @@ -15,10 +15,6 @@ #include "mace/kernels/conv_2d.h" #include "mace/kernels/arm/conv_winograd.h" -// winograd is always superior to neon impl during benchmark -#define USE_WINOGRAD 1 -#define WINOGRAD_OUT_TILE_SIZE 6 - namespace mace { namespace kernels { @@ -109,11 +105,21 @@ void Conv2dFunctor::operator()(const Tensor *input, MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(output); + std::vector filter_shape(4); + if (is_filter_transformed_) { + // TOC -> OIHW + filter_shape[0] = filter->dim(1); + filter_shape[1] = filter->dim(2); + filter_shape[2] = filter_shape[3] = 3; + } else { + filter_shape = filter->shape(); + } + std::vector output_shape(4); std::vector paddings(2); if (paddings_.empty()) { CalcNCHWPaddingAndOutputSize(input->shape().data(), - filter->shape().data(), + filter_shape.data(), dilations_, strides_, padding_type_, @@ -121,7 +127,7 @@ void Conv2dFunctor::operator()(const Tensor *input, paddings.data()); } else { paddings = paddings_; - CalcNCHWOutputSize(input->shape().data(), filter->shape().data(), + CalcNCHWOutputSize(input->shape().data(), filter_shape.data(), paddings_.data(), dilations_, strides_, RoundType::FLOOR, output_shape.data()); } @@ -138,10 +144,10 @@ void Conv2dFunctor::operator()(const Tensor *input, index_t input_height = input->dim(2); index_t input_width = input->dim(3); - index_t filter_h = filter->dim(2); - index_t filter_w = filter->dim(3); - MACE_CHECK(filter->dim(0) == channels, filter->dim(0), " != ", channels); - MACE_CHECK(filter->dim(1) == input_channels, filter->dim(1), " != ", + index_t filter_h = filter_shape[2]; + index_t filter_w = filter_shape[3]; + MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels); + MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ", input_channels); index_t stride_h = strides_[0]; @@ -171,9 +177,9 @@ void Conv2dFunctor::operator()(const Tensor *input, std::function conv_func; - bool use_winograd = USE_WINOGRAD && filter_h == 3 && filter_w == 3 + bool use_winograd = is_filter_transformed_ || (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1 - && input_channels >= 8 && channels >= 8; + && input_channels >= 8 && channels >= 8); bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3 @@ -185,10 +191,17 @@ void Conv2dFunctor::operator()(const Tensor *input, std::vector transformed_output_shape; std::vector transformed_filter_shape; + // When size of input feature map is bigger than 16x16, + // set winograd out tile size to 6 to get higher performance. + index_t winograd_out_tile_size = 2; + if (input_height > 16 && input_width > 16) { + winograd_out_tile_size = 6; + } + if (use_winograd) { - extra_output_height = RoundUp(height, WINOGRAD_OUT_TILE_SIZE); + extra_output_height = RoundUp(height, winograd_out_tile_size); extra_input_height = std::max(padded_input_height, extra_output_height + 2); - extra_output_width = RoundUp(width, WINOGRAD_OUT_TILE_SIZE); + extra_output_width = RoundUp(width, winograd_out_tile_size); extra_input_width = std::max(padded_input_width, extra_output_width + 2); if (extra_input_height != padded_input_height) { pad_bottom += (extra_input_height - padded_input_height); @@ -197,11 +210,11 @@ void Conv2dFunctor::operator()(const Tensor *input, pad_right += (extra_input_width - padded_input_width); } - index_t tile_height_count = extra_output_height / WINOGRAD_OUT_TILE_SIZE; - index_t tile_width_count = extra_output_width / WINOGRAD_OUT_TILE_SIZE; + index_t tile_height_count = extra_output_height / winograd_out_tile_size; + index_t tile_width_count = extra_output_width / winograd_out_tile_size; index_t tile_count = tile_height_count * tile_width_count; index_t in_tile_area = - (WINOGRAD_OUT_TILE_SIZE + 2) * (WINOGRAD_OUT_TILE_SIZE + 2); + (winograd_out_tile_size + 2) * (winograd_out_tile_size + 2); transformed_input_shape.insert(transformed_input_shape.end(), {in_tile_area, batch, input_channels, @@ -281,25 +294,45 @@ void Conv2dFunctor::operator()(const Tensor *input, if (use_winograd) { transformed_input.Resize(transformed_input_shape); transformed_output.Resize(transformed_output_shape); - if (!is_filter_transformed_) { + const float *transformed_filter_ptr; + if (transformed_filter_.dim_size() == 0) { transformed_filter_.Resize(transformed_filter_shape); + if (is_filter_transformed_) { + transformed_filter_ptr = filter_data; + } else { + switch (winograd_out_tile_size) { + case 2: + TransformFilter4x4(filter_data, + filter_shape[1], + filter_shape[0], + transformed_filter_.mutable_data()); + break; + case 6: + TransformFilter8x8(filter_data, + filter_shape[1], + filter_shape[0], + transformed_filter_.mutable_data()); + break; + default:MACE_NOT_IMPLEMENTED; + } + transformed_filter_ptr = transformed_filter_.data(); + } + } else { + transformed_filter_ptr = transformed_filter_.data(); } conv_func = [&](const float *pad_input, float *pad_output) { WinoGradConv3x3s1(pad_input, - filter_data, + transformed_filter_ptr, batch, extra_input_height, extra_input_width, input_channels, channels, - WINOGRAD_OUT_TILE_SIZE, + winograd_out_tile_size, transformed_input.mutable_data(), - transformed_filter_.mutable_data(), transformed_output.mutable_data(), - is_filter_transformed_, pad_output); - is_filter_transformed_ = true; }; } else if (use_neon_3x3_s1) { conv_func = [=](const float *pad_input, float *pad_output) { diff --git a/mace/kernels/arm/conv_winograd.cc b/mace/kernels/arm/conv_winograd.cc index 11ef6947..e73061e3 100644 --- a/mace/kernels/arm/conv_winograd.cc +++ b/mace/kernels/arm/conv_winograd.cc @@ -224,153 +224,6 @@ void TransformInput8x8(const float *input, } } -// OCHW => TOC -// no need to optimize, it will exist in converter -void TransformFilter4x4(const float *filter, - const index_t in_channels, - const index_t out_channels, - float *output) { - const index_t stride = out_channels * in_channels; - -#pragma omp parallel for collapse(2) - for (index_t m = 0; m < out_channels; ++m) { - for (index_t c = 0; c < in_channels; ++c) { - float g0, g1, g2, g3, g4, g5, g6, g7, g8; - float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, - s15; - - // load filter - index_t filter_offset = (m * in_channels + c) * 9; - g0 = filter[filter_offset]; - g1 = filter[filter_offset + 1]; - g2 = filter[filter_offset + 2]; - g3 = filter[filter_offset + 3]; - g4 = filter[filter_offset + 4]; - g5 = filter[filter_offset + 5]; - g6 = filter[filter_offset + 6]; - g7 = filter[filter_offset + 7]; - g8 = filter[filter_offset + 8]; - - // s = G * g * GT - s0 = g0; - s1 = (g0 + g2 + g1) * 0.5f; - s2 = (g0 + g2 - g1) * 0.5f; - s3 = g2; - s4 = (g0 + g6 + g3) * 0.5f; - s5 = ((g0 + g6 + g3) + (g2 + g8 + g5) + (g1 + g7 + g4)) * 0.25f; - s6 = ((g0 + g6 + g3) + (g2 + g8 + g5) - (g1 + g7 + g4)) * 0.25f; - s7 = (g2 + g8 + g5) * 0.5f; - s8 = (g0 + g6 - g3) * 0.5f; - s9 = ((g0 + g6 - g3) + (g2 + g8 - g5) + (g1 + g7 - g4)) * 0.25f; - s10 = ((g0 + g6 - g3) + (g2 + g8 - g5) - (g1 + g7 - g4)) * 0.25f; - s11 = (g2 + g8 - g5) * 0.5f; - s12 = g6; - s13 = (g6 + g8 + g7) * 0.5f; - s14 = (g6 + g8 - g7) * 0.5f; - s15 = g8; - - // store output - index_t output_offset = m * in_channels + c; - output[output_offset + 0 * stride] = s0; - output[output_offset + 1 * stride] = s1; - output[output_offset + 2 * stride] = s2; - output[output_offset + 3 * stride] = s3; - - output[output_offset + 4 * stride] = s4; - output[output_offset + 5 * stride] = s5; - output[output_offset + 6 * stride] = s6; - output[output_offset + 7 * stride] = s7; - - output[output_offset + 8 * stride] = s8; - output[output_offset + 9 * stride] = s9; - output[output_offset + 10 * stride] = s10; - output[output_offset + 11 * stride] = s11; - - output[output_offset + 12 * stride] = s12; - output[output_offset + 13 * stride] = s13; - output[output_offset + 14 * stride] = s14; - output[output_offset + 15 * stride] = s15; - } - } -} - -// OCHW => TOC -// no need to optimize, it will exist in converter -/** - * G = -⎡ 1 0 0 ⎤ -⎢ ⎥ -⎢-2/9 -2/9 -2/9 ⎥ -⎢ ⎥ -⎢-2/9 2/9 -2/9 ⎥ -⎢ ⎥ -⎢1/90 1/45 2/45 ⎥ -⎢ ⎥ -⎢1/90 -1/45 2/45 ⎥ -⎢ ⎥ -⎢1/45 1/90 1/180⎥ -⎢ ⎥ -⎢1/45 -1/90 1/180⎥ -⎢ ⎥ -⎣ 0 0 1 ⎦ - * - * @param filter - * @param in_channels - * @param out_channels - * @param output - */ -void TransformFilter8x8(const float *filter, - const index_t in_channels, - const index_t out_channels, - float *output) { - const index_t stride = out_channels * in_channels; - - const float G[8][3] = { - {1.0f, 0.0f, 0.0f}, - {-2.0f / 9, -2.0f / 9, -2.0f / 9}, - {-2.0f / 9, 2.0f / 9, -2.0f / 9}, - {1.0f / 90, 1.0f / 45, 2.0f / 45}, - {1.0f / 90, -1.0f / 45, 2.0f / 45}, - {1.0f / 45, 1.0f / 90, 1.0f / 180}, - {1.0f / 45, -1.0f / 90, 1.0f / 180}, - {0.0f, 0.0f, 1.0f} - }; - -#pragma omp parallel for collapse(2) - for (index_t m = 0; m < out_channels; ++m) { - for (index_t c = 0; c < in_channels; ++c) { - // load filter - index_t filter_offset = (m * in_channels + c) * 9; - float g0, g1, g2, g3, g4, g5, g6, g7, g8; - g0 = filter[filter_offset]; - g1 = filter[filter_offset + 1]; - g2 = filter[filter_offset + 2]; - g3 = filter[filter_offset + 3]; - g4 = filter[filter_offset + 4]; - g5 = filter[filter_offset + 5]; - g6 = filter[filter_offset + 6]; - g7 = filter[filter_offset + 7]; - g8 = filter[filter_offset + 8]; - - float s[3][8]; - for (int i = 0; i < 8; ++i) { - s[0][i] = g0 * G[i][0] + g1 * G[i][1] + g2 * G[i][2]; - s[1][i] = g3 * G[i][0] + g4 * G[i][1] + g5 * G[i][2]; - s[2][i] = g6 * G[i][0] + g7 * G[i][1] + g8 * G[i][2]; - } - - // store output - index_t output_offset = m * in_channels + c; - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - output[output_offset + (i * 8 + j) * stride] = - G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j]; - } - } - } - } -} - // TOC * TNCB => TNOB void BatchGemm(const float *input, const float *filter, @@ -581,8 +434,156 @@ void TransformOutput8x8(const float *input, } } // namespace + +// OCHW => TOC +// no need to optimize, it will exist in converter +void TransformFilter4x4(const float *filter, + const index_t in_channels, + const index_t out_channels, + float *output) { + const index_t stride = out_channels * in_channels; + +#pragma omp parallel for collapse(2) + for (index_t m = 0; m < out_channels; ++m) { + for (index_t c = 0; c < in_channels; ++c) { + float g0, g1, g2, g3, g4, g5, g6, g7, g8; + float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, + s15; + + // load filter + index_t filter_offset = (m * in_channels + c) * 9; + g0 = filter[filter_offset]; + g1 = filter[filter_offset + 1]; + g2 = filter[filter_offset + 2]; + g3 = filter[filter_offset + 3]; + g4 = filter[filter_offset + 4]; + g5 = filter[filter_offset + 5]; + g6 = filter[filter_offset + 6]; + g7 = filter[filter_offset + 7]; + g8 = filter[filter_offset + 8]; + + // s = G * g * GT + s0 = g0; + s1 = (g0 + g2 + g1) * 0.5f; + s2 = (g0 + g2 - g1) * 0.5f; + s3 = g2; + s4 = (g0 + g6 + g3) * 0.5f; + s5 = ((g0 + g6 + g3) + (g2 + g8 + g5) + (g1 + g7 + g4)) * 0.25f; + s6 = ((g0 + g6 + g3) + (g2 + g8 + g5) - (g1 + g7 + g4)) * 0.25f; + s7 = (g2 + g8 + g5) * 0.5f; + s8 = (g0 + g6 - g3) * 0.5f; + s9 = ((g0 + g6 - g3) + (g2 + g8 - g5) + (g1 + g7 - g4)) * 0.25f; + s10 = ((g0 + g6 - g3) + (g2 + g8 - g5) - (g1 + g7 - g4)) * 0.25f; + s11 = (g2 + g8 - g5) * 0.5f; + s12 = g6; + s13 = (g6 + g8 + g7) * 0.5f; + s14 = (g6 + g8 - g7) * 0.5f; + s15 = g8; + + // store output + index_t output_offset = m * in_channels + c; + output[output_offset + 0 * stride] = s0; + output[output_offset + 1 * stride] = s1; + output[output_offset + 2 * stride] = s2; + output[output_offset + 3 * stride] = s3; + + output[output_offset + 4 * stride] = s4; + output[output_offset + 5 * stride] = s5; + output[output_offset + 6 * stride] = s6; + output[output_offset + 7 * stride] = s7; + + output[output_offset + 8 * stride] = s8; + output[output_offset + 9 * stride] = s9; + output[output_offset + 10 * stride] = s10; + output[output_offset + 11 * stride] = s11; + + output[output_offset + 12 * stride] = s12; + output[output_offset + 13 * stride] = s13; + output[output_offset + 14 * stride] = s14; + output[output_offset + 15 * stride] = s15; + } + } +} + +// OCHW => TOC +// no need to optimize, it will exist in converter +/** + * G = +⎡ 1 0 0 ⎤ +⎢ ⎥ +⎢-2/9 -2/9 -2/9 ⎥ +⎢ ⎥ +⎢-2/9 2/9 -2/9 ⎥ +⎢ ⎥ +⎢1/90 1/45 2/45 ⎥ +⎢ ⎥ +⎢1/90 -1/45 2/45 ⎥ +⎢ ⎥ +⎢1/45 1/90 1/180⎥ +⎢ ⎥ +⎢1/45 -1/90 1/180⎥ +⎢ ⎥ +⎣ 0 0 1 ⎦ + * + * @param filter + * @param in_channels + * @param out_channels + * @param output + */ +void TransformFilter8x8(const float *filter, + const index_t in_channels, + const index_t out_channels, + float *output) { + const index_t stride = out_channels * in_channels; + + const float G[8][3] = { + {1.0f, 0.0f, 0.0f}, + {-2.0f / 9, -2.0f / 9, -2.0f / 9}, + {-2.0f / 9, 2.0f / 9, -2.0f / 9}, + {1.0f / 90, 1.0f / 45, 2.0f / 45}, + {1.0f / 90, -1.0f / 45, 2.0f / 45}, + {1.0f / 45, 1.0f / 90, 1.0f / 180}, + {1.0f / 45, -1.0f / 90, 1.0f / 180}, + {0.0f, 0.0f, 1.0f} + }; + +#pragma omp parallel for collapse(2) + for (index_t m = 0; m < out_channels; ++m) { + for (index_t c = 0; c < in_channels; ++c) { + // load filter + index_t filter_offset = (m * in_channels + c) * 9; + float g0, g1, g2, g3, g4, g5, g6, g7, g8; + g0 = filter[filter_offset]; + g1 = filter[filter_offset + 1]; + g2 = filter[filter_offset + 2]; + g3 = filter[filter_offset + 3]; + g4 = filter[filter_offset + 4]; + g5 = filter[filter_offset + 5]; + g6 = filter[filter_offset + 6]; + g7 = filter[filter_offset + 7]; + g8 = filter[filter_offset + 8]; + + float s[3][8]; + for (int i = 0; i < 8; ++i) { + s[0][i] = g0 * G[i][0] + g1 * G[i][1] + g2 * G[i][2]; + s[1][i] = g3 * G[i][0] + g4 * G[i][1] + g5 * G[i][2]; + s[2][i] = g6 * G[i][0] + g7 * G[i][1] + g8 * G[i][2]; + } + + // store output + index_t output_offset = m * in_channels + c; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + output[output_offset + (i * 8 + j) * stride] = + G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j]; + } + } + } + } +} + void WinoGradConv3x3s1(const float *input, - const float *filter, + const float *transformed_filter, const index_t batch, const index_t in_height, const index_t in_width, @@ -590,9 +591,7 @@ void WinoGradConv3x3s1(const float *input, const index_t out_channels, const int out_tile_size, float *transformed_input, - float *transformed_filter, float *transformed_output, - bool is_filter_transformed, float *output) { index_t out_height = in_height - 2; index_t out_width = in_width - 2; @@ -624,26 +623,6 @@ void WinoGradConv3x3s1(const float *input, default:MACE_NOT_IMPLEMENTED; } - // TODO(liyin): put it in model converter, but do not worry, it is fast and - // will only do once - if (!is_filter_transformed) { - switch (out_tile_size) { - case 2: - TransformFilter4x4(filter, - in_channels, - out_channels, - transformed_filter); - break; - case 6: - TransformFilter8x8(filter, - in_channels, - out_channels, - transformed_filter); - break; - default:MACE_NOT_IMPLEMENTED; - } - } - BatchGemm(transformed_input, transformed_filter, batch, @@ -703,8 +682,24 @@ void WinoGradConv3x3s1(const float *input, float *transformed_filter = new float[transformed_filter_size]; // TOC float *transformed_output = new float[transformed_output_size]; + switch (out_tile_size) { + case 2: + TransformFilter4x4(filter, + in_channels, + out_channels, + transformed_filter); + break; + case 6: + TransformFilter8x8(filter, + in_channels, + out_channels, + transformed_filter); + break; + default:MACE_NOT_IMPLEMENTED; + } + WinoGradConv3x3s1(input, - filter, + transformed_filter, batch, in_height, in_width, @@ -712,9 +707,7 @@ void WinoGradConv3x3s1(const float *input, out_channels, out_tile_size, transformed_input, - transformed_filter, transformed_output, - false, output); delete[]transformed_input; diff --git a/mace/kernels/arm/conv_winograd.h b/mace/kernels/arm/conv_winograd.h index de45aa08..558fea9d 100644 --- a/mace/kernels/arm/conv_winograd.h +++ b/mace/kernels/arm/conv_winograd.h @@ -24,6 +24,16 @@ namespace mace { namespace kernels { +void TransformFilter4x4(const float *filter, + const index_t in_channels, + const index_t out_channels, + float *output); + +void TransformFilter8x8(const float *filter, + const index_t in_channels, + const index_t out_channels, + float *output); + void WinoGradConv3x3s1(const float *input, const float *filter, const index_t batch, @@ -35,7 +45,7 @@ void WinoGradConv3x3s1(const float *input, float *output); void WinoGradConv3x3s1(const float *input, - const float *filter, + const float *transformed_filter, const index_t batch, const index_t in_height, const index_t in_width, @@ -43,9 +53,7 @@ void WinoGradConv3x3s1(const float *input, const index_t out_channels, const int out_tile_size, float *transformed_input, - float *transformed_filter, float *transformed_output, - bool is_filter_transformed, float *output); void ConvRef3x3s1(const float *input, diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index cefa520f..63f200f9 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -308,6 +308,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { const int *dilations, const ActivationType activation, const float relux_max_limit, + const bool is_filter_transformed, ScratchBuffer *scratch) : Conv2dFunctorBase(strides, padding_type, @@ -317,7 +318,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { relux_max_limit) {} void operator()(const Tensor *input, // NHWC - const Tensor *filter, // HWOI + const Tensor *filter, // HWOI or TOI const Tensor *bias, Tensor *output, StatsFuture *future) { @@ -434,6 +435,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { const int *dilations, const ActivationType activation, const float relux_max_limit, + const bool is_filter_transformed, ScratchBuffer *scratch) : Conv2dFunctorBase(strides, padding_type, @@ -441,7 +443,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { dilations, activation, relux_max_limit), - is_filter_transformed_(false), + is_filter_transformed_(is_filter_transformed), scratch_(scratch) {} void operator()(const Tensor *input, @@ -463,6 +465,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { const int *dilations, const ActivationType activation, const float relux_max_limit, + const bool is_filter_transformed, ScratchBuffer *scratch) : Conv2dFunctorBase(strides, padding_type, diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index 5f692143..690ef002 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -35,6 +35,8 @@ class Conv2dOp : public ConvPool2dOpBase { this->dilations_.data(), kernels::ActivationType::NOOP, 0.0f, + static_cast(OperatorBase::GetSingleArgument( + "is_filter_transformed", false)), ws->GetScratchBuffer(D)) {} bool Run(StatsFuture *future) override { diff --git a/mace/ops/fused_conv_2d.h b/mace/ops/fused_conv_2d.h index c58854cf..a2a255ef 100644 --- a/mace/ops/fused_conv_2d.h +++ b/mace/ops/fused_conv_2d.h @@ -38,6 +38,8 @@ class FusedConv2dOp : public ConvPool2dOpBase { OperatorBase::GetSingleArgument("activation", "NOOP")), OperatorBase::GetSingleArgument("max_limit", 0.0f), + static_cast(OperatorBase::GetSingleArgument( + "is_filter_transformed", false)), ws->GetScratchBuffer(D)) {} bool Run(StatsFuture *future) override { diff --git a/mace/python/tools/caffe_converter_lib.py b/mace/python/tools/caffe_converter_lib.py index 6f8a95cd..4086cc87 100644 --- a/mace/python/tools/caffe_converter_lib.py +++ b/mace/python/tools/caffe_converter_lib.py @@ -374,6 +374,10 @@ class CaffeConverter(object): return pad, stride, kernel def convert_conv2d(self, op): + use_winograd = False + if self.device == 'neon': + use_winograd = self.check_winograd_conv(op) + param = op.layer.convolution_param is_depthwise = False if param.HasField('group'): @@ -394,7 +398,11 @@ class CaffeConverter(object): else: # OIHW -> HWOI weight_data = op.data[0].transpose((2, 3, 0, 1)) - self.add_tensor(weight_tensor_name, weight_data) + + if self.device == 'neon' and use_winograd: + self.convert_winograd_conv_filter_neon(op, op_def) + else: + self.add_tensor(weight_tensor_name, weight_data) if self.device == 'gpu': buffer_type = "DW_CONV2D_FILTER" \ @@ -438,7 +446,7 @@ class CaffeConverter(object): op.output_shape_map[op.layer.top[0]] = output_shape if len(self.ops_map[final_op.name].children) == 1 and \ - self.ops_map[final_op.name].children[0].type \ + self.ops_map[final_op.name].children[0].type \ in activation_name_map: activation_op = self.ops_map[final_op.name].children[0] if not is_depthwise: @@ -455,15 +463,18 @@ class CaffeConverter(object): self.net_def.op.extend([op_def]) def check_winograd_conv(self, op): - # TODO: support winograd conv on neon - if self.device == 'neon': - return False param = op.layer.convolution_param filter_shape = np.asarray(op.data[0].shape) if self.device != 'neon': filter_shape = filter_shape[[2, 3, 0, 1]] # OIHW -> HWOI paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None) + if param.HasField('group'): + if param.group == op.data[0].shape[0] and op.data[0].shape[1] == 1: + return False # Depthwise conv not support winograd + else: + raise Exception("Mace do not support group convolution yet") + dilations = [1, 1] if len(param.dilation) > 0: if len(param.dilation) == 1: @@ -476,23 +487,60 @@ class CaffeConverter(object): op.get_single_parent().output_shape_map[op.layer.bottom[0]], filter_shape, paddings, strides, dilations, math.floor, input_format) - width = output_shape[0] * ((output_shape[1] + 1) / 2) * (( - output_shape[2] + 1) / 2) if self.winograd and dilations[0] == 1 and \ (dilations[0] == dilations[1]) and \ (strides[0] == 1) and (strides[0] == strides[1]): if self.device == 'gpu': + width = output_shape[0] * ((output_shape[1] + 1) / 2) * \ + ((output_shape[2] + 1) / 2) return filter_shape[0] == 3 and \ - (filter_shape[0] == filter_shape[1]) and \ - (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ - (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ - (width < OPENCL_IMAGE_MAX_SIZE) + filter_shape[0] == filter_shape[1] and \ + (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ + (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ + (width < OPENCL_IMAGE_MAX_SIZE) elif self.device == 'neon': - return filter_shape[2] == 3 and ( - filter_shape[2] == filter_shape[3]) + return filter_shape[2] == 3 and \ + filter_shape[2] == filter_shape[3] and \ + filter_shape[0] >= 8 and filter_shape[1] >= 8 return False - def convert_winograd_conv(self, op): + def convert_winograd_conv_filter_neon(self, op, op_def): + # Add filter + weight_tensor_name = op.name + '_weight:0' + weight_data = op.data[0] # OIHW + input_shape = op.data[1].shape + if input_shape[2] > 16 and input_shape[3] > 16: + G = np.array([ + [1.0, 0.0, 0.0], + [-2.0 / 9, -2.0 / 9, -2.0 / 9], + [-2.0 / 9, 2.0 / 9, -2.0 / 9], + [1.0 / 90, 1.0 / 45, 2.0 / 45], + [1.0 / 90, -1.0 / 45, 2.0 / 45], + [1.0 / 45, 1.0 / 90, 1.0 / 180], + [1.0 / 45, -1.0 / 90, 1.0 / 180], + [0.0, 0.0, 1.0] + ], dtype=np.float32) + new_shape = [64, weight_data.shape[0], weight_data.shape[1]] # TOC + else: + G = np.array([ + [1.0, 0.0, 0.0], + [0.5, 0.5, 0.5], + [0.5, -0.5, 0.5], + [0.0, 0.0, 1.0], + ], dtype=np.float32) + new_shape = [16, weight_data.shape[0], weight_data.shape[1]] # TOC + new_weight_value = G.dot(weight_data).dot(G.T) # [8, O, I, 8] + new_weight_value = new_weight_value.transpose(0, 3, 1, 2) + new_weight_value = new_weight_value.reshape(new_shape) + + self.add_tensor(weight_tensor_name, new_weight_value) + + op_def.input.extend([weight_tensor_name]) + winograd_transformed_arg = op_def.arg.add() + winograd_transformed_arg.name = 'is_filter_transformed' + winograd_transformed_arg.i = 1 + + def convert_winograd_conv_gpu(self, op): # Add filter weight_tensor_name = op.name + '_weight:0' self.add_tensor(weight_tensor_name, op.data[0]) @@ -504,10 +552,8 @@ class CaffeConverter(object): paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None) filter_shape = np.asarray(op.data[0].shape) - if self.device != 'neon': - filter_shape = filter_shape[[2, 3, 0, 1]] # OIHW -> HWOI - input_format = 'NCHW' if self.device == 'neon' else 'NHWC' + input_format = 'NHWC' output_shape = Shapes.conv_pool_shape( op.get_single_parent().output_shape_map[op.layer.bottom[0]], filter_shape, paddings, strides, [1, 1], math.floor, input_format) @@ -526,16 +572,10 @@ class CaffeConverter(object): wt_output_name = wt_op.name + ":0" wt_op.output.extend([wt_output_name]) wt_output_shape = mace_pb2.OutputShape() - if self.device != 'neon': - wt_output_width = output_shape[0] * (( - output_shape[1] + 1) / 2) * ((output_shape[2] + 1) / 2) - wt_output_shape.dims.extend( - [16, filter_shape[3], wt_output_width, 1]) - else: - wt_output_width = output_shape[0] * (( - output_shape[2] + 1) / 2) * ((output_shape[3] + 1) / 2) - wt_output_shape.dims.extend( - [16, filter_shape[1], wt_output_width, 1]) + wt_output_width = output_shape[0] * (( + output_shape[1] + 1) / 2) * ((output_shape[2] + 1) / 2) + wt_output_shape.dims.extend( + [16, filter_shape[3], wt_output_width, 1]) wt_op.output_shape.extend([wt_output_shape]) # MatMul @@ -549,12 +589,8 @@ class CaffeConverter(object): matmul_output_name = matmul_op.name + ":0" matmul_op.output.extend([matmul_output_name]) matmul_output_shape = mace_pb2.OutputShape() - if self.device != 'neon': - matmul_output_shape.dims.extend( - [16, filter_shape[2], wt_output_width, 1]) - else: - matmul_output_shape.dims.extend( - [16, filter_shape[0], wt_output_width, 1]) + matmul_output_shape.dims.extend( + [16, filter_shape[2], wt_output_width, 1]) matmul_op.output_shape.extend([matmul_output_shape]) # Inverse transform @@ -567,12 +603,10 @@ class CaffeConverter(object): batch_arg.i = output_shape[0] height_arg = iwt_op.arg.add() height_arg.name = 'height' - height_arg.i = output_shape[ - 1] if self.device != 'neon' else output_shape[2] + height_arg.i = output_shape[1] width_arg = iwt_op.arg.add() width_arg.name = 'width' - width_arg.i = output_shape[ - 2] if self.device != 'neon' else output_shape[3] + width_arg.i = output_shape[2] iwt_op.name = op.name + '_inverse_transform' iwt_op.type = 'WinogradInverseTransform' iwt_op.input.extend([matmul_output_name]) @@ -591,7 +625,7 @@ class CaffeConverter(object): self.resolved_ops.add(op.name) if len(self.ops_map[final_op.name].children) == 1 and \ - self.ops_map[final_op.name].children[0].type \ + self.ops_map[final_op.name].children[0].type \ in activation_name_map: activation_op = self.ops_map[final_op.name].children[0] fused_act_arg = iwt_op.arg.add() @@ -645,8 +679,8 @@ class CaffeConverter(object): output_shape = op.get_single_parent().output_shape_map[op.layer.bottom[ 0]] - if len(self.ops_map[final_op.name].children) == 1 \ - and self.ops_map[final_op.name].children[0].type \ + if len(self.ops_map[final_op.name].children) == 1 and \ + self.ops_map[final_op.name].children[0].type \ in activation_name_map: activation_op = self.ops_map[final_op.name].children[0] fused_act_arg = op_def.arg.add() @@ -727,13 +761,15 @@ class CaffeConverter(object): op_def.input.extend([bias_tensor_name]) self.resolved_ops.add(op.name) + input_format = 'NCHW' if self.device == 'neon' else 'NHWC' output_shape = Shapes.fully_connected_shape(input_shape, - weight_data.shape) + weight_data.shape, + input_format) op.output_shape_map[op.layer.top[0]] = output_shape final_op = op if len(self.ops_map[final_op.name].children) == 1 \ - and self.ops_map[final_op.name].children[0].type \ + and self.ops_map[final_op.name].children[0].type \ in activation_name_map: activation_op = self.ops_map[final_op.name].children[0] fused_act_arg = op_def.arg.add() @@ -764,7 +800,7 @@ class CaffeConverter(object): input_shape = op.get_single_parent().output_shape_map[op.layer.bottom[ 0]] if param.HasField('global_pooling') and param.global_pooling: - kernels = [input_shape[1], input_shape[2]] + kernels = [input_shape[2], input_shape[3]] kernel_arg = op_def.arg.add() kernel_arg.name = 'kernels' @@ -1054,8 +1090,8 @@ class CaffeConverter(object): if op.type == 'Input': self.resolved_ops.add(op.name) elif op.type == 'Convolution': - if self.check_winograd_conv(op): - self.convert_winograd_conv(op) + if self.device == 'gpu' and self.check_winograd_conv(op): + self.convert_winograd_conv_gpu(op) else: self.convert_conv2d(op) elif op.type == 'BatchNorm': diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 56b3f04d..c50766cb 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -257,15 +257,19 @@ class TFConverter(object): return False width = output_shape[0] * ((output_shape[1] + 1) / 2) * (( output_shape[2] + 1) / 2) - return self.winograd and op.type != 'DepthwiseConv2dNative' and \ - self.device == 'gpu' and filter_shape[0] == 3 and \ - (filter_shape[0] == filter_shape[1]) and \ - (strides[0] == 1) and (strides[0] == strides[1]) and \ - (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ - (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ - (width < OPENCL_IMAGE_MAX_SIZE) - - def convert_winograd_conv(self, op): + if self.winograd and op.type != 'DepthwiseConv2dNative' and \ + filter_shape[0] == 3 and \ + (filter_shape[0] == filter_shape[1]) and \ + (strides[0] == 1) and (strides[0] == strides[1]): + if self.device == 'gpu': + return (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ + (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ + (width < OPENCL_IMAGE_MAX_SIZE) + elif self.device == 'neon': + return filter_shape[2] >= 8 and filter_shape[3] >= 8 + return False + + def convert_winograd_conv_gpu(self, op): filter_tensor = get_input_tensor(op, 1) filter_shape = filter_tensor.shape.as_list() output_shape = op.outputs[0].shape.as_list() @@ -355,7 +359,55 @@ class TFConverter(object): self.add_output_shape(final_op.outputs, iwt_op) self.net_def.op.extend([wt_op, matmul_op, iwt_op]) + def convert_conv_winograd_filter_neon(self, op, op_def): + weight_tensor = get_input_tensor(op, 1) + weight_tensor_value = weight_tensor.eval().astype(np.float32) + input_shape = get_input_tensor(op, 0).shape.as_list() + output_channels = weight_tensor_value.shape[3] + input_channels = weight_tensor_value.shape[2] + # HWIO -> OIHW + weight_tensor_value = weight_tensor_value.transpose(3, 2, 0, 1) + if input_shape[2] > 16 and input_shape[3] > 16: + G = np.array([ + [1.0, 0.0, 0.0], + [-2.0 / 9, -2.0 / 9, -2.0 / 9], + [-2.0 / 9, 2.0 / 9, -2.0 / 9], + [1.0 / 90, 1.0 / 45, 2.0 / 45], + [1.0 / 90, -1.0 / 45, 2.0 / 45], + [1.0 / 45, 1.0 / 90, 1.0 / 180], + [1.0 / 45, -1.0 / 90, 1.0 / 180], + [0.0, 0.0, 1.0] + ], dtype=np.float32) + new_shape = [64, output_channels, input_channels] # TOC + else: + G = np.array([ + [1.0, 0.0, 0.0], + [0.5, 0.5, 0.5], + [0.5, -0.5, 0.5], + [0.0, 0.0, 1.0], + ], dtype=np.float32) + new_shape = [16, output_channels, input_channels] # TOC + new_weight_value = G.dot(weight_tensor_value).dot(G.T) # [t, O, I, t] + new_weight_value = new_weight_value.transpose(0, 3, 1, 2) + + new_weight_value = new_weight_value.reshape(new_shape) + new_tensor_name = weight_tensor.name[:-2] + '/winograd_transformed:0' + self.add_tensor(new_tensor_name, new_shape, + tf.float32, new_weight_value) + + winograd_transformed_arg = op_def.arg.add() + winograd_transformed_arg.name = 'is_filter_transformed' + winograd_transformed_arg.i = 1 + + self.unused_tensor.add(weight_tensor.name) + op_def.input.extend([op.inputs[0].name]) + op_def.input.extend([new_tensor_name]) + def convert_conv2d(self, op): + use_winograd = False + if self.device == 'neon': + use_winograd = self.check_winograd_conv(op) + op_def = mace_pb2.OperatorDef() arg = op_def.arg.add() arg.name = 'T' @@ -366,7 +418,7 @@ class TFConverter(object): else: op_def.type = op.type - if self.device == 'neon': + if self.device == 'neon' and not use_winograd: self.transpose_filter_tensor[get_input_tensor( op, 1).name] = (3, 2, 0, 1) elif op.type == 'Conv2D': @@ -381,6 +433,8 @@ class TFConverter(object): output_name = self.add_buffer_to_image( get_input_tensor(op, 1).name, buffer_type) op_def.input.extend([output_name]) + elif self.device == 'neon' and use_winograd: + self.convert_conv_winograd_filter_neon(op, op_def) else: op_def.input.extend( [get_input_tensor(op, i).name for i in range(len(op.inputs))]) @@ -1057,8 +1111,8 @@ class TFConverter(object): elif self.check_conv_to_fc(op): self.convert_global_conv_to_fc(op) elif op.type == 'Conv2D' or op.type == 'DepthwiseConv2dNative': - if self.check_winograd_conv(op): - self.convert_winograd_conv(op) + if self.device == 'gpu' and self.check_winograd_conv(op): + self.convert_winograd_conv_gpu(op) else: self.convert_conv2d(op) elif op.type == 'FusedBatchNorm': -- GitLab