From bcc883368248e0f799659d27a6ed5260e8263e1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Tue, 12 Feb 2019 14:03:01 +0800 Subject: [PATCH] Optimize gemm --- mace/core/allocator.cc | 22 + mace/core/allocator.h | 6 + mace/core/buffer.h | 4 +- mace/core/op_context.cc | 6 +- mace/core/op_context.h | 6 +- mace/ops/arm/fp32/conv_2d.h | 44 ++ mace/ops/arm/fp32/conv_2d_1x1.cc | 53 ++ mace/ops/arm/fp32/conv_2d_1x1.h | 49 ++ mace/ops/arm/fp32/gemm.cc | 1217 ++++++++++++++++++++++++++++++ mace/ops/arm/fp32/gemm.h | 156 ++++ mace/ops/arm/fp32/gemm_test.cc | 115 +++ mace/ops/arm/fp32/gemv.cc | 61 +- mace/ops/arm/fp32/gemv.h | 1 + mace/ops/arm/fp32/gemv_test.cc | 32 +- mace/ops/arm/q8/gemv.cc | 4 +- mace/ops/arm/q8/gemv.h | 4 + mace/ops/arm/q8/gemv_test.cc | 48 +- mace/ops/common/matrix.h | 107 +++ mace/ops/conv_2d.cc | 775 ++++++++++--------- mace/ops/fully_connected.cc | 2 + mace/ops/matmul.cc | 340 ++++++--- mace/ops/matmul_benchmark.cc | 14 +- mace/ops/matmul_test.cc | 181 +++-- mace/ops/ref/conv_2d.cc | 111 +++ mace/ops/ref/conv_2d.h | 76 ++ mace/ops/ref/gemm.cc | 116 +++ mace/ops/ref/gemm.h | 89 +++ mace/ops/ref/gemv.cc | 15 +- mace/ops/ref/gemv.h | 4 + mace/ops/sgemm.cc | 82 +- mace/ops/sgemm.h | 44 +- mace/ops/sgemm_pack_test.cc | 56 +- 32 files changed, 3134 insertions(+), 706 deletions(-) create mode 100644 mace/ops/arm/fp32/conv_2d.h create mode 100644 mace/ops/arm/fp32/conv_2d_1x1.cc create mode 100644 mace/ops/arm/fp32/conv_2d_1x1.h create mode 100644 mace/ops/arm/fp32/gemm.cc create mode 100644 mace/ops/arm/fp32/gemm.h create mode 100644 mace/ops/arm/fp32/gemm_test.cc create mode 100644 mace/ops/common/matrix.h create mode 100644 mace/ops/ref/conv_2d.cc create mode 100644 mace/ops/ref/conv_2d.h create mode 100644 mace/ops/ref/gemm.cc create mode 100644 mace/ops/ref/gemm.h diff --git a/mace/core/allocator.cc b/mace/core/allocator.cc index 003b1c2c..b5bffabc 100644 --- a/mace/core/allocator.cc +++ b/mace/core/allocator.cc @@ -14,6 +14,10 @@ #include "mace/core/allocator.h" +#include +#include +#include + namespace mace { Allocator *GetCPUAllocator() { @@ -21,4 +25,22 @@ Allocator *GetCPUAllocator() { return &allocator; } +void AdviseFree(void *addr, size_t length) { + int page_size = sysconf(_SC_PAGESIZE); + void *addr_aligned = + reinterpret_cast( + (reinterpret_cast(addr) + page_size - 1) + & (~(page_size - 1))); + uintptr_t delta = + reinterpret_cast(addr_aligned) + - reinterpret_cast(addr); + if (length >= delta + page_size) { + size_t len_aligned = (length - delta) & (~(page_size - 1)); + int ret = madvise(addr_aligned, len_aligned, MADV_DONTNEED); + if (ret != 0) { + LOG(ERROR) << "Advise free failed: " << strerror(errno); + } + } +} + } // namespace mace diff --git a/mace/core/allocator.h b/mace/core/allocator.h index 9a0811ae..9c910363 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -40,6 +40,10 @@ constexpr size_t kMaceAlignment = 64; constexpr size_t kMaceAlignment = 32; #endif +inline index_t PadAlignSize(index_t size) { + return (size + kMaceAlignment - 1) & (~(kMaceAlignment - 1)); +} + class Allocator { public: Allocator() {} @@ -140,6 +144,8 @@ class CPUAllocator : public Allocator { // Global CPU allocator used for CPU/GPU/DSP Allocator *GetCPUAllocator(); +void AdviseFree(void *addr, size_t length); + } // namespace mace #endif // MACE_CORE_ALLOCATOR_H_ diff --git a/mace/core/buffer.h b/mace/core/buffer.h index 149e54bc..66684db1 100644 --- a/mace/core/buffer.h +++ b/mace/core/buffer.h @@ -384,7 +384,7 @@ class BufferSlice : public BufferBase { BufferSlice(const BufferSlice &other) : BufferSlice(other.buffer_, other.offset_, other.size_) {} - ~BufferSlice() { + virtual ~BufferSlice() { if (buffer_ != nullptr && mapped_buf_ != nullptr) { UnMap(); } @@ -506,7 +506,7 @@ class ScratchBuffer: public Buffer { virtual ~ScratchBuffer() {} - MaceStatus GrowSize(index_t size) { + MaceStatus GrowSize(const index_t size) { if (size > size_) { VLOG(1) << "Grow scratch size to: " << size; MACE_CHECK(offset_ == 0, "scratch is being used, cannot grow size"); diff --git a/mace/core/op_context.cc b/mace/core/op_context.cc index 7f2048ed..d0ebeff7 100644 --- a/mace/core/op_context.cc +++ b/mace/core/op_context.cc @@ -25,11 +25,11 @@ void OpContext::set_device(Device *device) { device_ = device; } -Device* OpContext::device() { +Device* OpContext::device() const { return device_; } -Workspace* OpContext::workspace() { +Workspace* OpContext::workspace() const { return ws_; } @@ -37,7 +37,7 @@ void OpContext::set_future(StatsFuture *future) { future_ = future; } -StatsFuture *OpContext::future() { +StatsFuture *OpContext::future() const { return future_; } diff --git a/mace/core/op_context.h b/mace/core/op_context.h index 4f27cb57..26a31dc3 100644 --- a/mace/core/op_context.h +++ b/mace/core/op_context.h @@ -26,11 +26,11 @@ class OpContext { OpContext(Workspace *ws, Device *device); ~OpContext(); void set_device(Device *device); - Device *device(); - Workspace *workspace(); + Device *device() const; + Workspace *workspace() const; void set_future(StatsFuture *future); - StatsFuture *future(); + StatsFuture *future() const; private: Device *device_; Workspace *ws_; diff --git a/mace/ops/arm/fp32/conv_2d.h b/mace/ops/arm/fp32/conv_2d.h new file mode 100644 index 00000000..7d77cf14 --- /dev/null +++ b/mace/ops/arm/fp32/conv_2d.h @@ -0,0 +1,44 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_ARM_FP32_CONV_2D_H_ +#define MACE_OPS_ARM_FP32_CONV_2D_H_ + +#include "mace/public/mace.h" +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" +#include "mace/ops/arm/fp32/gemm.h" + +namespace mace { +namespace ops { +namespace arm { +namespace fp32 { + +class Conv2dBase { + public: + Conv2dBase() = default; + virtual ~Conv2dBase() = default; + virtual MaceStatus Compute( + const OpContext *context, + const Tensor *input, + const Tensor *filter, + Tensor *output) = 0; +}; + +} // namespace fp32 +} // namespace arm +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_ARM_FP32_CONV_2D_H_ diff --git a/mace/ops/arm/fp32/conv_2d_1x1.cc b/mace/ops/arm/fp32/conv_2d_1x1.cc new file mode 100644 index 00000000..b34e19aa --- /dev/null +++ b/mace/ops/arm/fp32/conv_2d_1x1.cc @@ -0,0 +1,53 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "mace/ops/arm/fp32/conv_2d_1x1.h" + +namespace mace { +namespace ops { +namespace arm { +namespace fp32 { + +MaceStatus Conv2dK1x1::Compute(const OpContext *context, + const Tensor *input, + const Tensor *filter, + Tensor *output) { + index_t batch = input->dim(0); + index_t height = input->dim(2); + index_t width = input->dim(3); + index_t in_channels = input->dim(1); + index_t out_channels = filter->dim(0); + MACE_RETURN_IF_ERROR(output->Resize({batch, out_channels, height, width})); + context->device()->scratch_buffer()->Rewind(); + return gemm_.Compute(context, + filter, + input, + batch, + out_channels, + in_channels, + in_channels, + height * width, + false, + false, + false, + false, + true, + output); +} + +} // namespace fp32 +} // namespace arm +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/fp32/conv_2d_1x1.h b/mace/ops/arm/fp32/conv_2d_1x1.h new file mode 100644 index 00000000..fd2077ec --- /dev/null +++ b/mace/ops/arm/fp32/conv_2d_1x1.h @@ -0,0 +1,49 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_ARM_FP32_CONV_2D_1X1_H_ +#define MACE_OPS_ARM_FP32_CONV_2D_1X1_H_ + +#include "mace/public/mace.h" +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" +#include "mace/ops/arm/fp32/gemm.h" +#include "mace/ops/arm/fp32/conv_2d.h" + +namespace mace { +namespace ops { +namespace arm { +namespace fp32 { + +class Conv2dK1x1 : public Conv2dBase { + public: + Conv2dK1x1() : gemm_(true) {} + virtual ~Conv2dK1x1() {} + + MaceStatus Compute( + const OpContext *context, + const Tensor *input, + const Tensor *filter, + Tensor *output); + + private: + Gemm gemm_; +}; + +} // namespace fp32 +} // namespace arm +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_ARM_FP32_CONV_2D_1X1_H_ diff --git a/mace/ops/arm/fp32/gemm.cc b/mace/ops/arm/fp32/gemm.cc new file mode 100644 index 00000000..8acde2d1 --- /dev/null +++ b/mace/ops/arm/fp32/gemm.cc @@ -0,0 +1,1217 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "mace/ops/arm/fp32/gemm.h" + +#include +#include +#include + +namespace mace { +namespace ops { +namespace arm { +namespace fp32 { + +enum { kNoCache, kCacheLhs, kCacheRhs }; + +MaceStatus Gemm::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t rows, + const index_t cols, + const index_t depth, + const MatrixMajor lhs_major, + const MatrixMajor rhs_major, + const MatrixMajor output_major, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output) { + MACE_UNUSED(context); + + MACE_CHECK(output->size() == batch * rows * cols, + "Need resize output tensor before call gemm."); + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard output_guard(output); + const float *lhs_data = lhs->data(); + const float *rhs_data = rhs->data(); + float *output_data = output->mutable_data(); + +#ifdef __aarch64__ + const index_t row_block_size = 8; +#else + const index_t row_block_size = 4; +#endif + const index_t col_block_size = 8; + const index_t depth_block_size = 4; + const index_t row_block_count = RoundUpDiv(rows, row_block_size); + const index_t col_block_count = RoundUpDiv(cols, col_block_size); + const index_t rows_padded = RoundUp(rows, row_block_size); + const index_t cols_padded = RoundUp(cols, col_block_size); + const index_t depth_padded = RoundUp(depth, depth_block_size); + + ScratchBuffer *scratch = &tmp_scratch_buffer_; + if (context != nullptr && context->device()->scratch_buffer() != nullptr) { + scratch = context->device()->scratch_buffer(); + } + index_t packed_lhs_size = + PadAlignSize(sizeof(float) * rows_padded * depth_padded); + index_t packed_rhs_size = + PadAlignSize(sizeof(float) * depth_padded * cols_padded); + index_t packed_output_size = + PadAlignSize(sizeof(float) * rows_padded * cols_padded); + // resize to the total size of lhs & rhs & output anyway, + // in case we do not cache const tensor for saving memory + MACE_RETURN_IF_ERROR(scratch->GrowSize( + packed_lhs_size + packed_rhs_size + packed_output_size)); + float *packed_lhs_data = + scratch->Scratch(packed_lhs_size).mutable_data(); + float *packed_rhs_data = + scratch->Scratch(packed_rhs_size).mutable_data(); + float *packed_output_data = + scratch->Scratch(packed_output_size).mutable_data(); + + int cache_side = kNoCache; + if (cached_ == kCacheLhs) { + packed_lhs_data = pack_cache_.mutable_data(); + } else if (cached_ == kCacheRhs) { + packed_rhs_data = pack_cache_.mutable_data(); + } else if (should_cache_pack_) { + if (lhs->is_weight() && !lhs_batched) { + cache_side = kCacheLhs; + pack_cache_.Resize(packed_lhs_size); + packed_lhs_data = pack_cache_.mutable_data(); + } else if (rhs->is_weight() && !rhs_batched) { + cache_side = kCacheRhs; + pack_cache_.Resize(packed_rhs_size); + packed_rhs_data = pack_cache_.mutable_data(); + } + } + + for (index_t b = 0; b < batch; ++b) { + MatrixMap + lhs_matrix + (lhs_data + static_cast(lhs_batched) * b * rows * depth, + lhs_major, + rows, + depth); + MatrixMap + rhs_matrix + (rhs_data + static_cast(rhs_batched) * b * depth * cols, + rhs_major, + depth, + cols); + MatrixMap output_matrix + (output_data + b * rows * cols, output_major, rows, cols); + + // pack lhs + if (cached_ != kCacheLhs) { +#pragma omp parallel for schedule(runtime) + for (index_t row_block_idx = 0; row_block_idx < row_block_count; + ++row_block_idx) { + const index_t start_row = row_block_idx * row_block_size; + const index_t + row_block_len = std::min(row_block_size, rows - start_row); + float *packed_lhs_data_block = + packed_lhs_data + row_block_idx * row_block_size * depth_padded; + PackLhs(lhs_matrix.block(start_row, 0, row_block_len, depth), + packed_lhs_data_block); + } + if (cache_side == kCacheLhs) { + cached_ = kCacheLhs; + if (lhs->UnderlyingBuffer()->OnHost()) { + AdviseFree(reinterpret_cast(const_cast(lhs->data< + float>())), + lhs->raw_size()); + } + } + } + + // pack rhs + if (cached_ != kCacheRhs) { +#pragma omp parallel for schedule(runtime) + for (index_t col_block_idx = 0; col_block_idx < col_block_count; + ++col_block_idx) { + const index_t start_col = col_block_idx * col_block_size; + const index_t + col_block_len = std::min(col_block_size, cols - start_col); + float *packed_rhs_data_block = + packed_rhs_data + col_block_idx * col_block_size * depth_padded; + PackRhs(rhs_matrix.block(0, start_col, depth, col_block_len), + packed_rhs_data_block); + } + if (cache_side == kCacheRhs) { + cached_ = kCacheRhs; + if (rhs->UnderlyingBuffer()->OnHost()) { + AdviseFree(reinterpret_cast(const_cast(rhs->data< + float>())), + rhs->raw_size()); + } + } + } + + // multiply lhs and rhs +#pragma omp parallel for schedule(runtime) + for (index_t row_block_idx = 0; row_block_idx < row_block_count; + ++row_block_idx) { + const index_t start_row = row_block_idx * row_block_size; + const index_t row_block_len = std::min(row_block_size, rows - start_row); + const float *packed_lhs_data_block = + packed_lhs_data + row_block_idx * row_block_size * depth_padded; + + for (index_t col_block_idx = 0; col_block_idx < col_block_count; + ++col_block_idx) { + const index_t start_col = col_block_idx * col_block_size; + const index_t + col_block_len = std::min(col_block_size, cols - start_col); + const float *packed_rhs_data_block = + packed_rhs_data + col_block_idx * col_block_size * depth_padded; + float *packed_output_data_block = + packed_output_data + row_block_idx * row_block_size * cols_padded + + col_block_idx * col_block_size; + ComputeBlock(packed_lhs_data_block, + packed_rhs_data_block, + depth_padded, + packed_output_data_block); + MatrixMap output_block = output_matrix.block(start_row, + start_col, + row_block_len, + col_block_len); + UnpackOutput(packed_output_data_block, &output_block); + } // col_block_idx + } // row_block_idx + } // b + + return MaceStatus::MACE_SUCCESS; +} + +void Gemm::ComputeBlock(const float *packed_lhs_data, + const float *packed_rhs_data, + const index_t depth_padded, + float *packed_output_data) { + /* Ref: + for (index_t r = 0; r < block_size; ++r) { + for (index_t c = 0; c < block_size; ++c) { + float sum = 0; + for (index_t d = 0; d < depth; ++d) { + // (r, d) * (d, c) + sum += packed_lhs_data[d * r_block_size + r] + * packed_rhs_data[d * c_block_size + c]; + } + packed_output_data[r * c_block_size + c] = sum; + } + } + */ + const float *lhs_ptr = packed_lhs_data; + const float *rhs_ptr = packed_rhs_data; + + const index_t depth_block_count = depth_padded / 4; + +#ifdef __aarch64__ + // Register layout: (8x4) x (4,8) + // + // +--------+--------+ + // | v8 ... | v9 ... | + // Rhs +--------+--------+ + // | v10... | v11... | + // +--------+--------+ + // | v12... | v13... | + // +--------+--------+ + // | v14... | v15... | + // +--------+--------+ + // + // Lhs + // + // +----+----+----+----+ - - +--------+--------+ + // | v0 | v2 | v4 | v6 | | v16... | v17... | + // | . | | | | | v18... | v19... | + // | . | | | | | v20... | v21... | + // | . | | | | | v22... | v23... | + // +----+----|----+----+ +--------+--------+ + // | v1 | v3 | v5 | v7 | | v24... | v25... | + // | . | | | | | v26... | v27... | + // | . | | | | | v28... | v29... | + // | . | | | | | v30... | v31... | + // +----+----|----+----+ +--------+--------+ + // + // Accumulator + // + + if (depth_block_count > 0) { + index_t r_depth_block_count = depth_block_count; + // just make compiler happy + MACE_UNUSED(r_depth_block_count); + + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "dup v18.4s, wzr \n" + "dup v19.4s, wzr \n" + "dup v20.4s, wzr \n" + "dup v21.4s, wzr \n" + "dup v22.4s, wzr \n" + "dup v23.4s, wzr \n" + "dup v24.4s, wzr \n" + "dup v25.4s, wzr \n" + "dup v26.4s, wzr \n" + "dup v27.4s, wzr \n" + "dup v28.4s, wzr \n" + "dup v29.4s, wzr \n" + "dup v30.4s, wzr \n" + "dup v31.4s, wzr \n" + + // prelogue + "ld1 {v0.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v2.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v3.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v4.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v5.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v6.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v7.4s}, [%[lhs_ptr]], #16 \n" + + "ld1 {v8.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v9.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v10.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v11.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v12.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v13.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v14.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v15.4s}, [%[rhs_ptr]], #16 \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + "beq 1f\n" + + "0: \n" + "fmla v16.4s, v8.4s, v0.s[0] \n" + "fmla v17.4s, v9.4s, v0.s[0] \n" + "fmla v18.4s, v8.4s, v0.s[1] \n" + "fmla v19.4s, v9.4s, v0.s[1] \n" + "fmla v20.4s, v8.4s, v0.s[2] \n" + "fmla v21.4s, v9.4s, v0.s[2] \n" + "fmla v22.4s, v8.4s, v0.s[3] \n" + "fmla v23.4s, v9.4s, v0.s[3] \n" + + "ld1 {v0.4s}, [%[lhs_ptr]], #16 \n" + + "fmla v24.4s, v8.4s, v1.s[0] \n" + "fmla v25.4s, v9.4s, v1.s[0] \n" + "fmla v26.4s, v8.4s, v1.s[1] \n" + "fmla v27.4s, v9.4s, v1.s[1] \n" + "fmla v28.4s, v8.4s, v1.s[2] \n" + "fmla v29.4s, v9.4s, v1.s[2] \n" + "fmla v30.4s, v8.4s, v1.s[3] \n" + "fmla v31.4s, v9.4s, v1.s[3] \n" + + "ld1 {v1.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v8.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v9.4s}, [%[rhs_ptr]], #16 \n" + + "fmla v16.4s, v10.4s, v2.s[0] \n" + "fmla v17.4s, v11.4s, v2.s[0] \n" + "fmla v18.4s, v10.4s, v2.s[1] \n" + "fmla v19.4s, v11.4s, v2.s[1] \n" + "fmla v20.4s, v10.4s, v2.s[2] \n" + "fmla v21.4s, v11.4s, v2.s[2] \n" + "fmla v22.4s, v10.4s, v2.s[3] \n" + "fmla v23.4s, v11.4s, v2.s[3] \n" + + "ld1 {v2.4s}, [%[lhs_ptr]], #16 \n" + + "fmla v24.4s, v10.4s, v3.s[0] \n" + "fmla v25.4s, v11.4s, v3.s[0] \n" + "fmla v26.4s, v10.4s, v3.s[1] \n" + "fmla v27.4s, v11.4s, v3.s[1] \n" + "fmla v28.4s, v10.4s, v3.s[2] \n" + "fmla v29.4s, v11.4s, v3.s[2] \n" + "fmla v30.4s, v10.4s, v3.s[3] \n" + "fmla v31.4s, v11.4s, v3.s[3] \n" + + "ld1 {v3.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v10.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v11.4s}, [%[rhs_ptr]], #16 \n" + + "fmla v16.4s, v12.4s, v4.s[0] \n" + "fmla v17.4s, v13.4s, v4.s[0] \n" + "fmla v18.4s, v12.4s, v4.s[1] \n" + "fmla v19.4s, v13.4s, v4.s[1] \n" + "fmla v20.4s, v12.4s, v4.s[2] \n" + "fmla v21.4s, v13.4s, v4.s[2] \n" + "fmla v22.4s, v12.4s, v4.s[3] \n" + "fmla v23.4s, v13.4s, v4.s[3] \n" + + "ld1 {v4.4s}, [%[lhs_ptr]], #16 \n" + + "fmla v24.4s, v12.4s, v5.s[0] \n" + "fmla v25.4s, v13.4s, v5.s[0] \n" + "fmla v26.4s, v12.4s, v5.s[1] \n" + "fmla v27.4s, v13.4s, v5.s[1] \n" + "fmla v28.4s, v12.4s, v5.s[2] \n" + "fmla v29.4s, v13.4s, v5.s[2] \n" + "fmla v30.4s, v12.4s, v5.s[3] \n" + "fmla v31.4s, v13.4s, v5.s[3] \n" + + "ld1 {v5.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v12.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v13.4s}, [%[rhs_ptr]], #16 \n" + + "fmla v16.4s, v14.4s, v6.s[0] \n" + "fmla v17.4s, v15.4s, v6.s[0] \n" + "fmla v18.4s, v14.4s, v6.s[1] \n" + "fmla v19.4s, v15.4s, v6.s[1] \n" + "fmla v20.4s, v14.4s, v6.s[2] \n" + "fmla v21.4s, v15.4s, v6.s[2] \n" + "fmla v22.4s, v14.4s, v6.s[3] \n" + "fmla v23.4s, v15.4s, v6.s[3] \n" + + "ld1 {v6.4s}, [%[lhs_ptr]], #16 \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + + "fmla v24.4s, v14.4s, v7.s[0] \n" + "fmla v25.4s, v15.4s, v7.s[0] \n" + "fmla v26.4s, v14.4s, v7.s[1] \n" + "fmla v27.4s, v15.4s, v7.s[1] \n" + "fmla v28.4s, v14.4s, v7.s[2] \n" + "fmla v29.4s, v15.4s, v7.s[2] \n" + "fmla v30.4s, v14.4s, v7.s[3] \n" + "fmla v31.4s, v15.4s, v7.s[3] \n" + + "ld1 {v7.4s}, [%[lhs_ptr]], #16 \n" + "ld1 {v14.4s}, [%[rhs_ptr]], #16 \n" + "ld1 {v15.4s}, [%[rhs_ptr]], #16 \n" + + "bne 0b \n" + + // prologue + "1:\n" + "fmla v16.4s, v8.4s, v0.s[0] \n" + "fmla v17.4s, v9.4s, v0.s[0] \n" + "fmla v18.4s, v8.4s, v0.s[1] \n" + "fmla v19.4s, v9.4s, v0.s[1] \n" + "fmla v20.4s, v8.4s, v0.s[2] \n" + "fmla v21.4s, v9.4s, v0.s[2] \n" + "fmla v22.4s, v8.4s, v0.s[3] \n" + "fmla v23.4s, v9.4s, v0.s[3] \n" + + "fmla v24.4s, v8.4s, v1.s[0] \n" + "fmla v25.4s, v9.4s, v1.s[0] \n" + "fmla v26.4s, v8.4s, v1.s[1] \n" + "fmla v27.4s, v9.4s, v1.s[1] \n" + "fmla v28.4s, v8.4s, v1.s[2] \n" + "fmla v29.4s, v9.4s, v1.s[2] \n" + "fmla v30.4s, v8.4s, v1.s[3] \n" + "fmla v31.4s, v9.4s, v1.s[3] \n" + + "fmla v16.4s, v10.4s, v2.s[0] \n" + "fmla v17.4s, v11.4s, v2.s[0] \n" + "fmla v18.4s, v10.4s, v2.s[1] \n" + "fmla v19.4s, v11.4s, v2.s[1] \n" + "fmla v20.4s, v10.4s, v2.s[2] \n" + "fmla v21.4s, v11.4s, v2.s[2] \n" + "fmla v22.4s, v10.4s, v2.s[3] \n" + "fmla v23.4s, v11.4s, v2.s[3] \n" + + "fmla v24.4s, v10.4s, v3.s[0] \n" + "fmla v25.4s, v11.4s, v3.s[0] \n" + "fmla v26.4s, v10.4s, v3.s[1] \n" + "fmla v27.4s, v11.4s, v3.s[1] \n" + "fmla v28.4s, v10.4s, v3.s[2] \n" + "fmla v29.4s, v11.4s, v3.s[2] \n" + "fmla v30.4s, v10.4s, v3.s[3] \n" + "fmla v31.4s, v11.4s, v3.s[3] \n" + + "fmla v16.4s, v12.4s, v4.s[0] \n" + "fmla v17.4s, v13.4s, v4.s[0] \n" + "fmla v18.4s, v12.4s, v4.s[1] \n" + "fmla v19.4s, v13.4s, v4.s[1] \n" + "fmla v20.4s, v12.4s, v4.s[2] \n" + "fmla v21.4s, v13.4s, v4.s[2] \n" + "fmla v22.4s, v12.4s, v4.s[3] \n" + "fmla v23.4s, v13.4s, v4.s[3] \n" + + "fmla v24.4s, v12.4s, v5.s[0] \n" + "fmla v25.4s, v13.4s, v5.s[0] \n" + "fmla v26.4s, v12.4s, v5.s[1] \n" + "fmla v27.4s, v13.4s, v5.s[1] \n" + "fmla v28.4s, v12.4s, v5.s[2] \n" + "fmla v29.4s, v13.4s, v5.s[2] \n" + "fmla v30.4s, v12.4s, v5.s[3] \n" + "fmla v31.4s, v13.4s, v5.s[3] \n" + + "fmla v16.4s, v14.4s, v6.s[0] \n" + "fmla v17.4s, v15.4s, v6.s[0] \n" + "fmla v18.4s, v14.4s, v6.s[1] \n" + "fmla v19.4s, v15.4s, v6.s[1] \n" + "fmla v20.4s, v14.4s, v6.s[2] \n" + "fmla v21.4s, v15.4s, v6.s[2] \n" + "fmla v22.4s, v14.4s, v6.s[3] \n" + "fmla v23.4s, v15.4s, v6.s[3] \n" + + "fmla v24.4s, v14.4s, v7.s[0] \n" + "fmla v25.4s, v15.4s, v7.s[0] \n" + "fmla v26.4s, v14.4s, v7.s[1] \n" + "fmla v27.4s, v15.4s, v7.s[1] \n" + "fmla v28.4s, v14.4s, v7.s[2] \n" + "fmla v29.4s, v15.4s, v7.s[2] \n" + "fmla v30.4s, v14.4s, v7.s[3] \n" + "fmla v31.4s, v15.4s, v7.s[3] \n" + + "st1 {v16.4s}, [%[packed_output_data]], #16 \n" + "st1 {v17.4s}, [%[packed_output_data]], #16 \n" + "st1 {v18.4s}, [%[packed_output_data]], #16 \n" + "st1 {v19.4s}, [%[packed_output_data]], #16 \n" + "st1 {v20.4s}, [%[packed_output_data]], #16 \n" + "st1 {v21.4s}, [%[packed_output_data]], #16 \n" + "st1 {v22.4s}, [%[packed_output_data]], #16 \n" + "st1 {v23.4s}, [%[packed_output_data]], #16 \n" + "st1 {v24.4s}, [%[packed_output_data]], #16 \n" + "st1 {v25.4s}, [%[packed_output_data]], #16 \n" + "st1 {v26.4s}, [%[packed_output_data]], #16 \n" + "st1 {v27.4s}, [%[packed_output_data]], #16 \n" + "st1 {v28.4s}, [%[packed_output_data]], #16 \n" + "st1 {v29.4s}, [%[packed_output_data]], #16 \n" + "st1 {v30.4s}, [%[packed_output_data]], #16 \n" + "st1 {v31.4s}, [%[packed_output_data]], #16 \n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), + [rhs_ptr] "+r"(rhs_ptr), + [packed_output_data] "+r"(packed_output_data), + [r_depth_block_count] "+r"(r_depth_block_count) + : // inputs + : // clabbers + "cc", "memory", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } +#else // armeabi-v7a + // Register layout: (4x4) x (4,8) + // + // +--------+--------+ + // | q4 ... | q5 ... | + // Rhs +--------+--------+ + // | q6 ... | q7 ... | + // +--------+--------+ + // | q4 ... | q5 ... | + // +--------+--------+ + // | q6 ... | q7 ... | + // +--------+--------+ + // + // Lhs + // + // +----+----+----+----+ - - +--------+--------+ + // | q0 | q1 | q2 | q3 | | q8... | q9... | + // | . | | | | | q10... | q11... | + // | . | | | | | q12... | q13... | + // | . | | | | | q14... | q15... | + // +----+----+----+----+ +--------+--------+ + // + // Accumulator + // + + if (depth_block_count > 0) { + index_t r_depth_block_count = depth_block_count; + // just make compiler happy + MACE_UNUSED(r_depth_block_count); + + asm volatile( + "mov r0, #0\n" + "vdup.f32 q8, r0 \n" + "vdup.f32 q9, r0 \n" + "vdup.f32 q10, r0 \n" + "vdup.f32 q11, r0 \n" + "vdup.f32 q12, r0 \n" + "vdup.f32 q13, r0 \n" + "vdup.f32 q14, r0 \n" + "vdup.f32 q15, r0 \n" + + // prelogue + "vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n" + "vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n" + "vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n" + + "vld1.f32 {d8-d9}, [%[rhs_ptr]]! \n" + "vld1.f32 {d10-d11}, [%[rhs_ptr]]! \n" + "vld1.f32 {d12-d13}, [%[rhs_ptr]]! \n" + "vld1.f32 {d14-d15}, [%[rhs_ptr]]! \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + "beq 1f\n" + + "0: \n" + + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q5, d0[0] \n" + "vmla.f32 q10, q4, d0[1] \n" + "vmla.f32 q11, q5, d0[1] \n" + "vmla.f32 q12, q4, d1[0] \n" + "vmla.f32 q13, q5, d1[0] \n" + "vmla.f32 q14, q4, d1[1] \n" + "vmla.f32 q15, q5, d1[1] \n" + + "vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n" + "vld1.f32 {d8-d9}, [%[rhs_ptr]]! \n" + "vld1.f32 {d10-d11}, [%[rhs_ptr]]! \n" + + "vmla.f32 q8, q6, d2[0] \n" + "vmla.f32 q9, q7, d2[0] \n" + "vmla.f32 q10, q6, d2[1] \n" + "vmla.f32 q11, q7, d2[1] \n" + "vmla.f32 q12, q6, d3[0] \n" + "vmla.f32 q13, q7, d3[0] \n" + "vmla.f32 q14, q6, d3[1] \n" + "vmla.f32 q15, q7, d3[1] \n" + + "vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n" + "vld1.f32 {d12-d13}, [%[rhs_ptr]]! \n" + "vld1.f32 {d14-d15}, [%[rhs_ptr]]! \n" + + "vmla.f32 q8, q4, d4[0] \n" + "vmla.f32 q9, q5, d4[0] \n" + "vmla.f32 q10, q4, d4[1] \n" + "vmla.f32 q11, q5, d4[1] \n" + "vmla.f32 q12, q4, d5[0] \n" + "vmla.f32 q13, q5, d5[0] \n" + "vmla.f32 q14, q4, d5[1] \n" + "vmla.f32 q15, q5, d5[1] \n" + + "vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n" + "vld1.f32 {d8-d9}, [%[rhs_ptr]]! \n" + "vld1.f32 {d10-d11}, [%[rhs_ptr]]! \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + + "vmla.f32 q8, q6, d6[0] \n" + "vmla.f32 q9, q7, d6[0] \n" + "vmla.f32 q10, q6, d6[1] \n" + "vmla.f32 q11, q7, d6[1] \n" + "vmla.f32 q12, q6, d7[0] \n" + "vmla.f32 q13, q7, d7[0] \n" + "vmla.f32 q14, q6, d7[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + + "vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n" + "vld1.f32 {d12-d13}, [%[rhs_ptr]]! \n" + "vld1.f32 {d14-d15}, [%[rhs_ptr]]! \n" + + "bne 0b \n" + + // prologue + "1:\n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q5, d0[0] \n" + "vmla.f32 q10, q4, d0[1] \n" + "vmla.f32 q11, q5, d0[1] \n" + "vmla.f32 q12, q4, d1[0] \n" + "vmla.f32 q13, q5, d1[0] \n" + "vmla.f32 q14, q4, d1[1] \n" + "vmla.f32 q15, q5, d1[1] \n" + + "vld1.f32 {d8-d9}, [%[rhs_ptr]]! \n" + "vld1.f32 {d10-d11}, [%[rhs_ptr]]! \n" + + "vmla.f32 q8, q6, d2[0] \n" + "vmla.f32 q9, q7, d2[0] \n" + "vmla.f32 q10, q6, d2[1] \n" + "vmla.f32 q11, q7, d2[1] \n" + "vmla.f32 q12, q6, d3[0] \n" + "vmla.f32 q13, q7, d3[0] \n" + "vmla.f32 q14, q6, d3[1] \n" + "vmla.f32 q15, q7, d3[1] \n" + + "vld1.f32 {d12-d13}, [%[rhs_ptr]]! \n" + "vld1.f32 {d14-d15}, [%[rhs_ptr]]! \n" + + "vmla.f32 q8, q4, d4[0] \n" + "vmla.f32 q9, q5, d4[0] \n" + "vmla.f32 q10, q4, d4[1] \n" + "vmla.f32 q11, q5, d4[1] \n" + "vmla.f32 q12, q4, d5[0] \n" + "vmla.f32 q13, q5, d5[0] \n" + "vmla.f32 q14, q4, d5[1] \n" + "vmla.f32 q15, q5, d5[1] \n" + + "vmla.f32 q8, q6, d6[0] \n" + "vmla.f32 q9, q7, d6[0] \n" + "vmla.f32 q10, q6, d6[1] \n" + "vmla.f32 q11, q7, d6[1] \n" + "vmla.f32 q12, q6, d7[0] \n" + "vmla.f32 q13, q7, d7[0] \n" + "vmla.f32 q14, q6, d7[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + + "vst1.f32 {d16-d17}, [%[packed_output_data]]! \n" + "vst1.f32 {d18-d19}, [%[packed_output_data]]! \n" + "vst1.f32 {d20-d21}, [%[packed_output_data]]! \n" + "vst1.f32 {d22-d23}, [%[packed_output_data]]! \n" + "vst1.f32 {d24-d25}, [%[packed_output_data]]! \n" + "vst1.f32 {d26-d27}, [%[packed_output_data]]! \n" + "vst1.f32 {d28-d29}, [%[packed_output_data]]! \n" + "vst1.f32 {d30-d31}, [%[packed_output_data]]! \n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), + [rhs_ptr] "+r"(rhs_ptr), + [packed_output_data] "+r"(packed_output_data), + [r_depth_block_count] "+r"(r_depth_block_count) + : // inputs + : // clabbers + "cc", "memory", "r0", + "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + } +#endif +} + +void Gemm::PackLhs(const MatrixMap &lhs, + float *packed_lhs) { +#ifdef __aarch64__ + Pack<8, 4>(lhs, ColMajor, packed_lhs); +#else + Pack<4, 4>(lhs, ColMajor, packed_lhs); +#endif +} + +void Gemm::PackRhs(const MatrixMap &rhs, + float *packed_rhs) { + Pack<8, 4>(rhs, RowMajor, packed_rhs); +} + +void Gemm::UnpackOutput(const float *packed_output, MatrixMap *output) { +#ifdef __aarch64__ + Unpack<8, 8>(packed_output, output); +#else + Unpack<4, 8>(packed_output, output); +#endif +} + +template<> +void Gemm::Pack<4, 4>(const MatrixMap &matrix, + MatrixMajor dst_major, + float *packed_matrix) { + const index_t rows = matrix.rows(); + const index_t cols = matrix.cols(); + + // use the same terminology as GemmLowp: + // depth is depth, width is the opposite dim other than depth + // lhs + index_t width = rows; + index_t depth = cols; + index_t width_stride = matrix.rows_stride(); + index_t depth_stride = matrix.cols_stride(); + if (dst_major == RowMajor) { + // rhs + std::swap(width, depth); + std::swap(width_stride, depth_stride); + } + const float *data = matrix.data(); + float *packed_ptr = packed_matrix; + + const index_t block_size = 4; + const index_t depth_padded = RoundUp(depth, static_cast(4)); + + if (depth_padded > depth) { + memset(packed_ptr + depth * block_size, + 0, + sizeof(float) * (depth_padded - depth) * block_size); + } + + if (dst_major == matrix.matrix_major()) { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + memcpy(packed_ptr, data, sizeof(float) * width); + memset(packed_ptr + width, 0, sizeof(float) * width_remain); + data += depth_stride; + packed_ptr += block_size; + } + } else { + for (index_t d = 0; d < depth; ++d) { + float32x4_t vi = vld1q_f32(data); + vst1q_f32(packed_ptr, vi); + data += depth_stride; + packed_ptr += block_size; + } + } + } else { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + for (index_t w = 0; w < width; ++w) { + packed_ptr[w] = data[w * width_stride + d]; + } // w + memset(packed_ptr + width, 0, sizeof(float) * width_remain); + packed_ptr += block_size; + } // d + } else { + const float *data0 = data; + const float *data1 = data + width_stride; + const float *data2 = data1 + width_stride; + const float *data3 = data2 + width_stride; + + const index_t depth_block = depth / 4; + const index_t depth_remain = depth - depth_block * 4; + for (index_t depth_block_idx = 0; depth_block_idx < depth_block; + ++depth_block_idx) { + float32x4_t v0 = vld1q_f32(data0); + float32x4_t v1 = vld1q_f32(data1); + float32x4_t v2 = vld1q_f32(data2); + float32x4_t v3 = vld1q_f32(data3); + float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); + float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); + float32x4x2_t v0123_intertwined = + vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); + float32x4x2_t v0123n_intertwined = + vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); + + vst1q_f32(packed_ptr, v0123_intertwined.val[0]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v0123_intertwined.val[1]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v0123n_intertwined.val[0]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v0123n_intertwined.val[1]); + packed_ptr += 4; + + data0 += 4; + data1 += 4; + data2 += 4; + data3 += 4; + } + for (index_t d = 0; d < depth_remain; ++d) { + float32x4_t vi = {*data0, *data1, *data2, *data3}; + vst1q_f32(packed_ptr, vi); + packed_ptr += 4; + + ++data0; + ++data1; + ++data2; + ++data3; + } // d + } + } +} + +template<> +void Gemm::Pack<8, 4>(const MatrixMap &matrix, + MatrixMajor dst_major, + float *packed_matrix) { + const index_t rows = matrix.rows(); + const index_t cols = matrix.cols(); + + // use the same terminology as GemmLowp: + // depth is depth, width is the opposite dim other than depth + // lhs + index_t width = rows; + index_t depth = cols; + index_t width_stride = matrix.rows_stride(); + index_t depth_stride = matrix.cols_stride(); + if (dst_major == RowMajor) { + // rhs + std::swap(width, depth); + std::swap(width_stride, depth_stride); + } + const float *data = matrix.data(); + float *packed_ptr = packed_matrix; + + const index_t block_size = 8; + const index_t depth_padded = RoundUp(depth, static_cast(4)); + + if (depth_padded > depth) { + memset(packed_ptr + depth * block_size, + 0, + sizeof(float) * (depth_padded - depth) * block_size); + } + + if (dst_major == matrix.matrix_major()) { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + memcpy(packed_ptr, data, sizeof(float) * width); + memset(packed_ptr + width, 0, sizeof(float) * width_remain); + data += depth_stride; + packed_ptr += block_size; + } + } else { + for (index_t d = 0; d < depth; ++d) { + float32x4_t vi = vld1q_f32(data); + vst1q_f32(packed_ptr, vi); + float32x4_t vin = vld1q_f32(data + 4); + vst1q_f32(packed_ptr + 4, vin); + data += depth_stride; + packed_ptr += block_size; + } + } + } else { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + for (index_t w = 0; w < width; ++w) { + packed_ptr[w] = data[w * width_stride + d]; + } // w + memset(packed_ptr + width, 0, sizeof(float) * width_remain); + packed_ptr += block_size; + } // d + } else { + const float *data0 = data; + const float *data1 = data + width_stride; + const float *data2 = data1 + width_stride; + const float *data3 = data2 + width_stride; + const float *data4 = data3 + width_stride; + const float *data5 = data4 + width_stride; + const float *data6 = data5 + width_stride; + const float *data7 = data6 + width_stride; + + const index_t depth_block = depth / 4; + const index_t depth_remain = depth - depth_block * 4; + for (index_t depth_block_idx = 0; depth_block_idx < depth_block; + ++depth_block_idx) { + float32x4_t v0 = vld1q_f32(data0); + float32x4_t v1 = vld1q_f32(data1); + float32x4_t v2 = vld1q_f32(data2); + float32x4_t v3 = vld1q_f32(data3); + float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); + float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); + float32x4x2_t v0123_intertwined = + vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); + float32x4x2_t v0123n_intertwined = + vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); + + float32x4_t v4 = vld1q_f32(data4); + float32x4_t v5 = vld1q_f32(data5); + float32x4_t v6 = vld1q_f32(data6); + float32x4_t v7 = vld1q_f32(data7); + float32x4x2_t v46_intertwined = vzipq_f32(v4, v6); + float32x4x2_t v57_intertwined = vzipq_f32(v5, v7); + float32x4x2_t v4567_intertwined = + vzipq_f32(v46_intertwined.val[0], v57_intertwined.val[0]); + float32x4x2_t v4567n_intertwined = + vzipq_f32(v46_intertwined.val[1], v57_intertwined.val[1]); + + vst1q_f32(packed_ptr, v0123_intertwined.val[0]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v4567_intertwined.val[0]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v0123_intertwined.val[1]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v4567_intertwined.val[1]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v0123n_intertwined.val[0]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v4567n_intertwined.val[0]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v0123n_intertwined.val[1]); + packed_ptr += 4; + + vst1q_f32(packed_ptr, v4567n_intertwined.val[1]); + packed_ptr += 4; + + data0 += 4; + data1 += 4; + data2 += 4; + data3 += 4; + data4 += 4; + data5 += 4; + data6 += 4; + data7 += 4; + } + for (index_t d = 0; d < depth_remain; ++d) { + float32x4_t vi = {*data0, *data1, *data2, *data3}; + vst1q_f32(packed_ptr, vi); + packed_ptr += 4; + + float32x4_t vin = {*data4, *data5, *data6, *data7}; + vst1q_f32(packed_ptr, vin); + packed_ptr += 4; + + ++data0; + ++data1; + ++data2; + ++data3; + ++data4; + ++data5; + ++data6; + ++data7; + } // d + } + } +} + +template<> +void Gemm::Unpack<4, 8>(const float *packed_output, MatrixMap *output) { + const index_t rows = output->rows(); + const index_t cols = output->cols(); + index_t row_stride = output->rows_stride(); + index_t col_stride = output->cols_stride(); + + float *output_ptr = output->data(); + const float *packed_ptr = packed_output; + + const index_t block_size = 8; + + // packed_output always has row-major + if (output->matrix_major() == RowMajor) { + if (cols < block_size) { + for (index_t r = 0; r < rows; ++r) { + memcpy(output_ptr, packed_ptr, sizeof(float) * cols); + output_ptr += row_stride; + packed_ptr += block_size; + } + } else { + for (index_t r = 0; r < rows; ++r) { + float32x4_t vi = vld1q_f32(packed_ptr); + vst1q_f32(output_ptr, vi); + float32x4_t vin = vld1q_f32(packed_ptr + 4); + vst1q_f32(output_ptr + 4, vin); + + output_ptr += row_stride; + packed_ptr += block_size; + } + } + } else { + // ColMajor + if (rows < block_size) { + for (index_t c = 0; c < cols; ++c) { + for (index_t r = 0; r < rows; ++r) { + output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c]; + } // r + } // c + } else { + const float *data0 = packed_ptr; + const float *data1 = data0 + block_size; + const float *data2 = data1 + block_size; + const float *data3 = data2 + block_size; + + index_t col_block = cols / 4; + index_t col_remain = cols - col_block * 4; + for (index_t col_block_idx = 0; col_block_idx < col_block; + ++col_block_idx) { + float32x4_t v0 = vld1q_f32(data0); + float32x4_t v1 = vld1q_f32(data1); + float32x4_t v2 = vld1q_f32(data2); + float32x4_t v3 = vld1q_f32(data3); + float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); + float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); + float32x4x2_t v0123_intertwined = + vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); + float32x4x2_t v0123n_intertwined = + vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); + + vst1q_f32(output_ptr, v0123_intertwined.val[0]); + output_ptr += col_stride; + + vst1q_f32(output_ptr, v0123_intertwined.val[1]); + output_ptr += col_stride; + + vst1q_f32(output_ptr, v0123n_intertwined.val[0]); + output_ptr += col_stride; + + vst1q_f32(output_ptr, v0123n_intertwined.val[1]); + output_ptr += col_stride; + + data0 += 4; + data1 += 4; + data2 += 4; + data3 += 4; + } + for (index_t c = 0; c < col_remain; ++c) { + float32x4_t vi = {*data0, *data1, *data2, *data3}; + vst1q_f32(output_ptr, vi); + output_ptr += col_stride; + + ++data0; + ++data1; + ++data2; + ++data3; + } // d + } + } +} + +template<> +void Gemm::Unpack<8, 8>(const float *packed_output, MatrixMap *output) { + const index_t rows = output->rows(); + const index_t cols = output->cols(); + index_t row_stride = output->rows_stride(); + index_t col_stride = output->cols_stride(); + + float *output_ptr = output->data(); + const float *packed_ptr = packed_output; + + const index_t block_size = 8; + + // packed_output always has row-major + if (output->matrix_major() == RowMajor) { + if (cols < block_size) { + for (index_t r = 0; r < rows; ++r) { + memcpy(output_ptr, packed_ptr, sizeof(float) * cols); + output_ptr += row_stride; + packed_ptr += block_size; + } + } else { + for (index_t r = 0; r < rows; ++r) { + float32x4_t vi = vld1q_f32(packed_ptr); + vst1q_f32(output_ptr, vi); + float32x4_t vin = vld1q_f32(packed_ptr + 4); + vst1q_f32(output_ptr + 4, vin); + + output_ptr += row_stride; + packed_ptr += block_size; + } + } + } else { + // ColMajor + if (rows < block_size) { + for (index_t c = 0; c < cols; ++c) { + for (index_t r = 0; r < rows; ++r) { + output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c]; + } // r + } // c + } else { + const float *data0 = packed_ptr; + const float *data1 = data0 + block_size; + const float *data2 = data1 + block_size; + const float *data3 = data2 + block_size; + const float *data4 = data3 + block_size; + const float *data5 = data4 + block_size; + const float *data6 = data5 + block_size; + const float *data7 = data6 + block_size; + + index_t col_block = cols / 4; + index_t col_remain = cols - col_block * 4; + for (index_t col_block_idx = 0; col_block_idx < col_block; + ++col_block_idx) { + float32x4_t v0 = vld1q_f32(data0); + float32x4_t v1 = vld1q_f32(data1); + float32x4_t v2 = vld1q_f32(data2); + float32x4_t v3 = vld1q_f32(data3); + float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); + float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); + float32x4x2_t v0123_intertwined = + vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); + float32x4x2_t v0123n_intertwined = + vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); + + float32x4_t v4 = vld1q_f32(data4); + float32x4_t v5 = vld1q_f32(data5); + float32x4_t v6 = vld1q_f32(data6); + float32x4_t v7 = vld1q_f32(data7); + float32x4x2_t v46_intertwined = vzipq_f32(v4, v6); + float32x4x2_t v57_intertwined = vzipq_f32(v5, v7); + float32x4x2_t v4567_intertwined = + vzipq_f32(v46_intertwined.val[0], v57_intertwined.val[0]); + float32x4x2_t v4567n_intertwined = + vzipq_f32(v46_intertwined.val[1], v57_intertwined.val[1]); + + vst1q_f32(output_ptr, v0123_intertwined.val[0]); + vst1q_f32(output_ptr + 4, v4567_intertwined.val[0]); + output_ptr += col_stride; + + vst1q_f32(output_ptr, v0123_intertwined.val[1]); + vst1q_f32(output_ptr + 4, v4567_intertwined.val[1]); + output_ptr += col_stride; + + vst1q_f32(output_ptr, v0123n_intertwined.val[0]); + vst1q_f32(output_ptr + 4, v4567n_intertwined.val[0]); + output_ptr += col_stride; + + vst1q_f32(output_ptr, v0123n_intertwined.val[1]); + vst1q_f32(output_ptr + 4, v4567n_intertwined.val[1]); + output_ptr += col_stride; + + data0 += 4; + data1 += 4; + data2 += 4; + data3 += 4; + data4 += 4; + data5 += 4; + data6 += 4; + data7 += 4; + } + for (index_t c = 0; c < col_remain; ++c) { + float32x4_t vi = {*data0, *data1, *data2, *data3}; + vst1q_f32(output_ptr, vi); + float32x4_t vin = {*data4, *data5, *data6, *data7}; + vst1q_f32(output_ptr + 4, vin); + output_ptr += col_stride; + + ++data0; + ++data1; + ++data2; + ++data3; + ++data4; + ++data5; + ++data6; + ++data7; + } // d + } + } +} + +MaceStatus Gemm::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t lhs_rows, + const index_t lhs_cols, + const index_t rhs_rows, + const index_t rhs_cols, + const bool transpose_lhs, + const bool transpose_rhs, + const bool transpose_out, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output) { + index_t rows = transpose_lhs ? lhs_cols : lhs_rows; + index_t depth = transpose_lhs ? lhs_rows : lhs_cols; + index_t cols = transpose_rhs ? rhs_rows : rhs_cols; + index_t depth2 = transpose_rhs ? rhs_cols : rhs_rows; + MACE_CHECK(depth == depth2, + "Matrices that multiply have inconsistent depth dim: ", + depth, + " vs. ", + depth2); + + return Compute(context, + lhs, + rhs, + batch, + rows, + cols, + depth, + transpose_lhs ? ColMajor : RowMajor, + transpose_rhs ? ColMajor : RowMajor, + transpose_out ? ColMajor : RowMajor, + lhs_batched, + rhs_batched, + output); +} + +} // namespace fp32 +} // namespace arm +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/fp32/gemm.h b/mace/ops/arm/fp32/gemm.h new file mode 100644 index 00000000..f4cfc42b --- /dev/null +++ b/mace/ops/arm/fp32/gemm.h @@ -0,0 +1,156 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_ARM_FP32_GEMM_H_ +#define MACE_OPS_ARM_FP32_GEMM_H_ + +#include "mace/public/mace.h" +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" +#include "mace/ops/common/matrix.h" + +// This implements matrix-matrix multiplication. +// In the case of matrix-vector multiplication, use gemv.h/gemv.cc instead + +namespace mace { +namespace ops { +namespace arm { +namespace fp32 { + +class Gemm { + public: + explicit Gemm(const bool should_cache_pack) + : tmp_scratch_buffer_(GetCPUAllocator()), + pack_cache_(GetCPUAllocator()), + should_cache_pack_(should_cache_pack), + cached_(0) {} + Gemm() : Gemm(false) {} + ~Gemm() {} + + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t rows, + const index_t cols, + const index_t depth, + const MatrixMajor lhs_major, + const MatrixMajor rhs_major, + const MatrixMajor output_major, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output); + + // Original matrix before transpose has row-major + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t lhs_rows, + const index_t lhs_cols, + const index_t rhs_rows, + const index_t rhs_cols, + const bool transpose_lhs, + const bool transpose_rhs, + const bool transpose_out, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output); + + private: + void ComputeBlock(const float *packed_lhs_data, + const float *packed_rhs_data, + const index_t depth_padded, + float *packed_output_data); + + void PackLhs(const MatrixMap &lhs, + float *packed_lhs); + + void PackRhs(const MatrixMap &rhs, + float *packed_rhs); + + void UnpackOutput(const float *packed_output, + MatrixMap *output); + + template + void Unpack(const float *packed_output, + MatrixMap *output) { + const index_t rows = output->rows(); + const index_t cols = output->cols(); + for (index_t r = 0; r < rows; ++r) { + for (index_t c = 0; c < cols; ++c) { + *output->data(r, c) = packed_output[r * ColBlockSize + c]; + } + } + } + + template + void Pack(const MatrixMap &matrix, + MatrixMajor dst_major, + float *packed_matrix) { + const index_t rows = matrix.rows(); + const index_t cols = matrix.cols(); + index_t depth = cols; + if (dst_major == RowMajor) { + // rhs + depth = rows; + } + const index_t depth_padded = RoundUp(depth, static_cast(4)); + memset(packed_matrix, 0, sizeof(float) * WidthBlockSize * depth_padded); + if (dst_major == ColMajor) { + for (index_t c = 0; c < cols; ++c) { + for (index_t r = 0; r < rows; ++r) { + packed_matrix[c * WidthBlockSize + r] = matrix(r, c); + } + } + } else { + for (index_t r = 0; r < rows; ++r) { + for (index_t c = 0; c < cols; ++c) { + packed_matrix[r * WidthBlockSize + c] = matrix(r, c); + } + } + } + } + + ScratchBuffer tmp_scratch_buffer_; + Buffer pack_cache_; + + bool should_cache_pack_; + int cached_; +}; + +template<> +void Gemm::Pack<4, 4>(const MatrixMap &matrix, + MatrixMajor dst_major, + float *packed_matrix); + +template<> +void Gemm::Pack<8, 4>(const MatrixMap &matrix, + MatrixMajor dst_major, + float *packed_matrix); + +template<> +void Gemm::Unpack<4, 8>(const float *packed_output, MatrixMap *output); + +template<> +void Gemm::Unpack<8, 8>(const float *packed_output, MatrixMap *output); + +} // namespace fp32 +} // namespace arm +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_ARM_FP32_GEMM_H_ diff --git a/mace/ops/arm/fp32/gemm_test.cc b/mace/ops/arm/fp32/gemm_test.cc new file mode 100644 index 00000000..372b3eb6 --- /dev/null +++ b/mace/ops/arm/fp32/gemm_test.cc @@ -0,0 +1,115 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include + +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" +#include "mace/ops/arm/fp32/gemm.h" +#include "mace/ops/ref/gemm.h" +#include "mace/ops/testing/test_utils.h" + +namespace mace { +namespace ops { +namespace test { + +void TestGemmFloat32(const index_t batch, + const index_t rows, + const index_t cols, + const index_t depth, + const MatrixMajor lhs_major, + const MatrixMajor rhs_major, + const MatrixMajor output_major, + const bool lhs_batched, + const bool rhs_batched) { + Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT); + Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT); + Tensor output(GetCPUAllocator(), DataType::DT_FLOAT); + lhs.Resize({lhs_batched ? batch : 1, rows, depth}); + rhs.Resize({rhs_batched ? batch : 1, depth, cols}); + output.Resize({batch, rows, cols}); + { + Tensor::MappingGuard lhs_guard(&lhs); + Tensor::MappingGuard rhs_guard(&rhs); + float *lhs_data = lhs.mutable_data(); + float *rhs_data = rhs.mutable_data(); + float *output_data = output.mutable_data(); + GenerateRandomRealTypeData(lhs.shape(), lhs_data); + GenerateRandomRealTypeData(rhs.shape(), rhs_data); + GenerateRandomRealTypeData(output.shape(), output_data); + } + ::mace::ops::arm::fp32::Gemm gemm; + gemm.Compute(nullptr, + &lhs, + &rhs, + batch, + rows, + cols, + depth, + lhs_major, + rhs_major, + output_major, + lhs_batched, + rhs_batched, + &output); + + Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT); + expected_output.Resize({batch, rows, cols}); + ::mace::ops::ref::Gemm gemm_ref; + gemm_ref.Compute(nullptr, + &lhs, + &rhs, + batch, + rows, + cols, + depth, + lhs_major, + rhs_major, + output_major, + lhs_batched, + rhs_batched, + &expected_output); + + ExpectTensorNear(expected_output, output); +} + +TEST(ArmGemm, TestGemmFloat32) { + TestGemmFloat32(1, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true); + TestGemmFloat32(1, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true); + TestGemmFloat32(1, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true); + TestGemmFloat32(1, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true); + TestGemmFloat32(1, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true); + TestGemmFloat32(1, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true); + TestGemmFloat32(1, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true); + TestGemmFloat32(1, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true); + + TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true); + TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true); + TestGemmFloat32(3, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true); + TestGemmFloat32(3, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true); + TestGemmFloat32(3, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true); + TestGemmFloat32(3, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true); + TestGemmFloat32(3, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true); + TestGemmFloat32(3, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true); + + TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, false); + TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, false, true); + + TestGemmFloat32(16, 31, 61, 67, RowMajor, ColMajor, RowMajor, true, true); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/fp32/gemv.cc b/mace/ops/arm/fp32/gemv.cc index 39b25bf5..a146de4c 100644 --- a/mace/ops/arm/fp32/gemv.cc +++ b/mace/ops/arm/fp32/gemv.cc @@ -22,6 +22,9 @@ #define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) #endif +// Disable unroll by default, since cache set conflict could be significant +// #define MACE_GEMV_UNROLL 1 + namespace mace { namespace ops { namespace arm { @@ -35,14 +38,26 @@ MaceStatus Gemv::Compute(const OpContext *context, const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output) { MACE_UNUSED(context); + MACE_CHECK(output->size() == batch * lhs_height, + "Need resize output tensor before call gemv."); + Tensor::MappingGuard lhs_guard(lhs); Tensor::MappingGuard rhs_guard(rhs); Tensor::MappingGuard bias_guard(bias); Tensor::MappingGuard output_guard(output); + const float *lhs_data = lhs->data(); + const float *rhs_data = rhs->data(); + const float *bias_data = nullptr; + if (bias) { + bias_data = bias->data(); + } + float *output_data = output->mutable_data(); + const index_t h_block_size = 4; const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size); const index_t w_block_size = 8; @@ -52,28 +67,20 @@ MaceStatus Gemv::Compute(const OpContext *context, #pragma omp parallel for collapse(2) schedule(runtime) for (index_t b = 0; b < batch; ++b) { for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) { - // TODO(liyin): it can be put it outside the loop, - // but openmp limits param count - const float *lhs_data = lhs->data(); - const float *rhs_data = rhs->data(); - const float *bias_data = nullptr; - if (bias) { - bias_data = bias->data(); - } - float *output_data = output->mutable_data(); - + const index_t h_start = h_block_idx * h_block_size; const float *lhs_ptr = lhs_data + static_cast(lhs_batched) * b * lhs_height * lhs_width - + lhs_width * h_block_idx * h_block_size; - const float *rhs_ptr = rhs_data + b * lhs_width; + + lhs_width * h_start; + const float *rhs_ptr = + rhs_data + static_cast(rhs_batched) * b * lhs_width; float - *ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size; + *ret_ptr = output_data + b * lhs_height + h_start; const index_t h_block_len = - std::min(h_block_size, lhs_height - h_block_idx * h_block_size); - const index_t h_offset = h_block_idx * h_block_size; + std::min(h_block_size, lhs_height - h_start); +#ifdef MACE_GEMV_UNROLL if (h_block_len == 4) { float32x4_t vo0 = vdupq_n_f32(0); float32x4_t vo1 = vdupq_n_f32(0); @@ -149,6 +156,11 @@ MaceStatus Gemv::Compute(const OpContext *context, "vmla.f32 %q[vo2], q4, q8\n" "vmla.f32 %q[vo3], q6, q8\n" + "vld1.f32 {d0-d1}, [r1]!\n" + "vld1.f32 {d4-d5}, [r2]!\n" + "vld1.f32 {d8-d9}, [r3]!\n" + "vld1.f32 {d12-d13}, [r4]!\n" + "vld1.f32 {d16-d17}, [r0]!\n" "vmla.f32 %q[vo0], q1, q9\n" "vmla.f32 %q[vo1], q3, q9\n" @@ -157,13 +169,6 @@ MaceStatus Gemv::Compute(const OpContext *context, "subs %[r_w_block_count], #1\n" - - "vld1.f32 {d0-d1}, [r1]!\n" - "vld1.f32 {d4-d5}, [r2]!\n" - "vld1.f32 {d8-d9}, [r3]!\n" - "vld1.f32 {d12-d13}, [r4]!\n" - "vld1.f32 {d16-d17}, [r0]!\n" - "vld1.f32 {d2-d3}, [r1]!\n" "vld1.f32 {d6-d7}, [r2]!\n" "vld1.f32 {d10-d11}, [r3]!\n" @@ -257,26 +262,30 @@ MaceStatus Gemv::Compute(const OpContext *context, vo = vaddq_f32(vo, vbias); vst1q_f32(ret_ptr, vo); } else { // h_block_len < 4 - // TODO(liyin): handle here case by case (1,2,3) to accelerate +#endif // MACE_GEMV_UNROLL const float *tmp_lhs_ptr = lhs_ptr; const float *tmp_rhs_ptr = rhs_ptr; for (index_t h = 0; h < h_block_len; ++h) { lhs_ptr = tmp_lhs_ptr + h * lhs_width; rhs_ptr = tmp_rhs_ptr; float32x4_t vo0 = vdupq_n_f32(0); + float32x4_t vo0n = vdupq_n_f32(0); for (index_t w = 0; w < w_block_count; ++w) { float32x4_t vr0 = vld1q_f32(rhs_ptr); float32x4_t vr0n = vld1q_f32(rhs_ptr + 4); float32x4_t vl0 = vld1q_f32(lhs_ptr); float32x4_t vl0n = vld1q_f32(lhs_ptr + 4); + + // may cause some precision error depending on the compute order vo0 = vmlaq_f32(vo0, vl0, vr0); - vo0 = vmlaq_f32(vo0, vl0n, vr0n); + vo0n = vmlaq_f32(vo0n, vl0n, vr0n); lhs_ptr += 8; rhs_ptr += 8; } // w - float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_offset + h] : 0); + vo0 = vaddq_f32(vo0, vo0n); + float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_start + h] : 0); for (index_t w = 0; w < w_remain; ++w) { s0 += lhs_ptr[0] * rhs_ptr[0]; ++lhs_ptr; @@ -285,7 +294,9 @@ MaceStatus Gemv::Compute(const OpContext *context, ret_ptr[h] = s0; } // h +#ifdef MACE_GEMV_UNROLL } // if +#endif // MACE_GEMV_UNROLL } // h_block_idx } // b diff --git a/mace/ops/arm/fp32/gemv.h b/mace/ops/arm/fp32/gemv.h index 5b1551fa..3210def1 100644 --- a/mace/ops/arm/fp32/gemv.h +++ b/mace/ops/arm/fp32/gemv.h @@ -38,6 +38,7 @@ class Gemv { const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output); }; diff --git a/mace/ops/arm/fp32/gemv_test.cc b/mace/ops/arm/fp32/gemv_test.cc index c2d13b38..b6b69254 100644 --- a/mace/ops/arm/fp32/gemv_test.cc +++ b/mace/ops/arm/fp32/gemv_test.cc @@ -28,13 +28,14 @@ namespace test { void TestGemvFloat32(const index_t batch, const index_t height, const index_t width, - const bool lhs_batched) { + const bool lhs_batched, + const bool rhs_batched) { Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT); Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT); Tensor bias(GetCPUAllocator(), DataType::DT_FLOAT); Tensor output(GetCPUAllocator(), DataType::DT_FLOAT); lhs.Resize({lhs_batched ? batch : 1, height, width}); - rhs.Resize({batch, width}); + rhs.Resize({rhs_batched ? batch : 1, width}); bias.Resize({height}); output.Resize({batch, height}); { @@ -57,6 +58,7 @@ void TestGemvFloat32(const index_t batch, height, width, lhs_batched, + rhs_batched, &output); Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT); @@ -70,28 +72,22 @@ void TestGemvFloat32(const index_t batch, height, width, lhs_batched, + rhs_batched, &expected_output); - Tensor::MappingGuard output_guard(&output); - Tensor::MappingGuard expected_guard(&expected_output); - const float *output_data = output.data(); - const float *expected_data = expected_output.data(); - - for (index_t i = 0; i < output.size(); ++i) { - EXPECT_NEAR(expected_data[i], output_data[i], 0.001); - } + ExpectTensorNear(expected_output, output); } TEST(ArmGemv, TestGemvFloat32) { - TestGemvFloat32(1, 16, 4, true); - TestGemvFloat32(1, 16, 256, true); - TestGemvFloat32(2, 16, 256, true); - TestGemvFloat32(3, 63, 257, true); + TestGemvFloat32(1, 16, 4, true, true); + TestGemvFloat32(1, 16, 256, true, true); + TestGemvFloat32(2, 16, 256, true, true); + TestGemvFloat32(3, 63, 257, true, true); - TestGemvFloat32(1, 16, 4, false); - TestGemvFloat32(1, 16, 256, false); - TestGemvFloat32(2, 16, 256, false); - TestGemvFloat32(3, 63, 257, false); + TestGemvFloat32(2, 16, 256, false, true); + TestGemvFloat32(3, 63, 257, false, true); + TestGemvFloat32(2, 16, 256, true, false); + TestGemvFloat32(3, 63, 257, true, false); } } // namespace test diff --git a/mace/ops/arm/q8/gemv.cc b/mace/ops/arm/q8/gemv.cc index 7117dcac..790a1448 100644 --- a/mace/ops/arm/q8/gemv.cc +++ b/mace/ops/arm/q8/gemv.cc @@ -43,6 +43,7 @@ MaceStatus Gemv::Compute(const OpContext *context, const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output) { MACE_UNUSED(context); @@ -100,7 +101,8 @@ MaceStatus Gemv::Compute(const OpContext *context, *lhs_ptr = lhs_data + static_cast(lhs_batched) * b * lhs_height * lhs_width + lhs_width * h_block_idx * h_block_size; - const uint8_t *rhs_ptr = rhs_data + b * lhs_width; + const uint8_t *rhs_ptr = + rhs_data + static_cast(rhs_batched) * b * lhs_width; OUTPUT_TYPE *ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size; diff --git a/mace/ops/arm/q8/gemv.h b/mace/ops/arm/q8/gemv.h index 2269ec98..adcb9590 100644 --- a/mace/ops/arm/q8/gemv.h +++ b/mace/ops/arm/q8/gemv.h @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +// This implements matrix-vector multiplication described as +// https://github.com/google/gemmlowp/blob/master/todo/fast-gemv.txt + #ifndef MACE_OPS_ARM_Q8_GEMV_H_ #define MACE_OPS_ARM_Q8_GEMV_H_ @@ -39,6 +42,7 @@ class Gemv { const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output); }; diff --git a/mace/ops/arm/q8/gemv_test.cc b/mace/ops/arm/q8/gemv_test.cc index 10cab216..ced75f64 100644 --- a/mace/ops/arm/q8/gemv_test.cc +++ b/mace/ops/arm/q8/gemv_test.cc @@ -28,7 +28,8 @@ namespace test { void TestGemvInt32(const index_t batch, const index_t height, const index_t width, - const bool lhs_batched) { + const bool lhs_batched, + const bool rhs_batched) { Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8); Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8); Tensor bias(GetCPUAllocator(), DataType::DT_INT32); @@ -38,7 +39,7 @@ void TestGemvInt32(const index_t batch, lhs.SetZeroPoint(0); rhs.SetZeroPoint(0); lhs.Resize({lhs_batched ? batch : 1, height, width}); - rhs.Resize({batch, width}); + rhs.Resize({rhs_batched ? batch : 1, batch, width}); bias.Resize({height}); output.Resize({batch, height}); { @@ -62,6 +63,7 @@ void TestGemvInt32(const index_t batch, height, width, lhs_batched, + rhs_batched, &output); Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32); @@ -75,6 +77,7 @@ void TestGemvInt32(const index_t batch, height, width, lhs_batched, + rhs_batched, &expected_output); Tensor::MappingGuard output_guard(&output); @@ -90,7 +93,8 @@ void TestGemvInt32(const index_t batch, void TestGemvUint8(const index_t batch, const index_t height, const index_t width, - const bool lhs_batched) { + const bool lhs_batched, + const bool rhs_batched) { Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8); Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8); Tensor bias(GetCPUAllocator(), DataType::DT_INT32); @@ -127,6 +131,7 @@ void TestGemvUint8(const index_t batch, height, width, lhs_batched, + rhs_batched, &output); Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32); @@ -142,6 +147,7 @@ void TestGemvUint8(const index_t batch, height, width, lhs_batched, + rhs_batched, &expected_output); Tensor::MappingGuard output_guard(&output); @@ -155,27 +161,27 @@ void TestGemvUint8(const index_t batch, } TEST(ArmGemv, TestGemvInt32) { - TestGemvInt32(1, 16, 4, true); - TestGemvInt32(1, 16, 256, true); - TestGemvInt32(2, 16, 256, true); - TestGemvInt32(3, 63, 257, true); - - TestGemvInt32(1, 16, 4, false); - TestGemvInt32(1, 16, 256, false); - TestGemvInt32(2, 16, 256, false); - TestGemvInt32(3, 63, 257, false); + TestGemvInt32(1, 16, 4, true, true); + TestGemvInt32(1, 16, 256, true, true); + TestGemvInt32(2, 16, 256, true, true); + TestGemvInt32(3, 63, 257, true, true); + + TestGemvInt32(2, 16, 256, false, true); + TestGemvInt32(3, 63, 257, false, true); + TestGemvInt32(2, 16, 256, true, false); + TestGemvInt32(3, 63, 257, true, false); } TEST(ArmGemv, TestGemvUint8) { - TestGemvUint8(1, 16, 4, true); - TestGemvUint8(1, 16, 256, true); - TestGemvUint8(2, 16, 256, true); - TestGemvUint8(3, 63, 257, true); - - TestGemvUint8(1, 16, 4, false); - TestGemvUint8(1, 16, 256, false); - TestGemvUint8(2, 16, 256, false); - TestGemvUint8(3, 63, 257, false); + TestGemvUint8(1, 16, 4, true, true); + TestGemvUint8(1, 16, 256, true, true); + TestGemvUint8(2, 16, 256, true, true); + TestGemvUint8(3, 63, 257, true, true); + + TestGemvUint8(2, 16, 256, false, true); + TestGemvUint8(3, 63, 257, false, true); + TestGemvUint8(2, 16, 256, true, false); + TestGemvUint8(3, 63, 257, true, false); } } // namespace test diff --git a/mace/ops/common/matrix.h b/mace/ops/common/matrix.h new file mode 100644 index 00000000..3abd2d3b --- /dev/null +++ b/mace/ops/common/matrix.h @@ -0,0 +1,107 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef MACE_OPS_COMMON_MATRIX_H_ +#define MACE_OPS_COMMON_MATRIX_H_ + +namespace mace { +namespace ops { + +enum MatrixMajor { + RowMajor, + ColMajor +}; + +inline MatrixMajor TransposeMatrixMajor(const MatrixMajor src_major) { + return src_major == RowMajor ? ColMajor : RowMajor; +} + +template +class MatrixMap { + public: + MatrixMap() + : data_(nullptr), + matrix_major_(RowMajor), + rows_(0), + cols_(0), + stride_(0) {} + MatrixMap(T *data, + const MatrixMajor matrix_major, + const index_t rows, + const index_t cols) : + data_(data), + matrix_major_(matrix_major), + rows_(rows), + cols_(cols), + stride_(matrix_major == ColMajor ? rows : cols) {} + MatrixMap(T *data, + const MatrixMajor matrix_major, + const index_t rows, + const index_t cols, + const index_t stride) : + data_(data), + matrix_major_(matrix_major), + rows_(rows), + cols_(cols), + stride_(stride) {} + MatrixMap(const MatrixMap &other) + : data_(other.data_), + matrix_major_(other.matrix_major_), + rows_(other.rows_), + cols_(other.cols_), + stride_(other.stride_) {} + + MatrixMajor matrix_major() const { return matrix_major_; } + index_t rows() const { return rows_; } + index_t cols() const { return cols_; } + index_t stride() const { return stride_; } + int rows_stride() const { + return matrix_major_ == MatrixMajor::ColMajor ? 1 : stride_; + } + int cols_stride() const { + return matrix_major_ == MatrixMajor::RowMajor ? 1 : stride_; + } + index_t size() const { return rows_ * cols_; } + T *data() const { return data_; } + T *data(int rows, int cols) const { + return data_ + rows * rows_stride() + cols * cols_stride(); + } + T &operator()(int row, int col) const { return *data(row, col); } + MatrixMap block(int start_row, int start_col, int block_rows, + int block_cols) const { + MACE_CHECK(start_row >= 0); + MACE_CHECK(start_row + block_rows <= rows_); + MACE_CHECK(start_col >= 0); + MACE_CHECK(start_col + block_cols <= cols_); + + return MatrixMap(data(start_row, start_col), + matrix_major_, + block_rows, + block_cols, + stride_); + } + + private: + T *data_; + MatrixMajor matrix_major_; + index_t rows_; + index_t cols_; + index_t stride_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_COMMON_MATRIX_H_ diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index d3c06973..19794b38 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -33,6 +33,13 @@ #include "mace/ops/common/conv_pool_2d_util.h" #include "mace/utils/utils.h" +#ifdef MACE_ENABLE_NEON +#include "mace/ops/arm/fp32/conv_2d.h" +#include "mace/ops/arm/fp32/conv_2d_1x1.h" +#else +#include "mace/ops/ref/conv_2d.h" +#endif // MACE_ENABLE_NEON + #ifdef MACE_ENABLE_QUANTIZE #include "mace/ops/gemmlowp_util.h" #include "mace/ops/quantization_util.h" @@ -61,7 +68,8 @@ class Conv2dOp : public ConvPool2dOpBase { relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)), leakyrelu_coefficient_(Operation::GetOptionalArg( "leakyrelu_coefficient", 0.0f)), - is_filter_transformed_(false) {} + is_filter_transformed_(false), + conv2d_delegator_(nullptr) {} MaceStatus Run(OpContext *context) override { const Tensor *input = this->Input(INPUT); @@ -69,9 +77,17 @@ class Conv2dOp : public ConvPool2dOpBase { const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr; Tensor *output = this->Output(OUTPUT); + index_t input_batch = input->dim(0); + index_t input_channels = input->dim(1); std::vector filter_shape(4); filter_shape = filter->shape(); + index_t stride_h = strides_[0]; + index_t stride_w = strides_[1]; + + index_t dilation_h = dilations_[0]; + index_t dilation_w = dilations_[1]; + std::vector output_shape(4); std::vector paddings(2); if (paddings_.empty()) { @@ -99,403 +115,413 @@ class Conv2dOp : public ConvPool2dOpBase { index_t height = output->dim(2); index_t width = output->dim(3); - index_t input_batch = input->dim(0); - index_t input_channels = input->dim(1); - index_t input_height = input->dim(2); - index_t input_width = input->dim(3); - - index_t filter_h = filter_shape[2]; - index_t filter_w = filter_shape[3]; + MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels); MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ", input_channels); - index_t stride_h = strides_[0]; - index_t stride_w = strides_[1]; - - index_t dilation_h = dilations_[0]; - index_t dilation_w = dilations_[1]; - - MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); - - index_t padded_input_height = input_height + paddings[0]; - index_t padded_input_width = input_width + paddings[1]; - index_t extra_input_height = padded_input_height; - index_t extra_input_width = padded_input_width; - index_t extra_output_height = height; - index_t extra_output_width = width; - - int pad_top = paddings[0] >> 1; - int pad_bottom = paddings[0] - pad_top; - int pad_left = paddings[1] >> 1; - int pad_right = paddings[1] - pad_left; - - Tensor::MappingGuard input_guard(input); - Tensor::MappingGuard filter_guard(filter); - Tensor::MappingGuard bias_guard(bias); - Tensor::MappingGuard output_guard(output); - - auto filter_data = filter->data(); - auto bias_data = bias == nullptr ? nullptr : bias->data(); - auto output_data = output->mutable_data(); - - std::function conv_func; - - bool - use_winograd = filter_h == 3 && filter_w == 3 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1 - && input_channels >= 8 && channels >= 8; - bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; - bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3 - && stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1; - bool use_neon_1x1_s1 = filter_h == 1 && filter_w == 1 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; - bool use_neon_5x5_s1 = filter_h == 5 && filter_w == 5 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; - bool use_neon_1x7_s1 = filter_h == 1 && filter_w == 7 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; - bool use_neon_7x1_s1 = filter_h == 7 && filter_w == 1 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; - bool use_neon_7x7_s1 = filter_h == 7 && filter_w == 7 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; - bool use_neon_7x7_s2 = filter_h == 7 && filter_w == 7 - && stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1; - bool use_neon_7x7_s3 = filter_h == 7 && filter_w == 7 - && stride_h == 3 && stride_w == 3 && dilation_h == 1 && dilation_w == 1; - bool use_neon_1x15_s1 = filter_h == 1 && filter_w == 15 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; - bool use_neon_15x1_s1 = filter_h == 15 && filter_w == 1 - && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; - - std::vector transformed_input_shape; - std::vector transformed_output_shape; - std::vector transformed_filter_shape; - - // When size of input feature map is bigger than 16x16, - // set winograd out tile size to 6 to get higher performance. - index_t winograd_out_tile_size = 2; - if (input_height > 16 && input_width > 16) { - winograd_out_tile_size = 6; - } +#ifdef MACE_ENABLE_NEON + index_t input_height = input->dim(2); + index_t input_width = input->dim(3); + index_t filter_h = filter->dim(2); + index_t filter_w = filter->dim(3); - if (use_winograd) { - extra_output_height = RoundUp(height, winograd_out_tile_size); - extra_input_height = - std::max(padded_input_height, extra_output_height + 2); - extra_output_width = RoundUp(width, winograd_out_tile_size); - extra_input_width = std::max(padded_input_width, extra_output_width + 2); - if (extra_input_height != padded_input_height) { - pad_bottom += (extra_input_height - padded_input_height); + if (filter_h == 1 && filter_w == 1 && stride_h == 1 && stride_w == 1 + && dilation_h == 1 && dilation_w == 1) { + if (conv2d_delegator_.get() == nullptr) { + conv2d_delegator_.reset(new arm::fp32::Conv2dK1x1()); } - if (extra_input_width != padded_input_width) { - pad_right += (extra_input_width - padded_input_width); + conv2d_delegator_->Compute(context, input, filter, output); + } else { + // TODO(liyin): the code below needs to be refactored. + // delegate to each of kernels instead of ruling them all + index_t padded_input_height = input_height + paddings[0]; + index_t padded_input_width = input_width + paddings[1]; + index_t extra_input_height = padded_input_height; + index_t extra_input_width = padded_input_width; + index_t extra_output_height = height; + index_t extra_output_width = width; + + int pad_top = paddings[0] >> 1; + int pad_bottom = paddings[0] - pad_top; + int pad_left = paddings[1] >> 1; + int pad_right = paddings[1] - pad_left; + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard filter_guard(filter); + Tensor::MappingGuard output_guard(output); + + auto filter_data = filter->data(); + auto output_data = output->mutable_data(); + + std::function conv_func; + + bool + use_winograd = filter_h == 3 && filter_w == 3 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 + && dilation_w == 1 + && input_channels >= 8 && channels >= 8; + bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3 + && stride_h == 2 && stride_w == 2 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_5x5_s1 = filter_h == 5 && filter_w == 5 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_1x7_s1 = filter_h == 1 && filter_w == 7 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_7x1_s1 = filter_h == 7 && filter_w == 1 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_7x7_s1 = filter_h == 7 && filter_w == 7 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_7x7_s2 = filter_h == 7 && filter_w == 7 + && stride_h == 2 && stride_w == 2 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_7x7_s3 = filter_h == 7 && filter_w == 7 + && stride_h == 3 && stride_w == 3 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_1x15_s1 = filter_h == 1 && filter_w == 15 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 + && dilation_w == 1; + bool use_neon_15x1_s1 = filter_h == 15 && filter_w == 1 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 + && dilation_w == 1; + + std::vector transformed_input_shape; + std::vector transformed_output_shape; + std::vector transformed_filter_shape; + + // When size of input feature map is bigger than 16x16, + // set winograd out tile size to 6 to get higher performance. + index_t winograd_out_tile_size = 2; + if (input_height > 16 && input_width > 16) { + winograd_out_tile_size = 6; } - index_t tile_height_count = extra_output_height / winograd_out_tile_size; - index_t tile_width_count = extra_output_width / winograd_out_tile_size; - index_t tile_count = tile_height_count * tile_width_count; - index_t in_tile_area = - (winograd_out_tile_size + 2) * (winograd_out_tile_size + 2); - - transformed_input_shape.insert(transformed_input_shape.end(), - {in_tile_area, batch, input_channels, - tile_count}); - transformed_output_shape.insert(transformed_output_shape.end(), - {in_tile_area, batch, channels, - tile_count}); - transformed_filter_shape.insert(transformed_filter_shape.end(), - {in_tile_area, channels, input_channels}); - } else { - index_t tile_h, tile_w; - if (use_neon_1x1_s1) { - tile_h = 1; - tile_w = 1; - } else if (use_neon_3x3_s1) { - tile_h = 2; - tile_w = 4; - } else if (use_neon_7x1_s1 || use_neon_15x1_s1) { - tile_h = 4; - tile_w = 1; + if (use_winograd) { + extra_output_height = RoundUp(height, winograd_out_tile_size); + extra_input_height = + std::max(padded_input_height, extra_output_height + 2); + extra_output_width = RoundUp(width, winograd_out_tile_size); + extra_input_width = + std::max(padded_input_width, extra_output_width + 2); + if (extra_input_height != padded_input_height) { + pad_bottom += (extra_input_height - padded_input_height); + } + if (extra_input_width != padded_input_width) { + pad_right += (extra_input_width - padded_input_width); + } + + index_t + tile_height_count = extra_output_height / winograd_out_tile_size; + index_t tile_width_count = extra_output_width / winograd_out_tile_size; + index_t tile_count = tile_height_count * tile_width_count; + index_t in_tile_area = + (winograd_out_tile_size + 2) * (winograd_out_tile_size + 2); + + transformed_input_shape.insert(transformed_input_shape.end(), + {in_tile_area, batch, input_channels, + tile_count}); + transformed_output_shape.insert(transformed_output_shape.end(), + {in_tile_area, batch, channels, + tile_count}); + transformed_filter_shape.insert(transformed_filter_shape.end(), + {in_tile_area, channels, + input_channels}); } else { - tile_h = 1; - tile_w = 4; + index_t tile_h, tile_w; + if (use_neon_3x3_s1) { + tile_h = 2; + tile_w = 4; + } else if (use_neon_7x1_s1 || use_neon_15x1_s1) { + tile_h = 4; + tile_w = 1; + } else { + tile_h = 1; + tile_w = 4; + } + extra_output_height = RoundUp(height, tile_h); + extra_input_height = + std::max(padded_input_height, (extra_output_height - 1) * stride_h + + (filter_h - 1) * dilation_h + 1); + extra_output_width = RoundUp(width, tile_w); + extra_input_width = + std::max(padded_input_width, (extra_output_width - 1) * stride_w + + (filter_w - 1) * dilation_w + 1); + if (extra_input_height != padded_input_height) { + pad_bottom += (extra_input_height - padded_input_height); + } + if (extra_input_width != padded_input_width) { + pad_right += (extra_input_width - padded_input_width); + } } - extra_output_height = RoundUp(height, tile_h); - extra_input_height = - std::max(padded_input_height, (extra_output_height - 1) * stride_h - + (filter_h - 1) * dilation_h + 1); - extra_output_width = RoundUp(width, tile_w); - extra_input_width = - std::max(padded_input_width, (extra_output_width - 1) * stride_w - + (filter_w - 1) * dilation_w + 1); - if (extra_input_height != padded_input_height) { - pad_bottom += (extra_input_height - padded_input_height); + + // decide scratch size before allocate it + index_t total_scratch_size = 0; + index_t transformed_input_size = 0; + index_t transformed_output_size = 0; + index_t padded_input_size = 0; + index_t padded_output_size = 0; + if (use_winograd) { + transformed_input_size = + std::accumulate(transformed_input_shape.begin(), + transformed_input_shape.end(), + 1, + std::multiplies()) * sizeof(float); + transformed_output_size = + std::accumulate(transformed_output_shape.begin(), + transformed_output_shape.end(), + 1, + std::multiplies()) * sizeof(float); + total_scratch_size += transformed_input_size + transformed_output_size; } - if (extra_input_width != padded_input_width) { - pad_right += (extra_input_width - padded_input_width); + if (extra_input_height != input_height + || extra_input_width != input_width) { + padded_input_size = + batch * input_channels * (input_height + pad_top + pad_bottom) + * (input_width + pad_left + pad_right) * sizeof(float) + + MACE_EXTRA_BUFFER_PAD_SIZE; + total_scratch_size += padded_input_size; + } + if (extra_output_height != height || extra_output_width != width) { + padded_output_size = + batch * channels * extra_output_height * extra_output_width + * sizeof(float); + total_scratch_size += padded_output_size; } - } - // decide scratch size before allocate it - index_t total_scratch_size = 0; - index_t transformed_input_size = 0; - index_t transformed_output_size = 0; - index_t padded_input_size = 0; - index_t padded_output_size = 0; - if (use_winograd) { - transformed_input_size = - std::accumulate(transformed_input_shape.begin(), - transformed_input_shape.end(), - 1, - std::multiplies()) * sizeof(float); - transformed_output_size = - std::accumulate(transformed_output_shape.begin(), - transformed_output_shape.end(), - 1, - std::multiplies()) * sizeof(float); - total_scratch_size += transformed_input_size + transformed_output_size; - } - if (extra_input_height != input_height - || extra_input_width != input_width) { - padded_input_size = - batch * input_channels * (input_height + pad_top + pad_bottom) - * (input_width + pad_left + pad_right) * sizeof(float) + - MACE_EXTRA_BUFFER_PAD_SIZE; - total_scratch_size += padded_input_size; - } - if (extra_output_height != height || extra_output_width != width) { - padded_output_size = - batch * channels * extra_output_height * extra_output_width - * sizeof(float); - total_scratch_size += padded_output_size; - } - // scratch for sgemm, preoccupy enough buffer - if (use_neon_1x1_s1) { - total_scratch_size += (input_batch * input_height * input_width - * (input_channels + channels)) - * sizeof(float); - } else if (use_winograd) { - total_scratch_size += transformed_input_size + transformed_output_size; - } + if (use_winograd) { + total_scratch_size += transformed_input_size + transformed_output_size; + } - // Init scratch buffer - ScratchBuffer *scratch = context->device()->scratch_buffer(); - scratch->Rewind(); - scratch->GrowSize(total_scratch_size); - Tensor - transformed_input(scratch->Scratch(transformed_input_size), DT_FLOAT); - Tensor - transformed_output(scratch->Scratch(transformed_output_size), DT_FLOAT); - Tensor padded_input(scratch->Scratch(padded_input_size), DT_FLOAT); - Tensor padded_output(scratch->Scratch(padded_output_size), DT_FLOAT); - const index_t extra_input_shape[4] = - {batch, input_channels, extra_input_height, extra_input_width}; - const index_t extra_output_shape[4] = - {batch, channels, extra_output_height, extra_output_width}; - - // make host compiler happy - MACE_UNUSED(extra_input_shape); - MACE_UNUSED(extra_output_shape); - - Tensor transformed_filter; - - // decide which convolution function to call - if (use_winograd) { - transformed_input.Reshape(transformed_input_shape); - transformed_output.Reshape(transformed_output_shape); - const float *transformed_filter_data = nullptr; - // filter only needs to be transformed once, set transformed_filter_data - // to null after the first run. - if (!is_filter_transformed_) { - transformed_filter.Resize(transformed_filter_shape); - switch (winograd_out_tile_size) { - case 2: - TransformFilter4x4(filter_data, - filter_shape[1], - filter_shape[0], - transformed_filter.mutable_data()); - break; - case 6: - TransformFilter8x8(filter_data, - filter_shape[1], - filter_shape[0], - transformed_filter.mutable_data()); - break; - default:MACE_NOT_IMPLEMENTED; + // Init scratch buffer + ScratchBuffer *scratch = context->device()->scratch_buffer(); + scratch->Rewind(); + scratch->GrowSize(total_scratch_size); + Tensor + transformed_input(scratch->Scratch(transformed_input_size), DT_FLOAT); + Tensor + transformed_output + (scratch->Scratch(transformed_output_size), DT_FLOAT); + Tensor padded_input(scratch->Scratch(padded_input_size), DT_FLOAT); + Tensor padded_output(scratch->Scratch(padded_output_size), DT_FLOAT); + const index_t extra_input_shape[4] = + {batch, input_channels, extra_input_height, extra_input_width}; + const index_t extra_output_shape[4] = + {batch, channels, extra_output_height, extra_output_width}; + + // make host compiler happy + MACE_UNUSED(extra_input_shape); + MACE_UNUSED(extra_output_shape); + + Tensor transformed_filter; + + // decide which convolution function to call + if (use_winograd) { + transformed_input.Reshape(transformed_input_shape); + transformed_output.Reshape(transformed_output_shape); + const float *transformed_filter_data = nullptr; + // filter only needs to be transformed once, set transformed_filter_data + // to null after the first run. + if (!is_filter_transformed_) { + transformed_filter.Resize(transformed_filter_shape); + switch (winograd_out_tile_size) { + case 2: + TransformFilter4x4(filter_data, + filter_shape[1], + filter_shape[0], + transformed_filter.mutable_data()); + break; + case 6: + TransformFilter8x8(filter_data, + filter_shape[1], + filter_shape[0], + transformed_filter.mutable_data()); + break; + default:MACE_NOT_IMPLEMENTED; + } + transformed_filter_data = transformed_filter.data(); + is_filter_transformed_ = true; } - transformed_filter_data = transformed_filter.data(); - is_filter_transformed_ = true; - } - float *transformed_input_data = transformed_input.mutable_data(); - float *transformed_output_data = transformed_output.mutable_data(); - - conv_func = [=](const float *pad_input, float *pad_output) { - WinoGradConv3x3s1(pad_input, - transformed_filter_data, - batch, - extra_input_height, - extra_input_width, - input_channels, - channels, - winograd_out_tile_size, - transformed_input_data, - transformed_output_data, - pad_output, - &sgemm_, - scratch); - }; - } else if (use_neon_1x1_s1) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK1x1S1(pad_input, - filter_data, - batch, - extra_input_height, - extra_input_width, - input_channels, - channels, - pad_output, - &sgemm_, - scratch); - }; - } else if (use_neon_3x3_s1) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK3x3S1(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_3x3_s2) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK3x3S2(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_5x5_s1) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK5x5S1(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_1x7_s1) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK1x7S1(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_7x1_s1) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK7x1S1(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_7x7_s1) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK7x7S1(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_7x7_s2) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK7x7S2(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_7x7_s3) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK7x7S3(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_1x15_s1) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK1x15S1(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else if (use_neon_15x1_s1) { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK15x1S1(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - pad_output); - }; - } else { - conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dGeneral(pad_input, - filter_data, - extra_input_shape, - extra_output_shape, - filter_shape.data(), - strides_.data(), - dilations_.data(), - pad_output); - }; - } + float *transformed_input_data = transformed_input.mutable_data(); + float + *transformed_output_data = transformed_output.mutable_data(); + + conv_func = [=](const float *pad_input, float *pad_output) { + WinoGradConv3x3s1(pad_input, + transformed_filter_data, + batch, + extra_input_height, + extra_input_width, + input_channels, + channels, + winograd_out_tile_size, + transformed_input_data, + transformed_output_data, + pad_output, + &sgemm_, + scratch); + }; + } else if (use_neon_3x3_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK3x3S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_3x3_s2) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK3x3S2(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_5x5_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK5x5S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_1x7_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK1x7S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_7x1_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK7x1S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_7x7_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK7x7S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_7x7_s2) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK7x7S2(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_7x7_s3) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK7x7S3(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_1x15_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK1x15S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_15x1_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK15x1S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dGeneral(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + filter_shape.data(), + strides_.data(), + dilations_.data(), + pad_output); + }; + } - // pad input and output - const Tensor *pad_input_ptr = input; - if (extra_input_height != input_height - || extra_input_width != input_width) { - MACE_RETURN_IF_ERROR(ConstructNCHWInputWithSpecificPadding( - input, pad_top, pad_bottom, pad_left, pad_right, &padded_input)); - pad_input_ptr = &padded_input; - } + // pad input and output + const Tensor *pad_input_ptr = input; + if (extra_input_height != input_height + || extra_input_width != input_width) { + MACE_RETURN_IF_ERROR(ConstructNCHWInputWithSpecificPadding( + input, pad_top, pad_bottom, pad_left, pad_right, &padded_input)); + pad_input_ptr = &padded_input; + } - // TODO(libin): don't need clear after bias is integrated in each conv - Tensor *pad_output_ptr = output; - if (extra_output_height != height || extra_output_width != width) { - padded_output.Reshape({batch, channels, extra_output_height, - extra_output_width}); - padded_output.Clear(); - pad_output_ptr = &padded_output; - } else if (!use_neon_1x1_s1) { - output->Clear(); - } + // TODO(libin): don't need clear after bias is integrated in each conv + Tensor *pad_output_ptr = output; + if (extra_output_height != height || extra_output_width != width) { + padded_output.Reshape({batch, channels, extra_output_height, + extra_output_width}); + padded_output.Clear(); + pad_output_ptr = &padded_output; + } else { + output->Clear(); + } - const float *pad_input_data = pad_input_ptr->data(); - float *pad_output_data = pad_output_ptr->mutable_data(); + const float *pad_input_data = pad_input_ptr->data(); + float *pad_output_data = pad_output_ptr->mutable_data(); - conv_func(pad_input_data, pad_output_data); + conv_func(pad_input_data, pad_output_data); - // unpack output - if (extra_output_height != height || extra_output_width != width) { + // unpack output + if (extra_output_height != height || extra_output_width != width) { #pragma omp parallel for collapse(2) schedule(runtime) - for (index_t b = 0; b < batch; ++b) { - for (index_t c = 0; c < channels; ++c) { - for (index_t h = 0; h < height; ++h) { - memcpy( - output_data + b * channels * height * width + c * height * width - + h * width, - pad_output_data - + b * channels * extra_output_height * extra_output_width - + c * extra_output_height * extra_output_width - + h * extra_output_width, - sizeof(float) * width); + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channels; ++c) { + for (index_t h = 0; h < height; ++h) { + memcpy( + output_data + b * channels * height * width + + c * height * width + + h * width, + pad_output_data + + b * channels * extra_output_height * extra_output_width + + c * extra_output_height * extra_output_width + + h * extra_output_width, + sizeof(float) * width); + } } } } } +#else + if (conv2d_delegator_.get() == nullptr) { + conv2d_delegator_.reset(new ref::Conv2d(paddings[0], + paddings[1], + stride_h, + stride_w, + dilation_h, + dilation_w)); + } + conv2d_delegator_->Compute(context, input, filter, output); +#endif + Tensor::MappingGuard bias_guard(bias); + Tensor::MappingGuard output_guard(output); + auto bias_data = bias == nullptr ? nullptr : bias->data(); + auto output_data = output->mutable_data(); if (bias_data != nullptr) { const index_t image_size = height * width; #pragma omp parallel for collapse(2) schedule(runtime) @@ -702,13 +728,16 @@ class Conv2dOp : public ConvPool2dOpBase { } // m } // b } - - private: const ActivationType activation_; const float relux_max_limit_; const float leakyrelu_coefficient_; bool is_filter_transformed_; SGemm sgemm_; +#ifdef MACE_ENABLE_NEON + std::unique_ptr conv2d_delegator_; +#else + std::unique_ptr> conv2d_delegator_; +#endif // MACE_ENABLE_NEON private: MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS); diff --git a/mace/ops/fully_connected.cc b/mace/ops/fully_connected.cc index 185d6278..c82aa8ff 100644 --- a/mace/ops/fully_connected.cc +++ b/mace/ops/fully_connected.cc @@ -102,6 +102,7 @@ class FullyConnectedOp : public FullyConnectedOpBase { output_size, input_size, false, + true, output); Tensor::MappingGuard guard_output(output); float *output_ptr = output->mutable_data(); @@ -162,6 +163,7 @@ class FullyConnectedOp output_size, input_size, false, + true, output); return MaceStatus::MACE_SUCCESS; } diff --git a/mace/ops/matmul.cc b/mace/ops/matmul.cc index 7ae79569..a3aebcb4 100644 --- a/mace/ops/matmul.cc +++ b/mace/ops/matmul.cc @@ -25,7 +25,7 @@ #include "mace/utils/utils.h" #ifdef MACE_ENABLE_NEON - +#include "mace/ops/arm/fp32/gemm.h" #include "mace/ops/arm/fp32/gemv.h" #ifdef MACE_ENABLE_QUANTIZE @@ -33,6 +33,7 @@ #endif // MACE_ENABLE_QUANTIZE #else +#include "mace/ops/ref/gemm.h" #include "mace/ops/ref/gemv.h" #endif // MACE_ENABLE_NEON @@ -58,35 +59,45 @@ class MatMulOpBase : public Operation { inline void Validate() { const Tensor *A = this->Input(INPUT_A); const Tensor *B = this->Input(INPUT_B); - MACE_CHECK(A->dim_size() == B->dim_size() && A->dim_size() >= 2, - "rank(A) should be equal to rank(B), rank should be greater " - "than or equal to 2"); - index_t rank = A->dim_size(); - for (index_t i = 0; i < rank - 2; ++i) { - MACE_CHECK(A->dim(i) == B->dim(i), - "batch dimensions are not equal: ", - A->dim(i), - " vs. ", - B->dim(i)); + const index_t lhs_rank = A->dim_size(); + const index_t rhs_rank = B->dim_size(); + + MACE_CHECK(lhs_rank >= 2 && rhs_rank >= 2, + "rank should be greater than or equal to 2"); + if (lhs_rank == rhs_rank) { + for (index_t i = 0; i < A->dim_size() - 2; ++i) { + MACE_CHECK(A->dim(i) == B->dim(i), + "batch dimensions are not equal: ", + A->dim(i), + " vs. ", + B->dim(i)); + } + } else { + MACE_CHECK(lhs_rank == 2 || rhs_rank == 2, + "Either lhs or rhs matrix should has rank 2 " + "for non-batched matrix multiplication"); } - index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1); - index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2); - MACE_CHECK(ak == bk, "the number of A's column ", ak, - " must be equal to B's row ", bk); + + index_t + lhs_depth = transpose_a_ ? A->dim(lhs_rank - 2) : A->dim(lhs_rank - 1); + index_t + rhs_depth = transpose_b_ ? B->dim(rhs_rank - 1) : B->dim(rhs_rank - 2); + MACE_CHECK(lhs_depth == rhs_depth, "the number of A's column ", lhs_depth, + " must be equal to B's row ", rhs_depth); } protected: - MACE_OP_INPUT_TAGS(INPUT_A, INPUT_B); + MACE_OP_INPUT_TAGS(INPUT_A, INPUT_B, BIAS); MACE_OP_OUTPUT_TAGS(OUTPUT); bool transpose_a_; bool transpose_b_; }; -template +template class MatMulOp; -template <> +template<> class MatMulOp : public MatMulOpBase { public: explicit MatMulOp(OpConstructContext *context) @@ -94,72 +105,116 @@ class MatMulOp : public MatMulOpBase { MaceStatus Run(OpContext *context) override { Validate(); - const Tensor *A = this->Input(INPUT_A); - const Tensor *B = this->Input(INPUT_B); + const Tensor *lhs = this->Input(INPUT_A); + const Tensor *rhs = this->Input(INPUT_B); + const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr; Tensor *C = this->Output(OUTPUT); - index_t batch; - index_t height; - index_t K; - index_t width; - - index_t rank = A->dim_size(); - height = A->dim(rank - 2); - K = A->dim(rank - 1); - if (transpose_a_) { - std::swap(height, K); - } - if (transpose_b_) { - width = B->dim(rank - 2); + const index_t lhs_rank = lhs->dim_size(); + const index_t lhs_rows = lhs->dim(lhs_rank - 2); + const index_t lhs_cols = lhs->dim(lhs_rank - 1); + const index_t rhs_rank = rhs->dim_size(); + const index_t rhs_rows = rhs->dim(rhs_rank - 2); + const index_t rhs_cols = rhs->dim(rhs_rank - 1); + + const index_t rows = transpose_a_ ? lhs_cols : lhs_rows; + const index_t cols = transpose_b_ ? rhs_rows : rhs_cols; + const index_t depth = transpose_a_ ? lhs_rows : lhs_cols; + const index_t + lhs_batch = + std::accumulate(lhs->shape().begin(), lhs->shape().end() - 2, 1, + std::multiplies()); + const index_t + rhs_batch = + std::accumulate(rhs->shape().begin(), rhs->shape().end() - 2, 1, + std::multiplies()); + index_t batch = 1; + std::vector output_shape; + if (lhs_rank >= rhs_rank) { + output_shape = lhs->shape(); + output_shape[lhs_rank - 2] = rows; + output_shape[lhs_rank - 1] = cols; + batch = lhs_batch; } else { - width = B->dim(rank - 1); + output_shape = rhs->shape(); + output_shape[rhs_rank - 2] = rows; + output_shape[rhs_rank - 1] = cols; + batch = rhs_batch; + } + bool lhs_batched = true; + bool rhs_batched = true; + if (lhs_rank < rhs_rank) { + lhs_batched = false; + } else if (rhs_rank < lhs_rank) { + rhs_batched = false; } - batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, - std::multiplies()); - std::vector c_shape = A->shape(); - c_shape[rank - 2] = height; - c_shape[rank - 1] = width; - - MACE_RETURN_IF_ERROR(C->Resize(c_shape)); - - Tensor::MappingGuard guarda(A); - Tensor::MappingGuard guardb(B); - Tensor::MappingGuard guardc(C); - const float *a_ptr_base = A->data(); - const float *b_ptr_base = B->data(); - float *c_ptr_base = C->mutable_data(); - - const index_t height_a = A->dim(rank - 2); - const index_t width_a = A->dim(rank - 1); - const index_t height_b = B->dim(rank - 2); - const index_t width_b = B->dim(rank - 1); - - auto scratch_buffer = context->device()->scratch_buffer(); - scratch_buffer->Rewind(); - - sgemm_.Run(a_ptr_base, - b_ptr_base, - batch, - height_a, - width_a, - height_b, - width_b, - transpose_a_, - transpose_b_, - A->is_weight(), - B->is_weight(), - c_ptr_base, - context->device()->scratch_buffer()); - return MaceStatus::MACE_SUCCESS; + MACE_RETURN_IF_ERROR(C->Resize(output_shape)); + + if (rows == 1 && transpose_b_) { + return gemv_.Compute(context, + rhs, + lhs, + bias, + batch, + cols, + depth, + rhs_batched, + lhs_batched, + C); + } else if (cols == 1 && !transpose_a_) { + return gemv_.Compute(context, + lhs, + rhs, + bias, + batch, + rows, + depth, + lhs_batched, + rhs_batched, + C); + } else { + context->device()->scratch_buffer()->Rewind(); + MaceStatus ret = gemm_.Compute(context, + lhs, + rhs, + batch, + lhs_rows, + lhs_cols, + rhs_rows, + rhs_cols, + transpose_a_, + transpose_b_, + false, + lhs_batched, + rhs_batched, + C); + if (bias != nullptr) { + MACE_CHECK(bias->dim_size() == 1 && bias->dim(0) == cols, + "bias' dim should be <= 2."); + Tensor::MappingGuard bias_guard(bias); + Tensor::MappingGuard c_guard(C); + const float *bias_data = bias->data(); + float *c_data = C->mutable_data(); +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t i = 0; i < batch * rows; ++i) { + for (index_t w = 0; w < cols; ++w) { + c_data[i * cols + w] += bias_data[w]; + } + } + } + + return ret; + } } private: - SGemm sgemm_; #ifdef MACE_ENABLE_NEON + arm::fp32::Gemm gemm_; arm::fp32::Gemv gemv_; #else ref::Gemv gemv_; + ref::Gemm gemm_; #endif // MACE_ENABLE_NEON }; @@ -174,18 +229,36 @@ class MatMulFixpointImpl { void operator()(OpContext *context, const Tensor *A, const Tensor *B, + const index_t batch, const index_t height, const index_t K, const index_t width, + const bool lhs_bached, + const bool rhs_bached, Tensor *C) { - index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, - std::multiplies()); - #if defined(MACE_ENABLE_NEON) if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { - gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C); + gemv_kernel_.Compute(context, + A, + B, + nullptr, + batch, + height, + K, + true, + true, + C); } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { - gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C); + gemv_kernel_.Compute(context, + B, + A, + nullptr, + batch, + width, + K, + true, + true, + C); } else { #endif // MACE_ENABLE_NEON Tensor::MappingGuard guarda(A); @@ -208,9 +281,13 @@ class MatMulFixpointImpl { for (index_t i = 0; i < batch; ++i) { gemmlowp::MatrixMap - a_matrix(a_ptr_base + i * a_size, height, K); + a_matrix(a_ptr_base + static_cast(lhs_bached) * i * a_size, + height, + K); gemmlowp::MatrixMap - b_matrix(b_ptr_base + i * b_size, K, width); + b_matrix(b_ptr_base + static_cast(rhs_bached) * i * b_size, + K, + width); gemmlowp::MatrixMap c_matrix(c_ptr_base + i * c_size, height, width); @@ -234,20 +311,39 @@ class MatMulFixpointImpl { void operator()(OpContext *context, const Tensor *A, const Tensor *B, + const index_t batch, const index_t height, const index_t K, const index_t width, + const bool lhs_bached, + const bool rhs_bached, Tensor *C) { C->SetScale(A->scale() * B->scale()); C->SetZeroPoint(0); - index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, - std::multiplies()); #if defined(MACE_ENABLE_NEON) if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { - gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C); + gemv_kernel_.Compute(context, + A, + B, + nullptr, + batch, + height, + K, + lhs_bached, + rhs_bached, + C); } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { - gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C); + gemv_kernel_.Compute(context, + B, + A, + nullptr, + batch, + width, + K, + lhs_bached, + rhs_bached, + C); } else { #endif // MACE_ENABLE_NEON Tensor::MappingGuard guarda(A); @@ -257,7 +353,8 @@ class MatMulFixpointImpl { auto b_ptr_base = B->data(); auto c_ptr_base = C->mutable_data(); auto - gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); + gemm_context = + context->device()->cpu_runtime()->GetGemmlowpContext(); MACE_CHECK_NOTNULL(gemm_context); index_t a_size = height * K; @@ -268,9 +365,15 @@ class MatMulFixpointImpl { for (index_t i = 0; i < batch; ++i) { gemmlowp::MatrixMap - a_matrix(a_ptr_base + i * a_size, height, K); + a_matrix + (a_ptr_base + static_cast(lhs_bached) * i * a_size, + height, + K); gemmlowp::MatrixMap - b_matrix(b_ptr_base + i * b_size, K, width); + b_matrix + (b_ptr_base + static_cast(rhs_bached) * i * b_size, + K, + width); gemmlowp::MatrixMap c_matrix(c_ptr_base + i * c_size, height, width); @@ -280,7 +383,6 @@ class MatMulFixpointImpl { -B->zero_point(), output_pipeline); } } - #if defined(MACE_ENABLE_NEON) } @@ -289,44 +391,65 @@ class MatMulFixpointImpl { #endif // MACE_ENABLE_NEON }; -template <> -class MatMulOp: public MatMulOpBase { +template<> +class MatMulOp : public MatMulOpBase { public: explicit MatMulOp(OpConstructContext *context) : MatMulOpBase(context) {} MaceStatus Run(OpContext *context) override { Validate(); - const Tensor *A = this->Input(INPUT_A); - const Tensor *B = this->Input(INPUT_B); + const Tensor *lhs = this->Input(INPUT_A); + const Tensor *rhs = this->Input(INPUT_B); Tensor *C = this->Output(OUTPUT); - index_t rank = A->dim_size(); - index_t height = A->dim(rank - 2); - index_t K = A->dim(rank - 1); - index_t width; - - if (transpose_a_) { - std::swap(height, K); - } - if (transpose_b_) { - width = B->dim(rank - 2); + const index_t lhs_rank = lhs->dim_size(); + const index_t lhs_rows = lhs->dim(lhs_rank - 2); + const index_t lhs_cols = lhs->dim(lhs_rank - 1); + const index_t rhs_rank = rhs->dim_size(); + const index_t rhs_rows = rhs->dim(rhs_rank - 2); + const index_t rhs_cols = rhs->dim(rhs_rank - 1); + + const index_t rows = transpose_a_ ? lhs_cols : lhs_rows; + const index_t cols = transpose_b_ ? rhs_rows : rhs_cols; + const index_t depth = transpose_a_ ? lhs_rows : lhs_cols; + const index_t + lhs_batch = + std::accumulate(lhs->shape().begin(), lhs->shape().end() - 2, 1, + std::multiplies()); + const index_t + rhs_batch = + std::accumulate(rhs->shape().begin(), rhs->shape().end() - 2, 1, + std::multiplies()); + index_t batch = 1; + std::vector output_shape; + if (lhs_rank >= rhs_rank) { + output_shape = lhs->shape(); + output_shape[lhs_rank - 2] = rows; + output_shape[lhs_rank - 1] = cols; + batch = lhs_batch; } else { - width = B->dim(rank - 1); + output_shape = rhs->shape(); + output_shape[rhs_rank - 2] = rows; + output_shape[rhs_rank - 1] = cols; + batch = rhs_batch; + } + bool lhs_batched = true; + bool rhs_batched = true; + if (lhs_rank < rhs_rank) { + lhs_batched = false; + } else if (rhs_rank < lhs_rank) { + rhs_batched = false; } - std::vector c_shape = A->shape(); - c_shape[rank - 2] = height; - c_shape[rank - 1] = width; - - MACE_RETURN_IF_ERROR(C->Resize(c_shape)); + MACE_RETURN_IF_ERROR(C->Resize(output_shape)); constexpr gemmlowp::MapOrder kRowMajor = gemmlowp::MapOrder::RowMajor; constexpr gemmlowp::MapOrder kColMajor = gemmlowp::MapOrder::ColMajor; #define MATMUL_FIXPOINT_IMPL(AOrder, BOrder, OutType) \ MatMulFixpointImpl()( \ - context, A, B, height, K, width, C); + context, lhs, rhs, batch, rows, depth, cols, lhs_batched, rhs_batched, C); #define MATMUL_FIXPOINT_IMPL_TRANSPOSE_OR_NOT(OutType) \ if (transpose_a_) { \ @@ -380,7 +503,6 @@ class MatMulOp : public MatMulOpBase { }; #endif // MACE_ENABLE_OPENCL - void RegisterMatMul(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp, DeviceType::CPU, float); diff --git a/mace/ops/matmul_benchmark.cc b/mace/ops/matmul_benchmark.cc index 07a51ebf..acacdc7f 100644 --- a/mace/ops/matmul_benchmark.cc +++ b/mace/ops/matmul_benchmark.cc @@ -101,11 +101,14 @@ void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) { std::vector rhs(k * n); std::vector result(m * n); - ops::MatrixMap matrix_lhs(1, m, k, RowMajor, lhs.data(), - true); - ops::MatrixMap matrix_rhs(1, k, n, RowMajor, rhs.data(), - true); - ops::MatrixMap matrix_result(1, m, n, RowMajor, result.data()); + ops::SGemmMatrixMap + matrix_lhs(1, m, k, SGemmRowMajor, lhs.data(), + true); + ops::SGemmMatrixMap + matrix_rhs(1, k, n, SGemmRowMajor, rhs.data(), + true); + ops::SGemmMatrixMap + matrix_result(1, m, n, SGemmRowMajor, result.data()); ops::SGemm sgemm; @@ -395,6 +398,7 @@ void MatMulTransposeBenchmark( MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU); MACE_BM_MATMUL_OP(1, 30000, 256, 1); +MACE_BM_MATMUL_OP(1, 128, 256, 128); MACE_BM_MATMUL_OP(2, 128, 128, 49); MACE_BM_MATMUL_OP(3, 128, 128, 49); MACE_BM_MATMUL_OP(4, 128, 128, 49); diff --git a/mace/ops/matmul_test.cc b/mace/ops/matmul_test.cc index dc15485e..741393ff 100644 --- a/mace/ops/matmul_test.cc +++ b/mace/ops/matmul_test.cc @@ -15,6 +15,7 @@ #include #include "mace/ops/ops_test_util.h" +#include "mace/ops/ref/gemm.h" namespace mace { namespace ops { @@ -71,34 +72,121 @@ TEST_F(MatMulOpTest, SimpleCPUWithBatch) { } namespace { -void QuantOutputUint8(const std::vector &batch, - const index_t height, - const index_t channels, - const index_t out_width, - const bool transpose_a, - const bool transpose_b) { + +template +void Complex(const std::vector &batch, + const index_t rows, + const index_t depth, + const index_t cols, + const bool transpose_lhs, + const bool transpose_rhs, + const bool lhs_batched, + const bool rhs_batched) { // Construct graph OpsTestNet net; // Add input data + index_t lhs_rows = transpose_lhs ? depth : rows; + index_t lhs_cols = transpose_lhs ? rows : depth; + index_t rhs_rows = transpose_rhs ? cols : depth; + index_t rhs_cols = transpose_rhs ? depth : cols; + std::vector lhs_shape = {lhs_rows, lhs_cols}; + std::vector rhs_shape = {rhs_rows, rhs_cols}; + if (lhs_batched) { + lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end()); + } + if (rhs_batched) { + rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end()); + } + net.AddRandomInput("A", lhs_shape); + net.AddRandomInput("B", rhs_shape); + + OpDefBuilder("MatMul", "MatMulTest") + .Input("A") + .AddIntArg("transpose_a", transpose_lhs ? 1 : 0) + .Input("B") + .AddIntArg("transpose_b", transpose_rhs ? 1 : 0) + .Output("Output") + .AddIntArg("T", DT_FLOAT) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + ref::Gemm gemm; + Tensor expected_output_tensor; + std::vector expected_output_shape({rows, cols}); + expected_output_shape.insert(expected_output_shape.begin(), + batch.begin(), + batch.end()); + expected_output_tensor.Resize(expected_output_shape); index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1, std::multiplies()); - if (transpose_a) { - net.AddRandomInput("A", {batch_count, channels, height}); - } else { - net.AddRandomInput("A", {batch_count, height, channels}); + gemm.Compute(nullptr, + net.GetTensor("A"), + net.GetTensor("B"), + batch_count, + lhs_rows, + lhs_cols, + rhs_rows, + rhs_cols, + transpose_lhs, + transpose_rhs, + false, + lhs_batched, + rhs_batched, + &expected_output_tensor); + + ExpectTensorNear(expected_output_tensor, *net.GetTensor("Output")); +} +} // namespace + +TEST_F(MatMulOpTest, ComplexCPUWithBatch) { + Complex({1}, 3, 3, 3, false, false, true, true); + Complex({}, 3, 3, 3, false, false, true, true); + Complex({16}, 31, 61, 67, false, true, true, true); + Complex({31}, 31, 61, 67, true, false, true, true); + Complex({2, 3}, 31, 61, 67, true, true, true, true); + Complex({1}, 1, 30001, 253, false, true, true, true); + Complex({2}, 253, 300, 1, false, false, true, true); + // test one-side batched + Complex({2, 3}, 31, 61, 67, true, true, false, true); + Complex({2, 3}, 31, 61, 67, true, true, true, false); + Complex({2, 3}, 31, 61, 67, true, true, false, true); +} + +namespace { +void QuantOutputUint8(const std::vector &batch, + const index_t rows, + const index_t depth, + const index_t cols, + const bool transpose_lhs, + const bool transpose_rhs, + const bool lhs_batched = true, + const bool rhs_batched = true) { + // Construct graph + OpsTestNet net; + + // Add input data + // Add input data + index_t lhs_rows = transpose_lhs ? depth : rows; + index_t lhs_cols = transpose_lhs ? rows : depth; + index_t rhs_rows = transpose_rhs ? cols : depth; + index_t rhs_cols = transpose_rhs ? depth : cols; + std::vector lhs_shape = {lhs_rows, lhs_cols}; + std::vector rhs_shape = {rhs_rows, rhs_cols}; + if (lhs_batched) { + lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end()); } - if (transpose_b) { - net.AddRandomInput("B", {batch_count, out_width, channels}); - } else { - net.AddRandomInput("B", {batch_count, channels, out_width}); + if (rhs_batched) { + rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end()); } + net.AddRandomInput("A", lhs_shape); + net.AddRandomInput("B", rhs_shape); OpDefBuilder("MatMul", "MatMulTest") .Input("A") - .AddIntArg("transpose_a", transpose_a ? 1 : 0) + .AddIntArg("transpose_a", transpose_lhs ? 1 : 0) .Input("B") - .AddIntArg("transpose_b", transpose_b ? 1 : 0) + .AddIntArg("transpose_b", transpose_rhs ? 1 : 0) .Output("Output") .AddIntArg("T", DT_FLOAT) .Finalize(net.NewOperatorDef()); @@ -133,9 +221,9 @@ void QuantOutputUint8(const std::vector &batch, OpDefBuilder("MatMul", "QuantizeMatMulTest") .Input("QuantizedA") - .AddIntArg("transpose_a", transpose_a ? 1 : 0) + .AddIntArg("transpose_a", transpose_lhs ? 1 : 0) .Input("QuantizedB") - .AddIntArg("transpose_b", transpose_b ? 1 : 0) + .AddIntArg("transpose_b", transpose_rhs ? 1 : 0) .Output("QuantizedOutput") .AddIntArg("T", DT_UINT8) .OutputType({DT_UINT8}) @@ -161,39 +249,38 @@ void QuantOutputUint8(const std::vector &batch, } void QuantOutputInt32(const std::vector &batch, - const index_t height, - const index_t channels, - const index_t out_width, - const bool transpose_a, - const bool transpose_b) { + const index_t rows, + const index_t depth, + const index_t cols, + const bool transpose_lhs, + const bool transpose_rhs, + const bool lhs_batched = true, + const bool rhs_batched = true) { // Construct graph OpsTestNet net; // Add input data - index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1, - std::multiplies()); - if (transpose_a) { - net.AddRandomInput("A", {batch_count, channels, height}, - false); - } else { - net.AddRandomInput("A", {batch_count, height, channels}, - false); + // Add input data + index_t lhs_rows = transpose_lhs ? depth : rows; + index_t lhs_cols = transpose_lhs ? rows : depth; + index_t rhs_rows = transpose_rhs ? cols : depth; + index_t rhs_cols = transpose_rhs ? depth : cols; + std::vector lhs_shape = {lhs_rows, lhs_cols}; + std::vector rhs_shape = {rhs_rows, rhs_cols}; + if (lhs_batched) { + lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end()); } - if (transpose_b) { - net.AddRandomInput("B", - {batch_count, out_width, channels}, - false); - } else { - net.AddRandomInput("B", - {batch_count, channels, out_width}, - false); + if (rhs_batched) { + rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end()); } + net.AddRandomInput("A", lhs_shape); + net.AddRandomInput("B", rhs_shape); OpDefBuilder("MatMul", "MatMulTest") .Input("A") - .AddIntArg("transpose_a", transpose_a ? 1 : 0) + .AddIntArg("transpose_a", transpose_lhs ? 1 : 0) .Input("B") - .AddIntArg("transpose_b", transpose_b ? 1 : 0) + .AddIntArg("transpose_b", transpose_rhs ? 1 : 0) .Output("Output") .AddIntArg("T", DT_FLOAT) .Finalize(net.NewOperatorDef()); @@ -219,9 +306,9 @@ void QuantOutputInt32(const std::vector &batch, OpDefBuilder("MatMul", "QuantizeMatMulTest") .Input("QuantizedA") - .AddIntArg("transpose_a", transpose_a ? 1 : 0) + .AddIntArg("transpose_a", transpose_lhs ? 1 : 0) .Input("QuantizedB") - .AddIntArg("transpose_b", transpose_b ? 1 : 0) + .AddIntArg("transpose_b", transpose_rhs ? 1 : 0) .Output("QuantizedOutput") .AddIntArg("T", DT_UINT8) .OutputType({DT_INT32}) @@ -256,10 +343,12 @@ TEST_F(MatMulOpTest, QuantOutputUint8) { QuantOutputUint8({1}, 64, 32, 128, true, true); QuantOutputUint8({2, 3}, 64, 32, 128, true, true); // UnAligned - QuantOutputUint8({2}, 3, 3, 3, false, false); QuantOutputUint8({16}, 31, 61, 67, false, true); QuantOutputUint8({31}, 31, 61, 67, true, false); QuantOutputUint8({2, 3}, 31, 61, 67, true, true); + + QuantOutputUint8({2, 3}, 31, 61, 67, true, true, true, false); + QuantOutputUint8({2, 3}, 31, 61, 67, true, true, false, true); } TEST_F(MatMulOpTest, QuantOutputInt32) { @@ -281,12 +370,14 @@ TEST_F(MatMulOpTest, QuantOutputInt32) { QuantOutputInt32({3}, 128, 256, 1, false, false); // UnAligned - QuantOutputInt32({2}, 3, 3, 3, false, false); QuantOutputInt32({16}, 31, 61, 67, false, true); QuantOutputInt32({31}, 31, 61, 67, true, false); QuantOutputInt32({2, 3}, 31, 61, 67, true, true); QuantOutputInt32({1}, 1, 30001, 253, false, true); QuantOutputInt32({2}, 253, 300, 1, false, false); + + QuantOutputInt32({2, 3}, 31, 61, 67, true, true, true, false); + QuantOutputInt32({2, 3}, 31, 61, 67, true, true, false, true); } } // namespace test diff --git a/mace/ops/ref/conv_2d.cc b/mace/ops/ref/conv_2d.cc new file mode 100644 index 00000000..4707d922 --- /dev/null +++ b/mace/ops/ref/conv_2d.cc @@ -0,0 +1,111 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "mace/ops/ref/conv_2d.h" + +#include +#include "mace/ops/common/conv_pool_2d_util.h" + +namespace mace { +namespace ops { +namespace ref { + +MaceStatus Conv2d::Compute(const OpContext *context, + const Tensor *input, + const Tensor *filter, + Tensor *output) { + MACE_UNUSED(context); + + const std::vector in_shape = input->shape(); + const std::vector filter_shape = filter->shape(); + const std::vector out_shape = output->shape(); + const std::vector stride_hw{stride_h_, stride_w_}; + const std::vector dilation_hw{dilation_h_, dilation_w_}; + const std::vector paddings{pad_h_, pad_w_}; + const index_t pad_top = pad_h_ >> 1; + const index_t pad_left = pad_w_ >> 1; + + std::vector output_shape(4); + + CalcOutputSize(in_shape.data(), + NCHW, + filter_shape.data(), + OIHW, + paddings.data(), + dilation_hw.data(), + stride_hw.data(), + RoundType::FLOOR, + output_shape.data()); + output->Resize(output_shape); + + const index_t in_image_size = in_shape[2] * in_shape[3]; + const index_t out_image_size = out_shape[2] * out_shape[3]; + const index_t in_batch_size = filter_shape[1] * in_image_size; + const index_t out_batch_size = filter_shape[0] * out_image_size; + const index_t filter_size = filter_shape[2] * filter_shape[3]; + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard filter_guard(filter); + Tensor::MappingGuard output_guard(output); + auto input_data = input->data(); + auto filter_data = filter->data(); + auto output_data = output->mutable_data(); + +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t b = 0; b < in_shape[0]; b++) { + for (index_t m = 0; m < filter_shape[0]; ++m) { + const index_t in_height = in_shape[2]; + const index_t in_width = in_shape[3]; + const index_t out_height = out_shape[2]; + const index_t out_width = out_shape[3]; + const index_t in_channels = filter_shape[1]; + + float *out_ptr_base = + output_data + b * out_batch_size + m * out_image_size; + + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w < out_width; ++w) { + float sum = 0; + + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input_data + b * in_batch_size + c * in_image_size; + const float *filter_ptr = + filter_data + m * in_channels * filter_size + c * filter_size; + + for (index_t kh = 0; kh < filter_shape[2]; ++kh) { + for (index_t kw = 0; kw < filter_shape[3]; ++kw) { + const index_t ih = -pad_top + h * stride_h_ + kh * dilation_h_; + const index_t iw = -pad_left + w * stride_w_ + kw * dilation_w_; + if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) { + sum += in_ptr_base[ih * in_width + iw] * filter_ptr[kw]; + } + } // kw + filter_ptr += filter_shape[3]; + } // kh + } // c + + out_ptr_base[h * out_width + w] = sum; + } // w + } // h + } // m + } // b + return MaceStatus::MACE_SUCCESS; +} + +} // namespace ref +} // namespace ops +} // namespace mace + + diff --git a/mace/ops/ref/conv_2d.h b/mace/ops/ref/conv_2d.h new file mode 100644 index 00000000..e99af5cf --- /dev/null +++ b/mace/ops/ref/conv_2d.h @@ -0,0 +1,76 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef MACE_OPS_REF_CONV_2D_H_ +#define MACE_OPS_REF_CONV_2D_H_ + +#include "mace/public/mace.h" +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" + +namespace mace { +namespace ops { +namespace ref { + +template +class Conv2d { + public: + Conv2d(int stride_h, int stride_w, int dilation_h, int dilation_w); + ~Conv2d() {} + MaceStatus Compute( + const OpContext *context, + const Tensor *input, + const Tensor *filter, + Tensor *output); +}; + +template<> +class Conv2d { + public: + Conv2d(int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w) + : pad_h_(pad_h), + pad_w_(pad_w), + stride_h_(stride_h), + stride_w_(stride_w), + dilation_h_(dilation_h), + dilation_w_(dilation_w) {} + ~Conv2d() {} + // Always row-major after transpose + MaceStatus Compute( + const OpContext *context, + const Tensor *input, + const Tensor *filter, + Tensor *output); + + private: + int pad_h_; + int pad_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; +}; + +} // namespace ref +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_REF_CONV_2D_H_ + diff --git a/mace/ops/ref/gemm.cc b/mace/ops/ref/gemm.cc new file mode 100644 index 00000000..e9d13c91 --- /dev/null +++ b/mace/ops/ref/gemm.cc @@ -0,0 +1,116 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "mace/ops/ref/gemm.h" + +namespace mace { +namespace ops { +namespace ref { + +MaceStatus Gemm::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t rows, + const index_t cols, + const index_t depth, + const MatrixMajor lhs_major, + const MatrixMajor rhs_major, + const MatrixMajor output_major, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output) { + MACE_UNUSED(context); + + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard output_guard(output); + const float *lhs_data = lhs->data(); + const float *rhs_data = rhs->data(); + float *output_data = output->mutable_data(); + + for (index_t b = 0; b < batch; ++b) { + MatrixMap + lhs_matrix + (lhs_data + static_cast(lhs_batched) * b * rows * depth, + lhs_major, + rows, + depth); + MatrixMap + rhs_matrix + (rhs_data + static_cast(rhs_batched) * b * depth * cols, + rhs_major, + depth, + cols); + MatrixMap + output_matrix(output_data + b * rows * cols, output_major, rows, cols); + + for (index_t r = 0; r < rows; ++r) { + for (index_t c = 0; c < cols; ++c) { + float sum = 0; + for (index_t d = 0; d < depth; ++d) { + sum += lhs_matrix(r, d) * rhs_matrix(d, c); + } // d + + *output_matrix.data(r, c) = sum; + } // c + } // r + } // b + + return MaceStatus::MACE_SUCCESS; +} + +MaceStatus Gemm::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t lhs_rows, + const index_t lhs_cols, + const index_t rhs_rows, + const index_t rhs_cols, + const bool transpose_lhs, + const bool transpose_rhs, + const bool transpose_out, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output) { + index_t rows = transpose_lhs ? lhs_cols : lhs_rows; + index_t depth = transpose_lhs ? lhs_rows : lhs_cols; + index_t cols = transpose_rhs ? rhs_rows : rhs_cols; + index_t depth2 = transpose_rhs ? rhs_cols : rhs_rows; + MACE_CHECK(depth == depth2, + "Matrices that multiply have inconsistent depth dim: ", + depth, + " vs. ", + depth2); + + return Compute(context, + lhs, + rhs, + batch, + rows, + cols, + depth, + transpose_lhs ? ColMajor : RowMajor, + transpose_rhs ? ColMajor : RowMajor, + transpose_out ? ColMajor : RowMajor, + lhs_batched, + rhs_batched, + output); +} + +} // namespace ref +} // namespace ops +} // namespace mace diff --git a/mace/ops/ref/gemm.h b/mace/ops/ref/gemm.h new file mode 100644 index 00000000..bf1826ad --- /dev/null +++ b/mace/ops/ref/gemm.h @@ -0,0 +1,89 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef MACE_OPS_REF_GEMM_H_ +#define MACE_OPS_REF_GEMM_H_ + +#include "mace/public/mace.h" +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" +#include "mace/ops/common/matrix.h" + +namespace mace { +namespace ops { +namespace ref { + +template +class Gemm { + public: + Gemm() {} + ~Gemm() {} + MaceStatus Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t rows, + const index_t cols, + const index_t depth, + const MatrixMajor lhs_major, + const MatrixMajor rhs_major, + const MatrixMajor output_major, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output); +}; + +template<> +class Gemm { + public: + Gemm() {} + ~Gemm() {} + MaceStatus Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t rows, + const index_t cols, + const index_t depth, + const MatrixMajor lhs_major, + const MatrixMajor rhs_major, + const MatrixMajor output_major, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output); + // Original matrix before transpose has row-major + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t lhs_rows, + const index_t lhs_cols, + const index_t rhs_rows, + const index_t rhs_cols, + const bool transpose_lhs, + const bool transpose_rhs, + const bool transpose_out, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output); +}; + +} // namespace ref +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_REF_GEMM_H_ + diff --git a/mace/ops/ref/gemv.cc b/mace/ops/ref/gemv.cc index 555c99e2..59fc31dc 100644 --- a/mace/ops/ref/gemv.cc +++ b/mace/ops/ref/gemv.cc @@ -31,6 +31,7 @@ MaceStatus Gemv::Compute(const OpContext *context, const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output) { MACE_UNUSED(context); @@ -52,9 +53,9 @@ MaceStatus Gemv::Compute(const OpContext *context, float sum = bias ? bias_data[h] : 0; for (index_t w = 0; w < lhs_width; ++w) { sum += lhs_data[ - static_cast(lhs_batched) * b * lhs_height * lhs_width - + h * lhs_width + w] - * rhs_data[b * lhs_width + w]; + static_cast(lhs_batched) * b * lhs_height * lhs_width + + h * lhs_width + w] + * rhs_data[static_cast(rhs_batched) * b * lhs_width + w]; } // w output_data[b * lhs_height + h] = sum; @@ -73,6 +74,7 @@ MaceStatus Gemv::Compute(const OpContext *context, const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output) { MACE_UNUSED(context); @@ -102,7 +104,8 @@ MaceStatus Gemv::Compute(const OpContext *context, sum += (lhs_data[ static_cast(lhs_batched) * b * lhs_height * lhs_width + h * lhs_width + w] - lhs_zero) - * (rhs_data[b * lhs_width + w] - rhs_zero); + * (rhs_data[static_cast(rhs_batched) * b * lhs_width + w] + - rhs_zero); } // w output_data[b * lhs_height + h] = @@ -120,6 +123,7 @@ MaceStatus Gemv::Compute(const OpContext *context, const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output) { MACE_UNUSED(context); @@ -146,7 +150,8 @@ MaceStatus Gemv::Compute(const OpContext *context, sum += (lhs_data[ static_cast(lhs_batched) * b * lhs_height * lhs_width + h * lhs_width + w] - lhs_zero) - * (rhs_data[b * lhs_width + w] - rhs_zero); + * (rhs_data[static_cast(rhs_batched) * b * lhs_width + w] + - rhs_zero); } // w output_data[b * lhs_height + h] = sum; diff --git a/mace/ops/ref/gemv.h b/mace/ops/ref/gemv.h index 46892839..7116b8fa 100644 --- a/mace/ops/ref/gemv.h +++ b/mace/ops/ref/gemv.h @@ -39,6 +39,7 @@ class Gemv { const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output); }; @@ -57,6 +58,7 @@ class Gemv { const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output); }; @@ -76,6 +78,7 @@ class Gemv { const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output); }; @@ -94,6 +97,7 @@ class Gemv { const index_t lhs_height, const index_t lhs_width, const bool lhs_batched, + const bool rhs_batched, Tensor *output); }; #endif // MACE_ENABLE_QUANTIZE diff --git a/mace/ops/sgemm.cc b/mace/ops/sgemm.cc index 445b9cf6..1601aac2 100644 --- a/mace/ops/sgemm.cc +++ b/mace/ops/sgemm.cc @@ -27,39 +27,17 @@ #define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) #endif -namespace { - -inline void AdviseFree(void *addr, size_t length) { - int page_size = sysconf(_SC_PAGESIZE); - void *addr_aligned = - reinterpret_cast( - (reinterpret_cast(addr) + page_size - 1) - & (~(page_size - 1))); - uintptr_t delta = - reinterpret_cast(addr_aligned) - - reinterpret_cast(addr); - if (length >= delta + page_size) { - size_t len_aligned = (length - delta) & (~(page_size - 1)); - int ret = madvise(addr_aligned, len_aligned, MADV_DONTNEED); - if (ret != 0) { - LOG(ERROR) << "Advise free failed: " << strerror(errno); - } - } -} - -} // namespace - namespace mace { namespace ops { -void SGemm::operator()(const MatrixMap &lhs, - const MatrixMap &rhs, - MatrixMap *result, +void SGemm::operator()(const SGemmMatrixMap &lhs, + const SGemmMatrixMap &rhs, + SGemmMatrixMap *result, ScratchBuffer *scratch_buffer) { if (lhs.is_const() && !rhs.is_const()) { - MatrixMap lhs_transpose = lhs.transpose(); - MatrixMap rhs_transpose = rhs.transpose(); - MatrixMap result_transpose = result->transpose(); + SGemmMatrixMap lhs_transpose = lhs.transpose(); + SGemmMatrixMap rhs_transpose = rhs.transpose(); + SGemmMatrixMap result_transpose = result->transpose(); return operator()(rhs_transpose, lhs_transpose, &result_transpose, @@ -150,18 +128,18 @@ void SGemm::Run(const float *A, width_c = height_b; } - MatrixMap matrix_a = - MatrixMap(batch, + SGemmMatrixMap matrix_a = + SGemmMatrixMap(batch, height_a, width_a, - ops::RowMajor, + ops::SGemmRowMajor, A, is_a_weight); - MatrixMap matrix_b = - ops::MatrixMap(batch, + SGemmMatrixMap matrix_b = + ops::SGemmMatrixMap(batch, height_b, width_b, - ops::RowMajor, + ops::SGemmRowMajor, B, is_b_weight); if (transpose_a) { @@ -170,7 +148,8 @@ void SGemm::Run(const float *A, if (transpose_b) { matrix_b = matrix_b.transpose(); } - MatrixMap matrix_c(batch, height_c, width_c, ops::RowMajor, C); + SGemmMatrixMap + matrix_c(batch, height_c, width_c, ops::SGemmRowMajor, C); operator()(matrix_a, matrix_b, &matrix_c, scratch_buffer); } @@ -930,17 +909,17 @@ void SGemm::RunPerBatch(const float *lhs_data, } // bw } -void SGemm::PackLhs(const MatrixMap &lhs, +void SGemm::PackLhs(const SGemmMatrixMap &lhs, PackedBlock *packed_block) { - Pack(lhs, PackOrder::ColMajor, packed_block); + Pack(lhs, PackOrder::SGemmColMajor, packed_block); } -void SGemm::PackRhs(const MatrixMap &rhs, +void SGemm::PackRhs(const SGemmMatrixMap &rhs, PackedBlock *packed_block) { - Pack(rhs, PackOrder::RowMajor, packed_block); + Pack(rhs, PackOrder::SGemmRowMajor, packed_block); } -void SGemm::Pack(const MatrixMap &src, +void SGemm::Pack(const SGemmMatrixMap &src, const PackOrder order, PackedBlock *packed_block) { MACE_CHECK_NOTNULL(packed_block); @@ -963,7 +942,7 @@ void SGemm::Pack(const MatrixMap &src, } void SGemm::UnPack(const PackedBlock &packed_result, - MatrixMap *matrix_map) { + SGemmMatrixMap *matrix_map) { MACE_CHECK_NOTNULL(matrix_map); const index_t height = matrix_map->row(); @@ -984,7 +963,7 @@ void SGemm::UnPack(const PackedBlock &packed_result, #undef MACE_SGEMM_UNPACK_PER_BATCH } -void SGemm::PackPerBatch(const MatrixMap &src, +void SGemm::PackPerBatch(const SGemmMatrixMap &src, const PackOrder order, const index_t batch_index, float *packed_data) { @@ -994,7 +973,8 @@ void SGemm::PackPerBatch(const MatrixMap &src, const index_t width = src.col(); auto src_data = src.batch_data(batch_index); - if (src.map_major() == Major::RowMajor && order == PackOrder::ColMajor) { + if (src.map_major() == Major::SGemmRowMajor + && order == PackOrder::SGemmColMajor) { // This is for packing no-transpose lhs. index_t h = 0; #if defined(MACE_ENABLE_NEON) @@ -1040,8 +1020,8 @@ void SGemm::PackPerBatch(const MatrixMap &src, for (index_t ih = h; ih < height; ++ih) { std::copy_n(src_data + ih * width, width, packed_data + ih * width); } - } else if (src.map_major() == Major::ColMajor && - order == PackOrder::ColMajor) { + } else if (src.map_major() == Major::SGemmColMajor && + order == PackOrder::SGemmColMajor) { // This is for packing transpose-needed lhs. index_t h = 0; #if defined(MACE_ENABLE_NEON) @@ -1082,8 +1062,8 @@ void SGemm::PackPerBatch(const MatrixMap &src, packed_data_ptr[w] = src_data_ptr[w * height]; } } - } else if (src.map_major() == Major::RowMajor && - order == PackOrder::RowMajor) { + } else if (src.map_major() == Major::SGemmRowMajor && + order == PackOrder::SGemmRowMajor) { // This is for packing no-transpose rhs. index_t w = 0; #if defined(MACE_ENABLE_NEON) @@ -1108,8 +1088,8 @@ void SGemm::PackPerBatch(const MatrixMap &src, packed_data_ptr[h] = src_data_ptr[h * width]; } } - } else if (src.map_major() == Major::ColMajor && - order == PackOrder::RowMajor) { + } else if (src.map_major() == Major::SGemmColMajor && + order == PackOrder::SGemmRowMajor) { // This is for packing transpose-needed rhs. index_t w = 0; #if defined(MACE_ENABLE_NEON) @@ -1138,14 +1118,14 @@ void SGemm::PackPerBatch(const MatrixMap &src, void SGemm::UnPackPerBatch(const float *packed_data, const index_t batch_index, - MatrixMap *matrix_map) { + SGemmMatrixMap *matrix_map) { MACE_CHECK_NOTNULL(matrix_map); const index_t height = matrix_map->row(); const index_t width = matrix_map->col(); auto unpacked_data = matrix_map->batch_data(batch_index); - if (matrix_map->map_major() == Major::RowMajor) { + if (matrix_map->map_major() == Major::SGemmRowMajor) { // This is for non-transposed result index_t w = 0; #if defined(MACE_ENABLE_NEON) diff --git a/mace/ops/sgemm.h b/mace/ops/sgemm.h index 25bbfcc7..1320d1be 100644 --- a/mace/ops/sgemm.h +++ b/mace/ops/sgemm.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// This implementation is deprecated. use mace/ops/arm/fp32/gemm.h instead. + #ifndef MACE_OPS_SGEMM_H_ #define MACE_OPS_SGEMM_H_ @@ -30,16 +32,16 @@ namespace mace { namespace ops { enum Major { - RowMajor, - ColMajor + SGemmRowMajor, + SGemmColMajor }; template -class MatrixMap { +class SGemmMatrixMap { public: - MatrixMap() {} + SGemmMatrixMap() {} - MatrixMap(const index_t batch, + SGemmMatrixMap(const index_t batch, const index_t row, const index_t col, const Major major, @@ -48,14 +50,20 @@ class MatrixMap { batch_(batch), row_(row), col_(col), - stride_(major == RowMajor ? col : row), + stride_(major == SGemmRowMajor ? col : row), major_(major), data_(data), is_const_(is_const) {} - MatrixMap transpose() const { - Major transpose_major = major_ == RowMajor ? ColMajor : RowMajor; - return MatrixMap(batch_, col_, row_, transpose_major, data_, is_const_); + SGemmMatrixMap transpose() const { + Major transpose_major = + major_ == SGemmRowMajor ? SGemmColMajor : SGemmRowMajor; + return SGemmMatrixMap(batch_, + col_, + row_, + transpose_major, + data_, + is_const_); } index_t batch() const { @@ -114,9 +122,9 @@ class SGemm { packed_rhs_(nullptr), packed_(false) {} - void operator()(const MatrixMap &lhs, - const MatrixMap &rhs, - MatrixMap *result, + void operator()(const SGemmMatrixMap &lhs, + const SGemmMatrixMap &rhs, + SGemmMatrixMap *result, ScratchBuffer *scratch_buffer = nullptr); void Run(const float *A, @@ -133,28 +141,28 @@ class SGemm { float *C, ScratchBuffer *scratch_buffer = nullptr); - void PackLhs(const MatrixMap &lhs, + void PackLhs(const SGemmMatrixMap &lhs, PackedBlock *packed_block); - void PackRhs(const MatrixMap &rhs, + void PackRhs(const SGemmMatrixMap &rhs, PackedBlock *packed_block); void UnPack(const PackedBlock &packed_result, - MatrixMap *matrix_map); + SGemmMatrixMap *matrix_map); private: - void Pack(const MatrixMap &src, + void Pack(const SGemmMatrixMap &src, const PackOrder order, PackedBlock *packed_block); - void PackPerBatch(const MatrixMap &src, + void PackPerBatch(const SGemmMatrixMap &src, const PackOrder order, const index_t batch_index, float *packed_data); void UnPackPerBatch(const float *packed_data, const index_t batch_index, - MatrixMap *matrix_map); + SGemmMatrixMap *matrix_map); void RunInternal(const PackedBlock &lhs, const PackedBlock &rhs, diff --git a/mace/ops/sgemm_pack_test.cc b/mace/ops/sgemm_pack_test.cc index c8911985..69766cb9 100644 --- a/mace/ops/sgemm_pack_test.cc +++ b/mace/ops/sgemm_pack_test.cc @@ -31,10 +31,11 @@ void TestPack(const std::vector &data, Major src_order, PackOrder pack_order) { SGemm sg; - MatrixMap src_matrix(1, height, width, src_order, data.data()); + SGemmMatrixMap + src_matrix(1, height, width, src_order, data.data()); PackedBlock packed; packed.Resize({height, width}); - if (pack_order == PackOrder::ColMajor) { + if (pack_order == PackOrder::SGemmColMajor) { sg.PackLhs(src_matrix, &packed); } else { sg.PackRhs(src_matrix, &packed); @@ -57,18 +58,19 @@ void TestUnPack(const index_t height, data[i] = rand_r(&seed); } - MatrixMap src_matrix(1, height, width, src_order, data.data()); + SGemmMatrixMap + src_matrix(1, height, width, src_order, data.data()); PackedBlock packed; packed.Resize({height, width}); SGemm sg; - if (pack_order == PackOrder::ColMajor) { + if (pack_order == PackOrder::SGemmColMajor) { sg.PackLhs(src_matrix, &packed); } else { sg.PackRhs(src_matrix, &packed); } std::vector unpacked(matrix_size); - MatrixMap + SGemmMatrixMap unpacked_matrix(1, height, width, src_order, unpacked.data()); sg.UnPack(packed, &unpacked_matrix); auto unpacked_data = unpacked.data(); @@ -87,78 +89,78 @@ TEST(SGemmPackTest, Pack) { // For no-transpose lhs TestPack(data, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - 3, 4, Major::RowMajor, PackOrder::ColMajor); + 3, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor); #if defined(MACE_ENABLE_NEON) TestPack(data, {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}, - 4, 4, Major::RowMajor, PackOrder::ColMajor); + 4, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor); TestPack(data, {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19, 20}, - 5, 4, Major::RowMajor, PackOrder::ColMajor); + 5, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor); #if defined(__aarch64__) TestPack(data, {1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31, 4, 8, 12, 16, 20, 24, 28, 32, 33, 34, 35, 36}, - 9, 4, Major::RowMajor, PackOrder::ColMajor); + 9, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor); #endif #endif // For transpose-needed lhs TestPack(data, {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12}, - 3, 4, Major::ColMajor, PackOrder::ColMajor); + 3, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor); #if defined(MACE_ENABLE_NEON) TestPack(data, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - 4, 4, Major::ColMajor, PackOrder::ColMajor); + 4, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor); TestPack(data, {1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15, 20}, - 5, 4, Major::ColMajor, PackOrder::ColMajor); + 5, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor); #if defined(__aarch64__) TestPack(data, {1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 9, 18, 27, 36}, - 9, 4, Major::ColMajor, PackOrder::ColMajor); + 9, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor); #endif #endif // For no-transpose rhs TestPack(data, {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12}, - 4, 3, Major::RowMajor, PackOrder::RowMajor); + 4, 3, Major::SGemmRowMajor, PackOrder::SGemmRowMajor); #if defined(MACE_ENABLE_NEON) TestPack(data, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - 4, 4, Major::RowMajor, PackOrder::RowMajor); + 4, 4, Major::SGemmRowMajor, PackOrder::SGemmRowMajor); TestPack(data, {1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15, 20}, - 4, 5, Major::RowMajor, PackOrder::RowMajor); + 4, 5, Major::SGemmRowMajor, PackOrder::SGemmRowMajor); #endif // For transpose-needed rhs TestPack(data, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - 4, 3, Major::ColMajor, PackOrder::RowMajor); + 4, 3, Major::SGemmColMajor, PackOrder::SGemmRowMajor); #if defined(MACE_ENABLE_NEON) TestPack(data, {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}, - 4, 4, Major::ColMajor, PackOrder::RowMajor); + 4, 4, Major::SGemmColMajor, PackOrder::SGemmRowMajor); TestPack(data, {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19, 20}, - 4, 5, Major::ColMajor, PackOrder::RowMajor); + 4, 5, Major::SGemmColMajor, PackOrder::SGemmRowMajor); #endif } TEST(SGemmPackTest, UnPack) { - TestUnPack(4, 3, Major::RowMajor, PackOrder::RowMajor); - TestUnPack(4, 4, Major::RowMajor, PackOrder::RowMajor); - TestUnPack(4, 5, Major::RowMajor, PackOrder::RowMajor); - TestUnPack(4, 100, Major::RowMajor, PackOrder::RowMajor); - TestUnPack(4, 3, Major::ColMajor, PackOrder::RowMajor); - TestUnPack(4, 4, Major::ColMajor, PackOrder::RowMajor); - TestUnPack(4, 5, Major::ColMajor, PackOrder::RowMajor); - TestUnPack(4, 100, Major::ColMajor, PackOrder::RowMajor); + TestUnPack(4, 3, Major::SGemmRowMajor, PackOrder::SGemmRowMajor); + TestUnPack(4, 4, Major::SGemmRowMajor, PackOrder::SGemmRowMajor); + TestUnPack(4, 5, Major::SGemmRowMajor, PackOrder::SGemmRowMajor); + TestUnPack(4, 100, Major::SGemmRowMajor, PackOrder::SGemmRowMajor); + TestUnPack(4, 3, Major::SGemmColMajor, PackOrder::SGemmRowMajor); + TestUnPack(4, 4, Major::SGemmColMajor, PackOrder::SGemmRowMajor); + TestUnPack(4, 5, Major::SGemmColMajor, PackOrder::SGemmRowMajor); + TestUnPack(4, 100, Major::SGemmColMajor, PackOrder::SGemmRowMajor); } } // namespace test -- GitLab