diff --git a/mace/core/buffer.h b/mace/core/buffer.h index c57a1714aa91e469e5e2d6ec6de392f8ca868821..ba43e96c4da03e1f77987e6e0cf5be03b02f2595 100644 --- a/mace/core/buffer.h +++ b/mace/core/buffer.h @@ -469,6 +469,7 @@ class ScratchBuffer: public Buffer { MaceStatus GrowSize(index_t size) { if (size > size_) { + MACE_CHECK(offset_ == 0, "scratch is being used, cannot grow size"); return Resize(size); } return MaceStatus::MACE_SUCCESS; @@ -487,8 +488,12 @@ class ScratchBuffer: public Buffer { return slice; } - void Rewind() { - offset_ = 0; + void Rewind(index_t offset = 0) { + offset_ = offset; + } + + index_t offset() const { + return offset_; } private: diff --git a/mace/kernels/arm/conv_2d_neon.h b/mace/kernels/arm/conv_2d_neon.h index 7c7f7a776a1b6ada47c00c7c0c52c32e3bc1ed68..bf0e1023b30df4158314ecfc88ed84448feda557 100644 --- a/mace/kernels/arm/conv_2d_neon.h +++ b/mace/kernels/arm/conv_2d_neon.h @@ -16,6 +16,7 @@ #define MACE_KERNELS_ARM_CONV_2D_NEON_H_ #include "mace/core/types.h" +#include "mace/kernels/sgemm.h" namespace mace { namespace kernels { @@ -27,7 +28,9 @@ void Conv2dNeonK1x1S1(const float *input, const index_t width, const index_t in_channels, const index_t out_channels, - float *output); + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer); void Conv2dNeonK3x3S1(const float *input, const float *filter, diff --git a/mace/kernels/arm/conv_2d_neon_1x1.cc b/mace/kernels/arm/conv_2d_neon_1x1.cc index 28aa6e4f824342fa4bf6e5d624718eca19428a45..21554d90b57cbc8c5ab7df9b4a011e35bd69129f 100644 --- a/mace/kernels/arm/conv_2d_neon_1x1.cc +++ b/mace/kernels/arm/conv_2d_neon_1x1.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "mace/kernels/arm/conv_2d_neon.h" -#include "mace/kernels/gemm.h" namespace mace { namespace kernels { @@ -25,11 +24,23 @@ void Conv2dNeonK1x1S1(const float *input, const index_t width, const index_t in_channels, const index_t out_channels, - float *output) { + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer) { for (index_t b = 0; b < batch; ++b) { - Gemm(filter, input + b * in_channels * height * width, 1, out_channels, - in_channels, height * width, - output + b * out_channels * height * width); + sgemm->Run(filter, + input + b * in_channels * height * width, + 1, + out_channels, + in_channels, + in_channels, + height * width, + false, + false, + true, + false, + output + b * out_channels * height * width, + scratch_buffer); } } diff --git a/mace/kernels/arm/conv_winograd.cc b/mace/kernels/arm/conv_winograd.cc index b074ced7b85965cab360e77f17b7f9a836398fee..d115e4e5000a04b86b8aa9023c32b9c3cd4c97f4 100644 --- a/mace/kernels/arm/conv_winograd.cc +++ b/mace/kernels/arm/conv_winograd.cc @@ -12,13 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include "mace/kernels/arm/conv_winograd.h" #include "mace/kernels/gemm.h" -#include "mace/utils/logging.h" -#include "mace/utils/utils.h" namespace mace { namespace kernels { @@ -247,30 +244,38 @@ void BatchGemm(const float *input, index_t out_channels, index_t tile_count, int out_tile_size, - float *output) { - const index_t filter_stride = out_channels * in_channels; + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer) { const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2); const index_t in_batch_size = in_tile_area * in_channels * tile_count; - const index_t in_stride = in_channels * tile_count; const index_t out_batch_size = in_tile_area * out_channels * tile_count; - const index_t out_stride = out_channels * tile_count; - if (batch == 1) { - Gemm(filter, input, in_tile_area, out_channels, in_channels, tile_count, - output); - } else { -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int i = 0; i < in_tile_area; ++i) { - const float *in_ptr = input + b * in_batch_size + i * in_stride; - const float *filter_ptr = filter + i * filter_stride; - float *out_ptr = output + b * out_batch_size + i * out_stride; - Gemm(filter_ptr, in_ptr, 1, out_channels, /* rows */ - in_channels, /* K */ - tile_count, /* cols */ - out_ptr); - } + index_t scratch_buffer_offset = 0; + if (scratch_buffer) { + scratch_buffer_offset = scratch_buffer->offset(); + } + // 'batch' is not gemm batch, 'in_tile_area' is. gemm is not thread safe, + // so we loop batch using single thread. + // Scratch buffer should be rewind to the initial position to use same + // scratch memory for each batch. + for (int b = 0; b < batch; ++b) { + if (scratch_buffer) { + scratch_buffer->Rewind(scratch_buffer_offset); } + sgemm->Run(filter, + input + b * in_batch_size, + in_tile_area, + out_channels, + in_channels, + in_channels, + tile_count, + false, + false, + true, + false, + output + b * out_batch_size, + scratch_buffer); } } @@ -613,7 +618,9 @@ void WinoGradConv3x3s1(const float *input, const int out_tile_size, float *transformed_input, float *transformed_output, - float *output) { + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer) { index_t out_height = in_height - 2; index_t out_width = in_width - 2; index_t tile_height_count = @@ -636,7 +643,8 @@ void WinoGradConv3x3s1(const float *input, } BatchGemm(transformed_input, transformed_filter, batch, in_channels, - out_channels, tile_count, out_tile_size, transformed_output); + out_channels, tile_count, out_tile_size, transformed_output, + sgemm, scratch_buffer); switch (out_tile_size) { case 2: @@ -660,7 +668,9 @@ void WinoGradConv3x3s1(const float *input, const index_t in_channels, const index_t out_channels, const int out_tile_size, - float *output) { + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer) { index_t out_height = in_height - 2; index_t out_width = in_width - 2; index_t tile_height_count = @@ -692,7 +702,7 @@ void WinoGradConv3x3s1(const float *input, WinoGradConv3x3s1(input, transformed_filter, batch, in_height, in_width, in_channels, out_channels, out_tile_size, transformed_input, - transformed_output, output); + transformed_output, output, sgemm, scratch_buffer); delete[] transformed_input; delete[] transformed_filter; diff --git a/mace/kernels/arm/conv_winograd.h b/mace/kernels/arm/conv_winograd.h index 558fea9deddd8edd923088d585cb1e556e83dd8c..7e274b777166ea6cb379145dc29b90b72ce571ba 100644 --- a/mace/kernels/arm/conv_winograd.h +++ b/mace/kernels/arm/conv_winograd.h @@ -20,6 +20,7 @@ #endif #include "mace/core/types.h" +#include "mace/kernels/sgemm.h" namespace mace { namespace kernels { @@ -42,7 +43,9 @@ void WinoGradConv3x3s1(const float *input, const index_t in_channels, const index_t out_channels, const int out_tile_size, - float *output); + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer); void WinoGradConv3x3s1(const float *input, const float *transformed_filter, @@ -54,7 +57,9 @@ void WinoGradConv3x3s1(const float *input, const int out_tile_size, float *transformed_input, float *transformed_output, - float *output); + float *output, + SGemm *sgemm, + ScratchBuffer *scratch_buffer); void ConvRef3x3s1(const float *input, const float *filter, diff --git a/mace/kernels/arm/conv_winograd_test.cc b/mace/kernels/arm/conv_winograd_test.cc index 1313543220580b896965bc3ef240e31b6edc3b09..ccb4f1181d41c2f96972ccc61fecb165aa602129 100644 --- a/mace/kernels/arm/conv_winograd_test.cc +++ b/mace/kernels/arm/conv_winograd_test.cc @@ -65,9 +65,10 @@ TEST(ConvWinogradTest, winograd) { kernels::ConvRef3x3s1(input_data, filter_data, batch, in_height, in_width, in_channels, out_channels, output_data_ref); + SGemm sgemm; kernels::WinoGradConv3x3s1(input_data, filter_data, batch, in_height, in_width, in_channels, out_channels, 6, - output_data); + output_data, &sgemm, nullptr); // test for (index_t i = 0; i < output_size; ++i) { diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 024644f33f7f8a871ac2ab9dba6c4931a54aa821..0568d9b345c602c1cd10d98a72df266db88841fe 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -483,6 +483,16 @@ struct Conv2dFunctor : Conv2dFunctorBase { * sizeof(float); total_scratch_size += padded_output_size; } + // scratch for sgemm + if (use_neon_1x1_s1) { + total_scratch_size += + (input_batch * input_height * input_width + * (input_channels + channels)) * sizeof(float); + } else if (use_winograd) { + total_scratch_size += + (transformed_input_size + transformed_output_size) * sizeof(float); + } + // Init scratch buffer scratch_->Rewind(); scratch_->GrowSize(total_scratch_size); @@ -547,7 +557,9 @@ struct Conv2dFunctor : Conv2dFunctorBase { winograd_out_tile_size, transformed_input_data, transformed_output_data, - pad_output); + pad_output, + &sgemm_, + scratch_); }; } else if (use_neon_3x3_s1) { conv_func = [=](const float *pad_input, float *pad_output) { @@ -574,7 +586,9 @@ struct Conv2dFunctor : Conv2dFunctorBase { extra_input_width, input_channels, channels, - pad_output); + pad_output, + &sgemm_, + scratch_); }; } else if (use_neon_5x5_s1) { conv_func = [=](const float *pad_input, float *pad_output) { @@ -722,6 +736,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { Tensor transformed_filter_; bool is_filter_transformed_; ScratchBuffer *scratch_; + SGemm sgemm_; }; template<> diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index 9c5292d2ac9f74cf29de36c2dc0f75502e875cdd..6f239de2127902f86ba29c556b78ea3e7ecab8c8 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -89,6 +89,9 @@ struct MatMulFunctor : OpKernel { const index_t height_b = B->dim(rank - 2); const index_t width_b = B->dim(rank - 1); + auto scratch_buffer = context_->workspace()->GetScratchBuffer(D); + scratch_buffer->Rewind(); + sgemm_.Run(a_ptr_base, b_ptr_base, batch, @@ -101,7 +104,7 @@ struct MatMulFunctor : OpKernel { A->is_weight(), B->is_weight(), c_ptr_base, - context_->workspace()->GetScratchBuffer(D)); + scratch_buffer); return MACE_SUCCESS; } diff --git a/mace/kernels/sgemm.cc b/mace/kernels/sgemm.cc index a5f15fd0abe5c5d09e6f175dc31c38372eb4f774..753cdc2161e6b1e746fa9fcd9d45c21aab5dd4fa 100644 --- a/mace/kernels/sgemm.cc +++ b/mace/kernels/sgemm.cc @@ -44,7 +44,6 @@ void SGemm::operator()(const MatrixMap &lhs, } if (scratch_buffer != nullptr) { - scratch_buffer->Rewind(); index_t total_size = result->size(); if (!lhs.is_const()) { total_size += lhs.size(); @@ -54,7 +53,6 @@ void SGemm::operator()(const MatrixMap &lhs, } scratch_buffer->GrowSize(total_size * sizeof(float)); - scratch_buffer->Rewind(); if (!lhs.is_const()) { packed_lhs_.reset(new Tensor(scratch_buffer->Scratch( lhs.size() * sizeof(float)), DT_FLOAT));