From 961dddd9dddee20be4ebaaa20c15cf01b77ca749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 12 Sep 2018 19:21:11 +0800 Subject: [PATCH] Replace gemm to sgemm --- mace/core/buffer.h | 9 +++- mace/kernels/arm/conv_2d_neon.h | 5 ++- mace/kernels/arm/conv_2d_neon_1x1.cc | 21 ++++++--- mace/kernels/arm/conv_winograd.cc | 62 +++++++++++++++----------- mace/kernels/arm/conv_winograd.h | 9 +++- mace/kernels/arm/conv_winograd_test.cc | 3 +- mace/kernels/conv_2d.h | 19 +++++++- mace/kernels/matmul.h | 5 ++- mace/kernels/sgemm.cc | 2 - 9 files changed, 93 insertions(+), 42 deletions(-) diff --git a/mace/core/buffer.h b/mace/core/buffer.h index c57a1714..ba43e96c 100644 --- a/mace/core/buffer.h +++ b/mace/core/buffer.h @@ -469,6 +469,7 @@ class ScratchBuffer: public Buffer { MaceStatus GrowSize(index_t size) { if (size > size_) { + MACE_CHECK(offset_ == 0, "scratch is being used, cannot grow size"); return Resize(size); } return MaceStatus::MACE_SUCCESS; @@ -487,8 +488,12 @@ class ScratchBuffer: public Buffer { return slice; } - void Rewind() { - offset_ = 0; + void Rewind(index_t offset = 0) { + offset_ = offset; + } + + index_t offset() const { + return offset_; } private: diff --git a/mace/kernels/arm/conv_2d_neon.h b/mace/kernels/arm/conv_2d_neon.h index 7c7f7a77..bf0e1023 100644 --- a/mace/kernels/arm/conv_2d_neon.h +++ b/mace/kernels/arm/conv_2d_neon.h @@ -16,6 +16,7 @@ #define MACE_KERNELS_ARM_CONV_2D_NEON_H_ #include "mace/core/types.h" +#include "mace/kernels/sgemm.h" namespace mace { namespace kernels { @@ -27,7 +28,9 @@ void Conv2dNeonK1x1S1(const float *input, const index_t width, const index_t in_channels, const index_t out_channels, - float *output); + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer); void Conv2dNeonK3x3S1(const float *input, const float *filter, diff --git a/mace/kernels/arm/conv_2d_neon_1x1.cc b/mace/kernels/arm/conv_2d_neon_1x1.cc index 28aa6e4f..21554d90 100644 --- a/mace/kernels/arm/conv_2d_neon_1x1.cc +++ b/mace/kernels/arm/conv_2d_neon_1x1.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "mace/kernels/arm/conv_2d_neon.h" -#include "mace/kernels/gemm.h" namespace mace { namespace kernels { @@ -25,11 +24,23 @@ void Conv2dNeonK1x1S1(const float *input, const index_t width, const index_t in_channels, const index_t out_channels, - float *output) { + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer) { for (index_t b = 0; b < batch; ++b) { - Gemm(filter, input + b * in_channels * height * width, 1, out_channels, - in_channels, height * width, - output + b * out_channels * height * width); + sgemm->Run(filter, + input + b * in_channels * height * width, + 1, + out_channels, + in_channels, + in_channels, + height * width, + false, + false, + true, + false, + output + b * out_channels * height * width, + scratch_buffer); } } diff --git a/mace/kernels/arm/conv_winograd.cc b/mace/kernels/arm/conv_winograd.cc index b074ced7..d115e4e5 100644 --- a/mace/kernels/arm/conv_winograd.cc +++ b/mace/kernels/arm/conv_winograd.cc @@ -12,13 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include "mace/kernels/arm/conv_winograd.h" #include "mace/kernels/gemm.h" -#include "mace/utils/logging.h" -#include "mace/utils/utils.h" namespace mace { namespace kernels { @@ -247,30 +244,38 @@ void BatchGemm(const float *input, index_t out_channels, index_t tile_count, int out_tile_size, - float *output) { - const index_t filter_stride = out_channels * in_channels; + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer) { const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2); const index_t in_batch_size = in_tile_area * in_channels * tile_count; - const index_t in_stride = in_channels * tile_count; const index_t out_batch_size = in_tile_area * out_channels * tile_count; - const index_t out_stride = out_channels * tile_count; - if (batch == 1) { - Gemm(filter, input, in_tile_area, out_channels, in_channels, tile_count, - output); - } else { -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int i = 0; i < in_tile_area; ++i) { - const float *in_ptr = input + b * in_batch_size + i * in_stride; - const float *filter_ptr = filter + i * filter_stride; - float *out_ptr = output + b * out_batch_size + i * out_stride; - Gemm(filter_ptr, in_ptr, 1, out_channels, /* rows */ - in_channels, /* K */ - tile_count, /* cols */ - out_ptr); - } + index_t scratch_buffer_offset = 0; + if (scratch_buffer) { + scratch_buffer_offset = scratch_buffer->offset(); + } + // 'batch' is not gemm batch, 'in_tile_area' is. gemm is not thread safe, + // so we loop batch using single thread. + // Scratch buffer should be rewind to the initial position to use same + // scratch memory for each batch. + for (int b = 0; b < batch; ++b) { + if (scratch_buffer) { + scratch_buffer->Rewind(scratch_buffer_offset); } + sgemm->Run(filter, + input + b * in_batch_size, + in_tile_area, + out_channels, + in_channels, + in_channels, + tile_count, + false, + false, + true, + false, + output + b * out_batch_size, + scratch_buffer); } } @@ -613,7 +618,9 @@ void WinoGradConv3x3s1(const float *input, const int out_tile_size, float *transformed_input, float *transformed_output, - float *output) { + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer) { index_t out_height = in_height - 2; index_t out_width = in_width - 2; index_t tile_height_count = @@ -636,7 +643,8 @@ void WinoGradConv3x3s1(const float *input, } BatchGemm(transformed_input, transformed_filter, batch, in_channels, - out_channels, tile_count, out_tile_size, transformed_output); + out_channels, tile_count, out_tile_size, transformed_output, + sgemm, scratch_buffer); switch (out_tile_size) { case 2: @@ -660,7 +668,9 @@ void WinoGradConv3x3s1(const float *input, const index_t in_channels, const index_t out_channels, const int out_tile_size, - float *output) { + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer) { index_t out_height = in_height - 2; index_t out_width = in_width - 2; index_t tile_height_count = @@ -692,7 +702,7 @@ void WinoGradConv3x3s1(const float *input, WinoGradConv3x3s1(input, transformed_filter, batch, in_height, in_width, in_channels, out_channels, out_tile_size, transformed_input, - transformed_output, output); + transformed_output, output, sgemm, scratch_buffer); delete[] transformed_input; delete[] transformed_filter; diff --git a/mace/kernels/arm/conv_winograd.h b/mace/kernels/arm/conv_winograd.h index 558fea9d..7e274b77 100644 --- a/mace/kernels/arm/conv_winograd.h +++ b/mace/kernels/arm/conv_winograd.h @@ -20,6 +20,7 @@ #endif #include "mace/core/types.h" +#include "mace/kernels/sgemm.h" namespace mace { namespace kernels { @@ -42,7 +43,9 @@ void WinoGradConv3x3s1(const float *input, const index_t in_channels, const index_t out_channels, const int out_tile_size, - float *output); + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer); void WinoGradConv3x3s1(const float *input, const float *transformed_filter, @@ -54,7 +57,9 @@ void WinoGradConv3x3s1(const float *input, const int out_tile_size, float *transformed_input, float *transformed_output, - float *output); + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer); void ConvRef3x3s1(const float *input, const float *filter, diff --git a/mace/kernels/arm/conv_winograd_test.cc b/mace/kernels/arm/conv_winograd_test.cc index 13135432..ccb4f118 100644 --- a/mace/kernels/arm/conv_winograd_test.cc +++ b/mace/kernels/arm/conv_winograd_test.cc @@ -65,9 +65,10 @@ TEST(ConvWinogradTest, winograd) { kernels::ConvRef3x3s1(input_data, filter_data, batch, in_height, in_width, in_channels, out_channels, output_data_ref); + SGemm sgemm; kernels::WinoGradConv3x3s1(input_data, filter_data, batch, in_height, in_width, in_channels, out_channels, 6, - output_data); + output_data, &sgemm, nullptr); // test for (index_t i = 0; i < output_size; ++i) { diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 024644f3..0568d9b3 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -483,6 +483,16 @@ struct Conv2dFunctor : Conv2dFunctorBase { * sizeof(float); total_scratch_size += padded_output_size; } + // scratch for sgemm + if (use_neon_1x1_s1) { + total_scratch_size += + (input_batch * input_height * input_width + * (input_channels + channels)) * sizeof(float); + } else if (use_winograd) { + total_scratch_size += + (transformed_input_size + transformed_output_size) * sizeof(float); + } + // Init scratch buffer scratch_->Rewind(); scratch_->GrowSize(total_scratch_size); @@ -547,7 +557,9 @@ struct Conv2dFunctor : Conv2dFunctorBase { winograd_out_tile_size, transformed_input_data, transformed_output_data, - pad_output); + pad_output, + &sgemm_, + scratch_); }; } else if (use_neon_3x3_s1) { conv_func = [=](const float *pad_input, float *pad_output) { @@ -574,7 +586,9 @@ struct Conv2dFunctor : Conv2dFunctorBase { extra_input_width, input_channels, channels, - pad_output); + pad_output, + &sgemm_, + scratch_); }; } else if (use_neon_5x5_s1) { conv_func = [=](const float *pad_input, float *pad_output) { @@ -722,6 +736,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { Tensor transformed_filter_; bool is_filter_transformed_; ScratchBuffer *scratch_; + SGemm sgemm_; }; template<> diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index 9c5292d2..6f239de2 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -89,6 +89,9 @@ struct MatMulFunctor : OpKernel { const index_t height_b = B->dim(rank - 2); const index_t width_b = B->dim(rank - 1); + auto scratch_buffer = context_->workspace()->GetScratchBuffer(D); + scratch_buffer->Rewind(); + sgemm_.Run(a_ptr_base, b_ptr_base, batch, @@ -101,7 +104,7 @@ struct MatMulFunctor : OpKernel { A->is_weight(), B->is_weight(), c_ptr_base, - context_->workspace()->GetScratchBuffer(D)); + scratch_buffer); return MACE_SUCCESS; } diff --git a/mace/kernels/sgemm.cc b/mace/kernels/sgemm.cc index ab8310d1..825172a2 100644 --- a/mace/kernels/sgemm.cc +++ b/mace/kernels/sgemm.cc @@ -44,7 +44,6 @@ void SGemm::operator()(const MatrixMap &lhs, } if (scratch_buffer != nullptr) { - scratch_buffer->Rewind(); index_t total_size = result->size(); if (!lhs.is_const()) { total_size += lhs.size(); @@ -54,7 +53,6 @@ void SGemm::operator()(const MatrixMap &lhs, } scratch_buffer->GrowSize(total_size * sizeof(float)); - scratch_buffer->Rewind(); if (!lhs.is_const()) { packed_lhs_.reset(new Tensor(scratch_buffer->Scratch( lhs.size() * sizeof(float)), DT_FLOAT)); -- GitLab