diff --git a/mace/core/buffer.h b/mace/core/buffer.h index f4b252a776296b1e065816c3a9b6288d13d03837..083c3e3002604940c8c3af82e5dd0e0a0a25f693 100644 --- a/mace/core/buffer.h +++ b/mace/core/buffer.h @@ -56,6 +56,8 @@ class BufferBase { virtual void Clear() = 0; + virtual void Clear(index_t size) = 0; + virtual index_t offset() const { return 0; } template @@ -175,7 +177,11 @@ class Buffer : public BufferBase { bool OnHost() const { return allocator_->OnHost(); } void Clear() { - memset(reinterpret_cast(raw_mutable_data()), 0, size_); + Clear(size_); + } + + void Clear(index_t size) { + memset(reinterpret_cast(raw_mutable_data()), 0, size); } protected: @@ -277,6 +283,11 @@ class Image : public BufferBase { MACE_NOT_IMPLEMENTED; } + void Clear(index_t size) { + MACE_UNUSED(size); + MACE_NOT_IMPLEMENTED; + } + private: Allocator *allocator_; std::vector shape_; @@ -381,7 +392,11 @@ class BufferSlice : public BufferBase { bool OnHost() const { return buffer_->OnHost(); } void Clear() { - memset(raw_mutable_data(), 0, size_); + Clear(size_); + } + + void Clear(index_t size) { + memset(raw_mutable_data(), 0, size); } private: diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 87a6cb3c1c9ff9f3712c50e07c8c6e0d69f5cf61..d50a223cac90d1368c623b2a32c9395ce668fc5f 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -208,7 +208,7 @@ class Tensor { inline void Clear() { MACE_CHECK_NOTNULL(buffer_); - buffer_->Clear(); + buffer_->Clear(raw_size()); } inline void Reshape(const std::vector &shape) { diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 7f8258280728872c32f96cfc49316e69d370eb08..4cf1224be09a3df7681a704a668b4cfce60b4a7e 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -298,7 +298,6 @@ struct Conv2dFunctor : Conv2dFunctorBase { output_shape.data()); } output->Resize(output_shape); - output->Clear(); index_t batch = output->dim(0); index_t channels = output->dim(1); @@ -418,7 +417,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { if (extra_input_width != padded_input_width) { pad_right += (extra_input_width - padded_input_width); } - } else { + } else if (!use_neon_1x1_s1) { extra_output_height = height; extra_input_height = std::max(padded_input_height, (extra_output_height - 1) * stride_h @@ -605,7 +604,6 @@ struct Conv2dFunctor : Conv2dFunctorBase { const Tensor *pad_input_ptr = input; if (extra_input_height != input_height || extra_input_width != input_width) { - padded_input.Clear(); ConstructNCHWInputWithSpecificPadding(input, pad_top, pad_bottom, @@ -615,13 +613,17 @@ struct Conv2dFunctor : Conv2dFunctorBase { pad_input_ptr = &padded_input; } + // TODO(libin): don't need clear after bias is integrated in each conv Tensor *pad_output_ptr = output; if (extra_output_height != height || extra_output_width != width) { padded_output.Reshape({batch, channels, extra_output_height, extra_output_width}); padded_output.Clear(); pad_output_ptr = &padded_output; + } else if (!use_neon_1x1_s1) { + output->Clear(); } + const float *pad_input_data = pad_input_ptr->data(); float *pad_output_data = pad_output_ptr->mutable_data(); diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index b0f5229f94de2be770f68eb42d538170da8fa7ca..07c72cb35eece69e9b2cefae8b82841dac397524 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -377,6 +377,7 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor, std::vector output_shape( {batch, channels, height + pad_height, width + pad_width}); output_tensor->Resize(output_shape); + output_tensor->Clear(); Tensor::MappingGuard padded_output_mapper(output_tensor); float *output_data = output_tensor->mutable_data(); diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 0fae44de1fa7c1c77195cfd6c93140c6e60c1d05..178e0720a13555d00387ccc57a87f065cd746b42 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -140,8 +140,8 @@ inline void GemmTile(const float *A, #endif #if defined(MACE_ENABLE_NEON) && defined(__aarch64__) - for (h = 0; h + 7 < height; h += 8) { - for (k = 0; k + 7 < K; k += 8) { + for (h = 0; h < height - 7; h += 8) { + for (k = 0; k < K - 7; k += 8) { const float *a_ptr = A + (h * stride_k + k); #ifdef __clang__ int nw = width >> 2; @@ -185,156 +185,150 @@ inline void GemmTile(const float *A, float *c_ptr7 = C + (h + 7) * stride_w; asm volatile( - "0: \n" - - "prfm pldl1keep, [%1, #128] \n" - "ld1 {v24.4s}, [%1] \n" - - // load b: 0-7 - "prfm pldl1keep, [%9, #128] \n" - "ld1 {v16.4s}, [%9], #16 \n" - - "prfm pldl1keep, [%10, #128] \n" - "ld1 {v17.4s}, [%10], #16 \n" - - "prfm pldl1keep, [%11, #128] \n" - "ld1 {v18.4s}, [%11], #16 \n" - - "prfm pldl1keep, [%12, #128] \n" - "ld1 {v19.4s}, [%12], #16 \n" - - "prfm pldl1keep, [%2, #128] \n" - "ld1 {v25.4s}, [%2] \n" - - "prfm pldl1keep, [%13, #128] \n" - "ld1 {v20.4s}, [%13], #16 \n" - - "prfm pldl1keep, [%14, #128] \n" - "ld1 {v21.4s}, [%14], #16 \n" - - "prfm pldl1keep, [%15, #128] \n" - "ld1 {v22.4s}, [%15], #16 \n" - - "prfm pldl1keep, [%16, #128] \n" - "ld1 {v23.4s}, [%16], #16 \n" - - "prfm pldl1keep, [%3, #128] \n" - "ld1 {v26.4s}, [%3] \n" - - "fmla v24.4s, v16.4s, %34.s[0] \n" - "fmla v24.4s, v17.4s, %34.s[1] \n" - "fmla v24.4s, v18.4s, %34.s[2] \n" - "fmla v24.4s, v19.4s, %34.s[3] \n" - - "fmla v24.4s, v20.4s, %35.s[0] \n" - "fmla v24.4s, v21.4s, %35.s[1] \n" - "fmla v24.4s, v22.4s, %35.s[2] \n" - "fmla v24.4s, v23.4s, %35.s[3] \n" - - "st1 {v24.4s}, [%1], #16 \n" - - "fmla v25.4s, v16.4s, %36.s[0] \n" - "fmla v25.4s, v17.4s, %36.s[1] \n" - "fmla v25.4s, v18.4s, %36.s[2] \n" - "fmla v25.4s, v19.4s, %36.s[3] \n" - - "fmla v25.4s, v20.4s, %37.s[0] \n" - "fmla v25.4s, v21.4s, %37.s[1] \n" - "fmla v25.4s, v22.4s, %37.s[2] \n" - "fmla v25.4s, v23.4s, %37.s[3] \n" - - "prfm pldl1keep, [%4, #128] \n" - "ld1 {v24.4s}, [%4] \n" - - "st1 {v25.4s}, [%2], #16 \n" - - "fmla v26.4s, v16.4s, %38.s[0] \n" - "fmla v26.4s, v17.4s, %38.s[1] \n" - "fmla v26.4s, v18.4s, %38.s[2] \n" - "fmla v26.4s, v19.4s, %38.s[3] \n" + "prfm pldl1keep, [%9, #128] \n" + "ld1 {v16.4s}, [%9], #16 \n" - "fmla v26.4s, v20.4s, %39.s[0] \n" - "fmla v26.4s, v21.4s, %39.s[1] \n" - "fmla v26.4s, v22.4s, %39.s[2] \n" - "fmla v26.4s, v23.4s, %39.s[3] \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v18.4s}, [%1] \n" - "prfm pldl1keep, [%5, #128] \n" - "ld1 {v25.4s}, [%5] \n" + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v19.4s}, [%2] \n" - "st1 {v26.4s}, [%3], #16 \n" - - "fmla v24.4s, v16.4s, %40.s[0] \n" - "fmla v24.4s, v17.4s, %40.s[1] \n" - "fmla v24.4s, v18.4s, %40.s[2] \n" - "fmla v24.4s, v19.4s, %40.s[3] \n" - - "fmla v24.4s, v20.4s, %41.s[0] \n" - "fmla v24.4s, v21.4s, %41.s[1] \n" - "fmla v24.4s, v22.4s, %41.s[2] \n" - "fmla v24.4s, v23.4s, %41.s[3] \n" - - "prfm pldl1keep, [%6, #128] \n" - "ld1 {v26.4s}, [%6] \n" - - "st1 {v24.4s}, [%4], #16 \n" - - "fmla v25.4s, v16.4s, %42.s[0] \n" - "fmla v25.4s, v17.4s, %42.s[1] \n" - "fmla v25.4s, v18.4s, %42.s[2] \n" - "fmla v25.4s, v19.4s, %42.s[3] \n" - - "fmla v25.4s, v20.4s, %43.s[0] \n" - "fmla v25.4s, v21.4s, %43.s[1] \n" - "fmla v25.4s, v22.4s, %43.s[2] \n" - "fmla v25.4s, v23.4s, %43.s[3] \n" - - "prfm pldl1keep, [%7, #128] \n" - "ld1 {v24.4s}, [%7] \n" - - "st1 {v25.4s}, [%5], #16 \n" - - "fmla v26.4s, v16.4s, %44.s[0] \n" - "fmla v26.4s, v17.4s, %44.s[1] \n" - "fmla v26.4s, v18.4s, %44.s[2] \n" - "fmla v26.4s, v19.4s, %44.s[3] \n" - - "fmla v26.4s, v20.4s, %45.s[0] \n" - "fmla v26.4s, v21.4s, %45.s[1] \n" - "fmla v26.4s, v22.4s, %45.s[2] \n" - "fmla v26.4s, v23.4s, %45.s[3] \n" - - "prfm pldl1keep, [%8, #128] \n" - "ld1 {v25.4s}, [%8] \n" - - "st1 {v26.4s}, [%6], #16 \n" - - "fmla v24.4s, v16.4s, %46.s[0] \n" - "fmla v24.4s, v17.4s, %46.s[1] \n" - "fmla v24.4s, v18.4s, %46.s[2] \n" - "fmla v24.4s, v19.4s, %46.s[3] \n" - - "fmla v24.4s, v20.4s, %47.s[0] \n" - "fmla v24.4s, v21.4s, %47.s[1] \n" - "fmla v24.4s, v22.4s, %47.s[2] \n" - "fmla v24.4s, v23.4s, %47.s[3] \n" - - "st1 {v24.4s}, [%7], #16 \n" - - "fmla v25.4s, v16.4s, %48.s[0] \n" - "fmla v25.4s, v17.4s, %48.s[1] \n" - "fmla v25.4s, v18.4s, %48.s[2] \n" - "fmla v25.4s, v19.4s, %48.s[3] \n" - - "fmla v25.4s, v20.4s, %49.s[0] \n" - "fmla v25.4s, v21.4s, %49.s[1] \n" - "fmla v25.4s, v22.4s, %49.s[2] \n" - "fmla v25.4s, v23.4s, %49.s[3] \n" - - "st1 {v25.4s}, [%8], #16 \n" + "0: \n" - "subs %w0, %w0, #1 \n" - "bne 0b \n" - : "=r"(nw), // 0 + "prfm pldl1keep, [%3, #128] \n" + "ld1 {v20.4s}, [%3] \n" + "prfm pldl1keep, [%4, #128] \n" + "ld1 {v21.4s}, [%4] \n" + "prfm pldl1keep, [%5, #128] \n" + "ld1 {v22.4s}, [%5] \n" + "prfm pldl1keep, [%6, #128] \n" + "ld1 {v23.4s}, [%6] \n" + "prfm pldl1keep, [%7, #128] \n" + "ld1 {v24.4s}, [%7] \n" + "prfm pldl1keep, [%8, #128] \n" + "ld1 {v25.4s}, [%8] \n" + "prfm pldl1keep, [%10, #128] \n" + "ld1 {v17.4s}, [%10], #16 \n" + + "fmla v18.4s, v16.4s, %34.s[0] \n" + "fmla v19.4s, v16.4s, %35.s[0] \n" + "fmla v20.4s, v16.4s, %36.s[0] \n" + "fmla v21.4s, v16.4s, %37.s[0] \n" + + "fmla v22.4s, v16.4s, %38.s[0] \n" + "fmla v23.4s, v16.4s, %39.s[0] \n" + "fmla v24.4s, v16.4s, %40.s[0] \n" + "fmla v25.4s, v16.4s, %41.s[0] \n" + + "fmla v18.4s, v17.4s, %34.s[1] \n" + "fmla v19.4s, v17.4s, %35.s[1] \n" + "fmla v20.4s, v17.4s, %36.s[1] \n" + "fmla v21.4s, v17.4s, %37.s[1] \n" + + "prfm pldl1keep, [%11, #128] \n" + "ld1 {v16.4s}, [%11], #16 \n" + + "fmla v22.4s, v17.4s, %38.s[1] \n" + "fmla v23.4s, v17.4s, %39.s[1] \n" + "fmla v24.4s, v17.4s, %40.s[1] \n" + "fmla v25.4s, v17.4s, %41.s[1] \n" + + "fmla v18.4s, v16.4s, %34.s[2] \n" + "fmla v19.4s, v16.4s, %35.s[2] \n" + "fmla v20.4s, v16.4s, %36.s[2] \n" + "fmla v21.4s, v16.4s, %37.s[2] \n" + + "prfm pldl1keep, [%12, #128] \n" + "ld1 {v17.4s}, [%12], #16 \n" + + "fmla v22.4s, v16.4s, %38.s[2] \n" + "fmla v23.4s, v16.4s, %39.s[2] \n" + "fmla v24.4s, v16.4s, %40.s[2] \n" + "fmla v25.4s, v16.4s, %41.s[2] \n" + + "fmla v18.4s, v17.4s, %34.s[3] \n" + "fmla v19.4s, v17.4s, %35.s[3] \n" + "fmla v20.4s, v17.4s, %36.s[3] \n" + "fmla v21.4s, v17.4s, %37.s[3] \n" + + "prfm pldl1keep, [%13, #128] \n" + "ld1 {v16.4s}, [%13], #16 \n" + + "fmla v22.4s, v17.4s, %38.s[3] \n" + "fmla v23.4s, v17.4s, %39.s[3] \n" + "fmla v24.4s, v17.4s, %40.s[3] \n" + "fmla v25.4s, v17.4s, %41.s[3] \n" + + "fmla v18.4s, v16.4s, %42.s[0] \n" + "fmla v19.4s, v16.4s, %43.s[0] \n" + "fmla v20.4s, v16.4s, %44.s[0] \n" + "fmla v21.4s, v16.4s, %45.s[0] \n" + + "prfm pldl1keep, [%14, #128] \n" + "ld1 {v17.4s}, [%14], #16 \n" + + "fmla v22.4s, v16.4s, %46.s[0] \n" + "fmla v23.4s, v16.4s, %47.s[0] \n" + "fmla v24.4s, v16.4s, %48.s[0] \n" + "fmla v25.4s, v16.4s, %49.s[0] \n" + + "fmla v18.4s, v17.4s, %42.s[1] \n" + "fmla v19.4s, v17.4s, %43.s[1] \n" + "fmla v20.4s, v17.4s, %44.s[1] \n" + "fmla v21.4s, v17.4s, %45.s[1] \n" + + "prfm pldl1keep, [%15, #128] \n" + "ld1 {v16.4s}, [%15], #16 \n" + + "fmla v22.4s, v17.4s, %46.s[1] \n" + "fmla v23.4s, v17.4s, %47.s[1] \n" + "fmla v24.4s, v17.4s, %48.s[1] \n" + "fmla v25.4s, v17.4s, %49.s[1] \n" + + "fmla v18.4s, v16.4s, %42.s[2] \n" + "fmla v19.4s, v16.4s, %43.s[2] \n" + "fmla v20.4s, v16.4s, %44.s[2] \n" + "fmla v21.4s, v16.4s, %45.s[2] \n" + + "prfm pldl1keep, [%16, #128] \n" + "ld1 {v17.4s}, [%16], #16 \n" + + "fmla v22.4s, v16.4s, %46.s[2] \n" + "fmla v23.4s, v16.4s, %47.s[2] \n" + "fmla v24.4s, v16.4s, %48.s[2] \n" + "fmla v25.4s, v16.4s, %49.s[2] \n" + + "fmla v18.4s, v17.4s, %42.s[3] \n" + "fmla v19.4s, v17.4s, %43.s[3] \n" + "fmla v20.4s, v17.4s, %44.s[3] \n" + "fmla v21.4s, v17.4s, %45.s[3] \n" + + "st1 {v18.4s}, [%1], #16 \n" + "st1 {v19.4s}, [%2], #16 \n" + "st1 {v20.4s}, [%3], #16 \n" + "st1 {v21.4s}, [%4], #16 \n" + + "fmla v22.4s, v17.4s, %46.s[3] \n" + "fmla v23.4s, v17.4s, %47.s[3] \n" + "fmla v24.4s, v17.4s, %48.s[3] \n" + "fmla v25.4s, v17.4s, %49.s[3] \n" + + "st1 {v22.4s}, [%5], #16 \n" + "st1 {v23.4s}, [%6], #16 \n" + "st1 {v24.4s}, [%7], #16 \n" + "st1 {v25.4s}, [%8], #16 \n" + + "prfm pldl1keep, [%9, #128] \n" + "ld1 {v16.4s}, [%9], #16 \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v18.4s}, [%1] \n" + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v19.4s}, [%2] \n" + + "subs %w0, %w0, #1 \n" + "bne 0b \n" + : "=r"(nw), // 0 "=r"(c_ptr0), // 1 "=r"(c_ptr1), // 2 "=r"(c_ptr2), // 3 @@ -351,7 +345,7 @@ inline void GemmTile(const float *A, "=r"(b_ptr5), // 14 "=r"(b_ptr6), // 15 "=r"(b_ptr7) // 16 - : "0"(nw), // 17 + : "0"(nw), // 17 "1"(c_ptr0), // 18 "2"(c_ptr1), // 19 "3"(c_ptr2), // 20 @@ -369,20 +363,20 @@ inline void GemmTile(const float *A, "15"(b_ptr6), // 32 "16"(b_ptr7), // 33 "w"(a0), // 34 - "w"(a1), // 35 - "w"(a2), // 36 - "w"(a3), // 37 - "w"(a4), // 38 - "w"(a5), // 39 - "w"(a6), // 40 - "w"(a7), // 41 - "w"(a8), // 42 - "w"(a9), // 43 - "w"(a10), // 44 - "w"(a11), // 45 - "w"(a12), // 46 - "w"(a13), // 47 - "w"(a14), // 48 + "w"(a2), // 35 + "w"(a4), // 36 + "w"(a6), // 37 + "w"(a8), // 38 + "w"(a10), // 39 + "w"(a12), // 40 + "w"(a14), // 41 + "w"(a1), // 42 + "w"(a3), // 43 + "w"(a5), // 44 + "w"(a7), // 45 + "w"(a9), // 46 + "w"(a11), // 47 + "w"(a13), // 48 "w"(a15) // 49 : "cc", "memory", "v16", @@ -585,7 +579,6 @@ void Gemm(const float *A, } memset(C, 0, sizeof(float) * batch * height * width); - // It is better to use large block size if it fits for fast cache. // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C), // the block size should be sqrt(32k / sizeof(T) / 3).