提交 bcc88336 编写于 作者: 李寅

Optimize gemm

上级 19bf36b7
......@@ -14,6 +14,10 @@
#include "mace/core/allocator.h"
#include <unistd.h>
#include <sys/mman.h>
#include <memory>
namespace mace {
Allocator *GetCPUAllocator() {
......@@ -21,4 +25,22 @@ Allocator *GetCPUAllocator() {
return &allocator;
}
void AdviseFree(void *addr, size_t length) {
int page_size = sysconf(_SC_PAGESIZE);
void *addr_aligned =
reinterpret_cast<void *>(
(reinterpret_cast<uintptr_t>(addr) + page_size - 1)
& (~(page_size - 1)));
uintptr_t delta =
reinterpret_cast<uintptr_t>(addr_aligned)
- reinterpret_cast<uintptr_t>(addr);
if (length >= delta + page_size) {
size_t len_aligned = (length - delta) & (~(page_size - 1));
int ret = madvise(addr_aligned, len_aligned, MADV_DONTNEED);
if (ret != 0) {
LOG(ERROR) << "Advise free failed: " << strerror(errno);
}
}
}
} // namespace mace
......@@ -40,6 +40,10 @@ constexpr size_t kMaceAlignment = 64;
constexpr size_t kMaceAlignment = 32;
#endif
inline index_t PadAlignSize(index_t size) {
return (size + kMaceAlignment - 1) & (~(kMaceAlignment - 1));
}
class Allocator {
public:
Allocator() {}
......@@ -140,6 +144,8 @@ class CPUAllocator : public Allocator {
// Global CPU allocator used for CPU/GPU/DSP
Allocator *GetCPUAllocator();
void AdviseFree(void *addr, size_t length);
} // namespace mace
#endif // MACE_CORE_ALLOCATOR_H_
......@@ -384,7 +384,7 @@ class BufferSlice : public BufferBase {
BufferSlice(const BufferSlice &other)
: BufferSlice(other.buffer_, other.offset_, other.size_) {}
~BufferSlice() {
virtual ~BufferSlice() {
if (buffer_ != nullptr && mapped_buf_ != nullptr) {
UnMap();
}
......@@ -506,7 +506,7 @@ class ScratchBuffer: public Buffer {
virtual ~ScratchBuffer() {}
MaceStatus GrowSize(index_t size) {
MaceStatus GrowSize(const index_t size) {
if (size > size_) {
VLOG(1) << "Grow scratch size to: " << size;
MACE_CHECK(offset_ == 0, "scratch is being used, cannot grow size");
......
......@@ -25,11 +25,11 @@ void OpContext::set_device(Device *device) {
device_ = device;
}
Device* OpContext::device() {
Device* OpContext::device() const {
return device_;
}
Workspace* OpContext::workspace() {
Workspace* OpContext::workspace() const {
return ws_;
}
......@@ -37,7 +37,7 @@ void OpContext::set_future(StatsFuture *future) {
future_ = future;
}
StatsFuture *OpContext::future() {
StatsFuture *OpContext::future() const {
return future_;
}
......
......@@ -26,11 +26,11 @@ class OpContext {
OpContext(Workspace *ws, Device *device);
~OpContext();
void set_device(Device *device);
Device *device();
Workspace *workspace();
Device *device() const;
Workspace *workspace() const;
void set_future(StatsFuture *future);
StatsFuture *future();
StatsFuture *future() const;
private:
Device *device_;
Workspace *ws_;
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_CONV_2D_H_
#define MACE_OPS_ARM_FP32_CONV_2D_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Conv2dBase {
public:
Conv2dBase() = default;
virtual ~Conv2dBase() = default;
virtual MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) = 0;
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_CONV_2D_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/fp32/conv_2d_1x1.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
MaceStatus Conv2dK1x1::Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) {
index_t batch = input->dim(0);
index_t height = input->dim(2);
index_t width = input->dim(3);
index_t in_channels = input->dim(1);
index_t out_channels = filter->dim(0);
MACE_RETURN_IF_ERROR(output->Resize({batch, out_channels, height, width}));
context->device()->scratch_buffer()->Rewind();
return gemm_.Compute(context,
filter,
input,
batch,
out_channels,
in_channels,
in_channels,
height * width,
false,
false,
false,
false,
true,
output);
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_CONV_2D_1X1_H_
#define MACE_OPS_ARM_FP32_CONV_2D_1X1_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/arm/fp32/conv_2d.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Conv2dK1x1 : public Conv2dBase {
public:
Conv2dK1x1() : gemm_(true) {}
virtual ~Conv2dK1x1() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
Gemm gemm_;
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_CONV_2D_1X1_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/fp32/gemm.h"
#include <arm_neon.h>
#include <algorithm>
#include <utility>
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
enum { kNoCache, kCacheLhs, kCacheRhs };
MaceStatus Gemm::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
MACE_UNUSED(context);
MACE_CHECK(output->size() == batch * rows * cols,
"Need resize output tensor before call gemm.");
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard output_guard(output);
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
float *output_data = output->mutable_data<float>();
#ifdef __aarch64__
const index_t row_block_size = 8;
#else
const index_t row_block_size = 4;
#endif
const index_t col_block_size = 8;
const index_t depth_block_size = 4;
const index_t row_block_count = RoundUpDiv(rows, row_block_size);
const index_t col_block_count = RoundUpDiv(cols, col_block_size);
const index_t rows_padded = RoundUp(rows, row_block_size);
const index_t cols_padded = RoundUp(cols, col_block_size);
const index_t depth_padded = RoundUp(depth, depth_block_size);
ScratchBuffer *scratch = &tmp_scratch_buffer_;
if (context != nullptr && context->device()->scratch_buffer() != nullptr) {
scratch = context->device()->scratch_buffer();
}
index_t packed_lhs_size =
PadAlignSize(sizeof(float) * rows_padded * depth_padded);
index_t packed_rhs_size =
PadAlignSize(sizeof(float) * depth_padded * cols_padded);
index_t packed_output_size =
PadAlignSize(sizeof(float) * rows_padded * cols_padded);
// resize to the total size of lhs & rhs & output anyway,
// in case we do not cache const tensor for saving memory
MACE_RETURN_IF_ERROR(scratch->GrowSize(
packed_lhs_size + packed_rhs_size + packed_output_size));
float *packed_lhs_data =
scratch->Scratch(packed_lhs_size).mutable_data<float>();
float *packed_rhs_data =
scratch->Scratch(packed_rhs_size).mutable_data<float>();
float *packed_output_data =
scratch->Scratch(packed_output_size).mutable_data<float>();
int cache_side = kNoCache;
if (cached_ == kCacheLhs) {
packed_lhs_data = pack_cache_.mutable_data<float>();
} else if (cached_ == kCacheRhs) {
packed_rhs_data = pack_cache_.mutable_data<float>();
} else if (should_cache_pack_) {
if (lhs->is_weight() && !lhs_batched) {
cache_side = kCacheLhs;
pack_cache_.Resize(packed_lhs_size);
packed_lhs_data = pack_cache_.mutable_data<float>();
} else if (rhs->is_weight() && !rhs_batched) {
cache_side = kCacheRhs;
pack_cache_.Resize(packed_rhs_size);
packed_rhs_data = pack_cache_.mutable_data<float>();
}
}
for (index_t b = 0; b < batch; ++b) {
MatrixMap<const float>
lhs_matrix
(lhs_data + static_cast<index_t>(lhs_batched) * b * rows * depth,
lhs_major,
rows,
depth);
MatrixMap<const float>
rhs_matrix
(rhs_data + static_cast<index_t>(rhs_batched) * b * depth * cols,
rhs_major,
depth,
cols);
MatrixMap<float> output_matrix
(output_data + b * rows * cols, output_major, rows, cols);
// pack lhs
if (cached_ != kCacheLhs) {
#pragma omp parallel for schedule(runtime)
for (index_t row_block_idx = 0; row_block_idx < row_block_count;
++row_block_idx) {
const index_t start_row = row_block_idx * row_block_size;
const index_t
row_block_len = std::min(row_block_size, rows - start_row);
float *packed_lhs_data_block =
packed_lhs_data + row_block_idx * row_block_size * depth_padded;
PackLhs(lhs_matrix.block(start_row, 0, row_block_len, depth),
packed_lhs_data_block);
}
if (cache_side == kCacheLhs) {
cached_ = kCacheLhs;
if (lhs->UnderlyingBuffer()->OnHost()) {
AdviseFree(reinterpret_cast<void *>(const_cast<float *>(lhs->data<
float>())),
lhs->raw_size());
}
}
}
// pack rhs
if (cached_ != kCacheRhs) {
#pragma omp parallel for schedule(runtime)
for (index_t col_block_idx = 0; col_block_idx < col_block_count;
++col_block_idx) {
const index_t start_col = col_block_idx * col_block_size;
const index_t
col_block_len = std::min(col_block_size, cols - start_col);
float *packed_rhs_data_block =
packed_rhs_data + col_block_idx * col_block_size * depth_padded;
PackRhs(rhs_matrix.block(0, start_col, depth, col_block_len),
packed_rhs_data_block);
}
if (cache_side == kCacheRhs) {
cached_ = kCacheRhs;
if (rhs->UnderlyingBuffer()->OnHost()) {
AdviseFree(reinterpret_cast<void *>(const_cast<float *>(rhs->data<
float>())),
rhs->raw_size());
}
}
}
// multiply lhs and rhs
#pragma omp parallel for schedule(runtime)
for (index_t row_block_idx = 0; row_block_idx < row_block_count;
++row_block_idx) {
const index_t start_row = row_block_idx * row_block_size;
const index_t row_block_len = std::min(row_block_size, rows - start_row);
const float *packed_lhs_data_block =
packed_lhs_data + row_block_idx * row_block_size * depth_padded;
for (index_t col_block_idx = 0; col_block_idx < col_block_count;
++col_block_idx) {
const index_t start_col = col_block_idx * col_block_size;
const index_t
col_block_len = std::min(col_block_size, cols - start_col);
const float *packed_rhs_data_block =
packed_rhs_data + col_block_idx * col_block_size * depth_padded;
float *packed_output_data_block =
packed_output_data + row_block_idx * row_block_size * cols_padded
+ col_block_idx * col_block_size;
ComputeBlock(packed_lhs_data_block,
packed_rhs_data_block,
depth_padded,
packed_output_data_block);
MatrixMap<float> output_block = output_matrix.block(start_row,
start_col,
row_block_len,
col_block_len);
UnpackOutput(packed_output_data_block, &output_block);
} // col_block_idx
} // row_block_idx
} // b
return MaceStatus::MACE_SUCCESS;
}
void Gemm::ComputeBlock(const float *packed_lhs_data,
const float *packed_rhs_data,
const index_t depth_padded,
float *packed_output_data) {
/* Ref:
for (index_t r = 0; r < block_size; ++r) {
for (index_t c = 0; c < block_size; ++c) {
float sum = 0;
for (index_t d = 0; d < depth; ++d) {
// (r, d) * (d, c)
sum += packed_lhs_data[d * r_block_size + r]
* packed_rhs_data[d * c_block_size + c];
}
packed_output_data[r * c_block_size + c] = sum;
}
}
*/
const float *lhs_ptr = packed_lhs_data;
const float *rhs_ptr = packed_rhs_data;
const index_t depth_block_count = depth_padded / 4;
#ifdef __aarch64__
// Register layout: (8x4) x (4,8)
//
// +--------+--------+
// | v8 ... | v9 ... |
// Rhs +--------+--------+
// | v10... | v11... |
// +--------+--------+
// | v12... | v13... |
// +--------+--------+
// | v14... | v15... |
// +--------+--------+
//
// Lhs
//
// +----+----+----+----+ - - +--------+--------+
// | v0 | v2 | v4 | v6 | | v16... | v17... |
// | . | | | | | v18... | v19... |
// | . | | | | | v20... | v21... |
// | . | | | | | v22... | v23... |
// +----+----|----+----+ +--------+--------+
// | v1 | v3 | v5 | v7 | | v24... | v25... |
// | . | | | | | v26... | v27... |
// | . | | | | | v28... | v29... |
// | . | | | | | v30... | v31... |
// +----+----|----+----+ +--------+--------+
//
// Accumulator
//
if (depth_block_count > 0) {
index_t r_depth_block_count = depth_block_count;
// just make compiler happy
MACE_UNUSED(r_depth_block_count);
asm volatile(
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"dup v18.4s, wzr \n"
"dup v19.4s, wzr \n"
"dup v20.4s, wzr \n"
"dup v21.4s, wzr \n"
"dup v22.4s, wzr \n"
"dup v23.4s, wzr \n"
"dup v24.4s, wzr \n"
"dup v25.4s, wzr \n"
"dup v26.4s, wzr \n"
"dup v27.4s, wzr \n"
"dup v28.4s, wzr \n"
"dup v29.4s, wzr \n"
"dup v30.4s, wzr \n"
"dup v31.4s, wzr \n"
// prelogue
"ld1 {v0.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v2.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v3.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v4.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v5.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v6.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v7.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v8.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v9.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v10.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v11.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v12.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v13.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v14.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v15.4s}, [%[rhs_ptr]], #16 \n"
"subs %[r_depth_block_count], %[r_depth_block_count], #1 \n"
"beq 1f\n"
"0: \n"
"fmla v16.4s, v8.4s, v0.s[0] \n"
"fmla v17.4s, v9.4s, v0.s[0] \n"
"fmla v18.4s, v8.4s, v0.s[1] \n"
"fmla v19.4s, v9.4s, v0.s[1] \n"
"fmla v20.4s, v8.4s, v0.s[2] \n"
"fmla v21.4s, v9.4s, v0.s[2] \n"
"fmla v22.4s, v8.4s, v0.s[3] \n"
"fmla v23.4s, v9.4s, v0.s[3] \n"
"ld1 {v0.4s}, [%[lhs_ptr]], #16 \n"
"fmla v24.4s, v8.4s, v1.s[0] \n"
"fmla v25.4s, v9.4s, v1.s[0] \n"
"fmla v26.4s, v8.4s, v1.s[1] \n"
"fmla v27.4s, v9.4s, v1.s[1] \n"
"fmla v28.4s, v8.4s, v1.s[2] \n"
"fmla v29.4s, v9.4s, v1.s[2] \n"
"fmla v30.4s, v8.4s, v1.s[3] \n"
"fmla v31.4s, v9.4s, v1.s[3] \n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v8.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v9.4s}, [%[rhs_ptr]], #16 \n"
"fmla v16.4s, v10.4s, v2.s[0] \n"
"fmla v17.4s, v11.4s, v2.s[0] \n"
"fmla v18.4s, v10.4s, v2.s[1] \n"
"fmla v19.4s, v11.4s, v2.s[1] \n"
"fmla v20.4s, v10.4s, v2.s[2] \n"
"fmla v21.4s, v11.4s, v2.s[2] \n"
"fmla v22.4s, v10.4s, v2.s[3] \n"
"fmla v23.4s, v11.4s, v2.s[3] \n"
"ld1 {v2.4s}, [%[lhs_ptr]], #16 \n"
"fmla v24.4s, v10.4s, v3.s[0] \n"
"fmla v25.4s, v11.4s, v3.s[0] \n"
"fmla v26.4s, v10.4s, v3.s[1] \n"
"fmla v27.4s, v11.4s, v3.s[1] \n"
"fmla v28.4s, v10.4s, v3.s[2] \n"
"fmla v29.4s, v11.4s, v3.s[2] \n"
"fmla v30.4s, v10.4s, v3.s[3] \n"
"fmla v31.4s, v11.4s, v3.s[3] \n"
"ld1 {v3.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v10.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v11.4s}, [%[rhs_ptr]], #16 \n"
"fmla v16.4s, v12.4s, v4.s[0] \n"
"fmla v17.4s, v13.4s, v4.s[0] \n"
"fmla v18.4s, v12.4s, v4.s[1] \n"
"fmla v19.4s, v13.4s, v4.s[1] \n"
"fmla v20.4s, v12.4s, v4.s[2] \n"
"fmla v21.4s, v13.4s, v4.s[2] \n"
"fmla v22.4s, v12.4s, v4.s[3] \n"
"fmla v23.4s, v13.4s, v4.s[3] \n"
"ld1 {v4.4s}, [%[lhs_ptr]], #16 \n"
"fmla v24.4s, v12.4s, v5.s[0] \n"
"fmla v25.4s, v13.4s, v5.s[0] \n"
"fmla v26.4s, v12.4s, v5.s[1] \n"
"fmla v27.4s, v13.4s, v5.s[1] \n"
"fmla v28.4s, v12.4s, v5.s[2] \n"
"fmla v29.4s, v13.4s, v5.s[2] \n"
"fmla v30.4s, v12.4s, v5.s[3] \n"
"fmla v31.4s, v13.4s, v5.s[3] \n"
"ld1 {v5.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v12.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v13.4s}, [%[rhs_ptr]], #16 \n"
"fmla v16.4s, v14.4s, v6.s[0] \n"
"fmla v17.4s, v15.4s, v6.s[0] \n"
"fmla v18.4s, v14.4s, v6.s[1] \n"
"fmla v19.4s, v15.4s, v6.s[1] \n"
"fmla v20.4s, v14.4s, v6.s[2] \n"
"fmla v21.4s, v15.4s, v6.s[2] \n"
"fmla v22.4s, v14.4s, v6.s[3] \n"
"fmla v23.4s, v15.4s, v6.s[3] \n"
"ld1 {v6.4s}, [%[lhs_ptr]], #16 \n"
"subs %[r_depth_block_count], %[r_depth_block_count], #1 \n"
"fmla v24.4s, v14.4s, v7.s[0] \n"
"fmla v25.4s, v15.4s, v7.s[0] \n"
"fmla v26.4s, v14.4s, v7.s[1] \n"
"fmla v27.4s, v15.4s, v7.s[1] \n"
"fmla v28.4s, v14.4s, v7.s[2] \n"
"fmla v29.4s, v15.4s, v7.s[2] \n"
"fmla v30.4s, v14.4s, v7.s[3] \n"
"fmla v31.4s, v15.4s, v7.s[3] \n"
"ld1 {v7.4s}, [%[lhs_ptr]], #16 \n"
"ld1 {v14.4s}, [%[rhs_ptr]], #16 \n"
"ld1 {v15.4s}, [%[rhs_ptr]], #16 \n"
"bne 0b \n"
// prologue
"1:\n"
"fmla v16.4s, v8.4s, v0.s[0] \n"
"fmla v17.4s, v9.4s, v0.s[0] \n"
"fmla v18.4s, v8.4s, v0.s[1] \n"
"fmla v19.4s, v9.4s, v0.s[1] \n"
"fmla v20.4s, v8.4s, v0.s[2] \n"
"fmla v21.4s, v9.4s, v0.s[2] \n"
"fmla v22.4s, v8.4s, v0.s[3] \n"
"fmla v23.4s, v9.4s, v0.s[3] \n"
"fmla v24.4s, v8.4s, v1.s[0] \n"
"fmla v25.4s, v9.4s, v1.s[0] \n"
"fmla v26.4s, v8.4s, v1.s[1] \n"
"fmla v27.4s, v9.4s, v1.s[1] \n"
"fmla v28.4s, v8.4s, v1.s[2] \n"
"fmla v29.4s, v9.4s, v1.s[2] \n"
"fmla v30.4s, v8.4s, v1.s[3] \n"
"fmla v31.4s, v9.4s, v1.s[3] \n"
"fmla v16.4s, v10.4s, v2.s[0] \n"
"fmla v17.4s, v11.4s, v2.s[0] \n"
"fmla v18.4s, v10.4s, v2.s[1] \n"
"fmla v19.4s, v11.4s, v2.s[1] \n"
"fmla v20.4s, v10.4s, v2.s[2] \n"
"fmla v21.4s, v11.4s, v2.s[2] \n"
"fmla v22.4s, v10.4s, v2.s[3] \n"
"fmla v23.4s, v11.4s, v2.s[3] \n"
"fmla v24.4s, v10.4s, v3.s[0] \n"
"fmla v25.4s, v11.4s, v3.s[0] \n"
"fmla v26.4s, v10.4s, v3.s[1] \n"
"fmla v27.4s, v11.4s, v3.s[1] \n"
"fmla v28.4s, v10.4s, v3.s[2] \n"
"fmla v29.4s, v11.4s, v3.s[2] \n"
"fmla v30.4s, v10.4s, v3.s[3] \n"
"fmla v31.4s, v11.4s, v3.s[3] \n"
"fmla v16.4s, v12.4s, v4.s[0] \n"
"fmla v17.4s, v13.4s, v4.s[0] \n"
"fmla v18.4s, v12.4s, v4.s[1] \n"
"fmla v19.4s, v13.4s, v4.s[1] \n"
"fmla v20.4s, v12.4s, v4.s[2] \n"
"fmla v21.4s, v13.4s, v4.s[2] \n"
"fmla v22.4s, v12.4s, v4.s[3] \n"
"fmla v23.4s, v13.4s, v4.s[3] \n"
"fmla v24.4s, v12.4s, v5.s[0] \n"
"fmla v25.4s, v13.4s, v5.s[0] \n"
"fmla v26.4s, v12.4s, v5.s[1] \n"
"fmla v27.4s, v13.4s, v5.s[1] \n"
"fmla v28.4s, v12.4s, v5.s[2] \n"
"fmla v29.4s, v13.4s, v5.s[2] \n"
"fmla v30.4s, v12.4s, v5.s[3] \n"
"fmla v31.4s, v13.4s, v5.s[3] \n"
"fmla v16.4s, v14.4s, v6.s[0] \n"
"fmla v17.4s, v15.4s, v6.s[0] \n"
"fmla v18.4s, v14.4s, v6.s[1] \n"
"fmla v19.4s, v15.4s, v6.s[1] \n"
"fmla v20.4s, v14.4s, v6.s[2] \n"
"fmla v21.4s, v15.4s, v6.s[2] \n"
"fmla v22.4s, v14.4s, v6.s[3] \n"
"fmla v23.4s, v15.4s, v6.s[3] \n"
"fmla v24.4s, v14.4s, v7.s[0] \n"
"fmla v25.4s, v15.4s, v7.s[0] \n"
"fmla v26.4s, v14.4s, v7.s[1] \n"
"fmla v27.4s, v15.4s, v7.s[1] \n"
"fmla v28.4s, v14.4s, v7.s[2] \n"
"fmla v29.4s, v15.4s, v7.s[2] \n"
"fmla v30.4s, v14.4s, v7.s[3] \n"
"fmla v31.4s, v15.4s, v7.s[3] \n"
"st1 {v16.4s}, [%[packed_output_data]], #16 \n"
"st1 {v17.4s}, [%[packed_output_data]], #16 \n"
"st1 {v18.4s}, [%[packed_output_data]], #16 \n"
"st1 {v19.4s}, [%[packed_output_data]], #16 \n"
"st1 {v20.4s}, [%[packed_output_data]], #16 \n"
"st1 {v21.4s}, [%[packed_output_data]], #16 \n"
"st1 {v22.4s}, [%[packed_output_data]], #16 \n"
"st1 {v23.4s}, [%[packed_output_data]], #16 \n"
"st1 {v24.4s}, [%[packed_output_data]], #16 \n"
"st1 {v25.4s}, [%[packed_output_data]], #16 \n"
"st1 {v26.4s}, [%[packed_output_data]], #16 \n"
"st1 {v27.4s}, [%[packed_output_data]], #16 \n"
"st1 {v28.4s}, [%[packed_output_data]], #16 \n"
"st1 {v29.4s}, [%[packed_output_data]], #16 \n"
"st1 {v30.4s}, [%[packed_output_data]], #16 \n"
"st1 {v31.4s}, [%[packed_output_data]], #16 \n"
: // outputs
[lhs_ptr] "+r"(lhs_ptr),
[rhs_ptr] "+r"(rhs_ptr),
[packed_output_data] "+r"(packed_output_data),
[r_depth_block_count] "+r"(r_depth_block_count)
: // inputs
: // clabbers
"cc", "memory",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
}
#else // armeabi-v7a
// Register layout: (4x4) x (4,8)
//
// +--------+--------+
// | q4 ... | q5 ... |
// Rhs +--------+--------+
// | q6 ... | q7 ... |
// +--------+--------+
// | q4 ... | q5 ... |
// +--------+--------+
// | q6 ... | q7 ... |
// +--------+--------+
//
// Lhs
//
// +----+----+----+----+ - - +--------+--------+
// | q0 | q1 | q2 | q3 | | q8... | q9... |
// | . | | | | | q10... | q11... |
// | . | | | | | q12... | q13... |
// | . | | | | | q14... | q15... |
// +----+----+----+----+ +--------+--------+
//
// Accumulator
//
if (depth_block_count > 0) {
index_t r_depth_block_count = depth_block_count;
// just make compiler happy
MACE_UNUSED(r_depth_block_count);
asm volatile(
"mov r0, #0\n"
"vdup.f32 q8, r0 \n"
"vdup.f32 q9, r0 \n"
"vdup.f32 q10, r0 \n"
"vdup.f32 q11, r0 \n"
"vdup.f32 q12, r0 \n"
"vdup.f32 q13, r0 \n"
"vdup.f32 q14, r0 \n"
"vdup.f32 q15, r0 \n"
// prelogue
"vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n"
"vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n"
"vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n"
"vld1.f32 {d8-d9}, [%[rhs_ptr]]! \n"
"vld1.f32 {d10-d11}, [%[rhs_ptr]]! \n"
"vld1.f32 {d12-d13}, [%[rhs_ptr]]! \n"
"vld1.f32 {d14-d15}, [%[rhs_ptr]]! \n"
"subs %[r_depth_block_count], %[r_depth_block_count], #1 \n"
"beq 1f\n"
"0: \n"
"vmla.f32 q8, q4, d0[0] \n"
"vmla.f32 q9, q5, d0[0] \n"
"vmla.f32 q10, q4, d0[1] \n"
"vmla.f32 q11, q5, d0[1] \n"
"vmla.f32 q12, q4, d1[0] \n"
"vmla.f32 q13, q5, d1[0] \n"
"vmla.f32 q14, q4, d1[1] \n"
"vmla.f32 q15, q5, d1[1] \n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n"
"vld1.f32 {d8-d9}, [%[rhs_ptr]]! \n"
"vld1.f32 {d10-d11}, [%[rhs_ptr]]! \n"
"vmla.f32 q8, q6, d2[0] \n"
"vmla.f32 q9, q7, d2[0] \n"
"vmla.f32 q10, q6, d2[1] \n"
"vmla.f32 q11, q7, d2[1] \n"
"vmla.f32 q12, q6, d3[0] \n"
"vmla.f32 q13, q7, d3[0] \n"
"vmla.f32 q14, q6, d3[1] \n"
"vmla.f32 q15, q7, d3[1] \n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n"
"vld1.f32 {d12-d13}, [%[rhs_ptr]]! \n"
"vld1.f32 {d14-d15}, [%[rhs_ptr]]! \n"
"vmla.f32 q8, q4, d4[0] \n"
"vmla.f32 q9, q5, d4[0] \n"
"vmla.f32 q10, q4, d4[1] \n"
"vmla.f32 q11, q5, d4[1] \n"
"vmla.f32 q12, q4, d5[0] \n"
"vmla.f32 q13, q5, d5[0] \n"
"vmla.f32 q14, q4, d5[1] \n"
"vmla.f32 q15, q5, d5[1] \n"
"vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n"
"vld1.f32 {d8-d9}, [%[rhs_ptr]]! \n"
"vld1.f32 {d10-d11}, [%[rhs_ptr]]! \n"
"subs %[r_depth_block_count], %[r_depth_block_count], #1 \n"
"vmla.f32 q8, q6, d6[0] \n"
"vmla.f32 q9, q7, d6[0] \n"
"vmla.f32 q10, q6, d6[1] \n"
"vmla.f32 q11, q7, d6[1] \n"
"vmla.f32 q12, q6, d7[0] \n"
"vmla.f32 q13, q7, d7[0] \n"
"vmla.f32 q14, q6, d7[1] \n"
"vmla.f32 q15, q7, d7[1] \n"
"vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n"
"vld1.f32 {d12-d13}, [%[rhs_ptr]]! \n"
"vld1.f32 {d14-d15}, [%[rhs_ptr]]! \n"
"bne 0b \n"
// prologue
"1:\n"
"vmla.f32 q8, q4, d0[0] \n"
"vmla.f32 q9, q5, d0[0] \n"
"vmla.f32 q10, q4, d0[1] \n"
"vmla.f32 q11, q5, d0[1] \n"
"vmla.f32 q12, q4, d1[0] \n"
"vmla.f32 q13, q5, d1[0] \n"
"vmla.f32 q14, q4, d1[1] \n"
"vmla.f32 q15, q5, d1[1] \n"
"vld1.f32 {d8-d9}, [%[rhs_ptr]]! \n"
"vld1.f32 {d10-d11}, [%[rhs_ptr]]! \n"
"vmla.f32 q8, q6, d2[0] \n"
"vmla.f32 q9, q7, d2[0] \n"
"vmla.f32 q10, q6, d2[1] \n"
"vmla.f32 q11, q7, d2[1] \n"
"vmla.f32 q12, q6, d3[0] \n"
"vmla.f32 q13, q7, d3[0] \n"
"vmla.f32 q14, q6, d3[1] \n"
"vmla.f32 q15, q7, d3[1] \n"
"vld1.f32 {d12-d13}, [%[rhs_ptr]]! \n"
"vld1.f32 {d14-d15}, [%[rhs_ptr]]! \n"
"vmla.f32 q8, q4, d4[0] \n"
"vmla.f32 q9, q5, d4[0] \n"
"vmla.f32 q10, q4, d4[1] \n"
"vmla.f32 q11, q5, d4[1] \n"
"vmla.f32 q12, q4, d5[0] \n"
"vmla.f32 q13, q5, d5[0] \n"
"vmla.f32 q14, q4, d5[1] \n"
"vmla.f32 q15, q5, d5[1] \n"
"vmla.f32 q8, q6, d6[0] \n"
"vmla.f32 q9, q7, d6[0] \n"
"vmla.f32 q10, q6, d6[1] \n"
"vmla.f32 q11, q7, d6[1] \n"
"vmla.f32 q12, q6, d7[0] \n"
"vmla.f32 q13, q7, d7[0] \n"
"vmla.f32 q14, q6, d7[1] \n"
"vmla.f32 q15, q7, d7[1] \n"
"vst1.f32 {d16-d17}, [%[packed_output_data]]! \n"
"vst1.f32 {d18-d19}, [%[packed_output_data]]! \n"
"vst1.f32 {d20-d21}, [%[packed_output_data]]! \n"
"vst1.f32 {d22-d23}, [%[packed_output_data]]! \n"
"vst1.f32 {d24-d25}, [%[packed_output_data]]! \n"
"vst1.f32 {d26-d27}, [%[packed_output_data]]! \n"
"vst1.f32 {d28-d29}, [%[packed_output_data]]! \n"
"vst1.f32 {d30-d31}, [%[packed_output_data]]! \n"
: // outputs
[lhs_ptr] "+r"(lhs_ptr),
[rhs_ptr] "+r"(rhs_ptr),
[packed_output_data] "+r"(packed_output_data),
[r_depth_block_count] "+r"(r_depth_block_count)
: // inputs
: // clabbers
"cc", "memory", "r0",
"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
#endif
}
void Gemm::PackLhs(const MatrixMap<const float> &lhs,
float *packed_lhs) {
#ifdef __aarch64__
Pack<8, 4>(lhs, ColMajor, packed_lhs);
#else
Pack<4, 4>(lhs, ColMajor, packed_lhs);
#endif
}
void Gemm::PackRhs(const MatrixMap<const float> &rhs,
float *packed_rhs) {
Pack<8, 4>(rhs, RowMajor, packed_rhs);
}
void Gemm::UnpackOutput(const float *packed_output, MatrixMap<float> *output) {
#ifdef __aarch64__
Unpack<8, 8>(packed_output, output);
#else
Unpack<4, 8>(packed_output, output);
#endif
}
template<>
void Gemm::Pack<4, 4>(const MatrixMap<const float> &matrix,
MatrixMajor dst_major,
float *packed_matrix) {
const index_t rows = matrix.rows();
const index_t cols = matrix.cols();
// use the same terminology as GemmLowp:
// depth is depth, width is the opposite dim other than depth
// lhs
index_t width = rows;
index_t depth = cols;
index_t width_stride = matrix.rows_stride();
index_t depth_stride = matrix.cols_stride();
if (dst_major == RowMajor) {
// rhs
std::swap(width, depth);
std::swap(width_stride, depth_stride);
}
const float *data = matrix.data();
float *packed_ptr = packed_matrix;
const index_t block_size = 4;
const index_t depth_padded = RoundUp(depth, static_cast<index_t>(4));
if (depth_padded > depth) {
memset(packed_ptr + depth * block_size,
0,
sizeof(float) * (depth_padded - depth) * block_size);
}
if (dst_major == matrix.matrix_major()) {
if (width < block_size) {
const index_t width_remain = block_size - width;
for (index_t d = 0; d < depth; ++d) {
memcpy(packed_ptr, data, sizeof(float) * width);
memset(packed_ptr + width, 0, sizeof(float) * width_remain);
data += depth_stride;
packed_ptr += block_size;
}
} else {
for (index_t d = 0; d < depth; ++d) {
float32x4_t vi = vld1q_f32(data);
vst1q_f32(packed_ptr, vi);
data += depth_stride;
packed_ptr += block_size;
}
}
} else {
if (width < block_size) {
const index_t width_remain = block_size - width;
for (index_t d = 0; d < depth; ++d) {
for (index_t w = 0; w < width; ++w) {
packed_ptr[w] = data[w * width_stride + d];
} // w
memset(packed_ptr + width, 0, sizeof(float) * width_remain);
packed_ptr += block_size;
} // d
} else {
const float *data0 = data;
const float *data1 = data + width_stride;
const float *data2 = data1 + width_stride;
const float *data3 = data2 + width_stride;
const index_t depth_block = depth / 4;
const index_t depth_remain = depth - depth_block * 4;
for (index_t depth_block_idx = 0; depth_block_idx < depth_block;
++depth_block_idx) {
float32x4_t v0 = vld1q_f32(data0);
float32x4_t v1 = vld1q_f32(data1);
float32x4_t v2 = vld1q_f32(data2);
float32x4_t v3 = vld1q_f32(data3);
float32x4x2_t v02_intertwined = vzipq_f32(v0, v2);
float32x4x2_t v13_intertwined = vzipq_f32(v1, v3);
float32x4x2_t v0123_intertwined =
vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]);
float32x4x2_t v0123n_intertwined =
vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]);
vst1q_f32(packed_ptr, v0123_intertwined.val[0]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v0123_intertwined.val[1]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v0123n_intertwined.val[0]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v0123n_intertwined.val[1]);
packed_ptr += 4;
data0 += 4;
data1 += 4;
data2 += 4;
data3 += 4;
}
for (index_t d = 0; d < depth_remain; ++d) {
float32x4_t vi = {*data0, *data1, *data2, *data3};
vst1q_f32(packed_ptr, vi);
packed_ptr += 4;
++data0;
++data1;
++data2;
++data3;
} // d
}
}
}
template<>
void Gemm::Pack<8, 4>(const MatrixMap<const float> &matrix,
MatrixMajor dst_major,
float *packed_matrix) {
const index_t rows = matrix.rows();
const index_t cols = matrix.cols();
// use the same terminology as GemmLowp:
// depth is depth, width is the opposite dim other than depth
// lhs
index_t width = rows;
index_t depth = cols;
index_t width_stride = matrix.rows_stride();
index_t depth_stride = matrix.cols_stride();
if (dst_major == RowMajor) {
// rhs
std::swap(width, depth);
std::swap(width_stride, depth_stride);
}
const float *data = matrix.data();
float *packed_ptr = packed_matrix;
const index_t block_size = 8;
const index_t depth_padded = RoundUp(depth, static_cast<index_t>(4));
if (depth_padded > depth) {
memset(packed_ptr + depth * block_size,
0,
sizeof(float) * (depth_padded - depth) * block_size);
}
if (dst_major == matrix.matrix_major()) {
if (width < block_size) {
const index_t width_remain = block_size - width;
for (index_t d = 0; d < depth; ++d) {
memcpy(packed_ptr, data, sizeof(float) * width);
memset(packed_ptr + width, 0, sizeof(float) * width_remain);
data += depth_stride;
packed_ptr += block_size;
}
} else {
for (index_t d = 0; d < depth; ++d) {
float32x4_t vi = vld1q_f32(data);
vst1q_f32(packed_ptr, vi);
float32x4_t vin = vld1q_f32(data + 4);
vst1q_f32(packed_ptr + 4, vin);
data += depth_stride;
packed_ptr += block_size;
}
}
} else {
if (width < block_size) {
const index_t width_remain = block_size - width;
for (index_t d = 0; d < depth; ++d) {
for (index_t w = 0; w < width; ++w) {
packed_ptr[w] = data[w * width_stride + d];
} // w
memset(packed_ptr + width, 0, sizeof(float) * width_remain);
packed_ptr += block_size;
} // d
} else {
const float *data0 = data;
const float *data1 = data + width_stride;
const float *data2 = data1 + width_stride;
const float *data3 = data2 + width_stride;
const float *data4 = data3 + width_stride;
const float *data5 = data4 + width_stride;
const float *data6 = data5 + width_stride;
const float *data7 = data6 + width_stride;
const index_t depth_block = depth / 4;
const index_t depth_remain = depth - depth_block * 4;
for (index_t depth_block_idx = 0; depth_block_idx < depth_block;
++depth_block_idx) {
float32x4_t v0 = vld1q_f32(data0);
float32x4_t v1 = vld1q_f32(data1);
float32x4_t v2 = vld1q_f32(data2);
float32x4_t v3 = vld1q_f32(data3);
float32x4x2_t v02_intertwined = vzipq_f32(v0, v2);
float32x4x2_t v13_intertwined = vzipq_f32(v1, v3);
float32x4x2_t v0123_intertwined =
vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]);
float32x4x2_t v0123n_intertwined =
vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]);
float32x4_t v4 = vld1q_f32(data4);
float32x4_t v5 = vld1q_f32(data5);
float32x4_t v6 = vld1q_f32(data6);
float32x4_t v7 = vld1q_f32(data7);
float32x4x2_t v46_intertwined = vzipq_f32(v4, v6);
float32x4x2_t v57_intertwined = vzipq_f32(v5, v7);
float32x4x2_t v4567_intertwined =
vzipq_f32(v46_intertwined.val[0], v57_intertwined.val[0]);
float32x4x2_t v4567n_intertwined =
vzipq_f32(v46_intertwined.val[1], v57_intertwined.val[1]);
vst1q_f32(packed_ptr, v0123_intertwined.val[0]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v4567_intertwined.val[0]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v0123_intertwined.val[1]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v4567_intertwined.val[1]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v0123n_intertwined.val[0]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v4567n_intertwined.val[0]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v0123n_intertwined.val[1]);
packed_ptr += 4;
vst1q_f32(packed_ptr, v4567n_intertwined.val[1]);
packed_ptr += 4;
data0 += 4;
data1 += 4;
data2 += 4;
data3 += 4;
data4 += 4;
data5 += 4;
data6 += 4;
data7 += 4;
}
for (index_t d = 0; d < depth_remain; ++d) {
float32x4_t vi = {*data0, *data1, *data2, *data3};
vst1q_f32(packed_ptr, vi);
packed_ptr += 4;
float32x4_t vin = {*data4, *data5, *data6, *data7};
vst1q_f32(packed_ptr, vin);
packed_ptr += 4;
++data0;
++data1;
++data2;
++data3;
++data4;
++data5;
++data6;
++data7;
} // d
}
}
}
template<>
void Gemm::Unpack<4, 8>(const float *packed_output, MatrixMap<float> *output) {
const index_t rows = output->rows();
const index_t cols = output->cols();
index_t row_stride = output->rows_stride();
index_t col_stride = output->cols_stride();
float *output_ptr = output->data();
const float *packed_ptr = packed_output;
const index_t block_size = 8;
// packed_output always has row-major
if (output->matrix_major() == RowMajor) {
if (cols < block_size) {
for (index_t r = 0; r < rows; ++r) {
memcpy(output_ptr, packed_ptr, sizeof(float) * cols);
output_ptr += row_stride;
packed_ptr += block_size;
}
} else {
for (index_t r = 0; r < rows; ++r) {
float32x4_t vi = vld1q_f32(packed_ptr);
vst1q_f32(output_ptr, vi);
float32x4_t vin = vld1q_f32(packed_ptr + 4);
vst1q_f32(output_ptr + 4, vin);
output_ptr += row_stride;
packed_ptr += block_size;
}
}
} else {
// ColMajor
if (rows < block_size) {
for (index_t c = 0; c < cols; ++c) {
for (index_t r = 0; r < rows; ++r) {
output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c];
} // r
} // c
} else {
const float *data0 = packed_ptr;
const float *data1 = data0 + block_size;
const float *data2 = data1 + block_size;
const float *data3 = data2 + block_size;
index_t col_block = cols / 4;
index_t col_remain = cols - col_block * 4;
for (index_t col_block_idx = 0; col_block_idx < col_block;
++col_block_idx) {
float32x4_t v0 = vld1q_f32(data0);
float32x4_t v1 = vld1q_f32(data1);
float32x4_t v2 = vld1q_f32(data2);
float32x4_t v3 = vld1q_f32(data3);
float32x4x2_t v02_intertwined = vzipq_f32(v0, v2);
float32x4x2_t v13_intertwined = vzipq_f32(v1, v3);
float32x4x2_t v0123_intertwined =
vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]);
float32x4x2_t v0123n_intertwined =
vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]);
vst1q_f32(output_ptr, v0123_intertwined.val[0]);
output_ptr += col_stride;
vst1q_f32(output_ptr, v0123_intertwined.val[1]);
output_ptr += col_stride;
vst1q_f32(output_ptr, v0123n_intertwined.val[0]);
output_ptr += col_stride;
vst1q_f32(output_ptr, v0123n_intertwined.val[1]);
output_ptr += col_stride;
data0 += 4;
data1 += 4;
data2 += 4;
data3 += 4;
}
for (index_t c = 0; c < col_remain; ++c) {
float32x4_t vi = {*data0, *data1, *data2, *data3};
vst1q_f32(output_ptr, vi);
output_ptr += col_stride;
++data0;
++data1;
++data2;
++data3;
} // d
}
}
}
template<>
void Gemm::Unpack<8, 8>(const float *packed_output, MatrixMap<float> *output) {
const index_t rows = output->rows();
const index_t cols = output->cols();
index_t row_stride = output->rows_stride();
index_t col_stride = output->cols_stride();
float *output_ptr = output->data();
const float *packed_ptr = packed_output;
const index_t block_size = 8;
// packed_output always has row-major
if (output->matrix_major() == RowMajor) {
if (cols < block_size) {
for (index_t r = 0; r < rows; ++r) {
memcpy(output_ptr, packed_ptr, sizeof(float) * cols);
output_ptr += row_stride;
packed_ptr += block_size;
}
} else {
for (index_t r = 0; r < rows; ++r) {
float32x4_t vi = vld1q_f32(packed_ptr);
vst1q_f32(output_ptr, vi);
float32x4_t vin = vld1q_f32(packed_ptr + 4);
vst1q_f32(output_ptr + 4, vin);
output_ptr += row_stride;
packed_ptr += block_size;
}
}
} else {
// ColMajor
if (rows < block_size) {
for (index_t c = 0; c < cols; ++c) {
for (index_t r = 0; r < rows; ++r) {
output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c];
} // r
} // c
} else {
const float *data0 = packed_ptr;
const float *data1 = data0 + block_size;
const float *data2 = data1 + block_size;
const float *data3 = data2 + block_size;
const float *data4 = data3 + block_size;
const float *data5 = data4 + block_size;
const float *data6 = data5 + block_size;
const float *data7 = data6 + block_size;
index_t col_block = cols / 4;
index_t col_remain = cols - col_block * 4;
for (index_t col_block_idx = 0; col_block_idx < col_block;
++col_block_idx) {
float32x4_t v0 = vld1q_f32(data0);
float32x4_t v1 = vld1q_f32(data1);
float32x4_t v2 = vld1q_f32(data2);
float32x4_t v3 = vld1q_f32(data3);
float32x4x2_t v02_intertwined = vzipq_f32(v0, v2);
float32x4x2_t v13_intertwined = vzipq_f32(v1, v3);
float32x4x2_t v0123_intertwined =
vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]);
float32x4x2_t v0123n_intertwined =
vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]);
float32x4_t v4 = vld1q_f32(data4);
float32x4_t v5 = vld1q_f32(data5);
float32x4_t v6 = vld1q_f32(data6);
float32x4_t v7 = vld1q_f32(data7);
float32x4x2_t v46_intertwined = vzipq_f32(v4, v6);
float32x4x2_t v57_intertwined = vzipq_f32(v5, v7);
float32x4x2_t v4567_intertwined =
vzipq_f32(v46_intertwined.val[0], v57_intertwined.val[0]);
float32x4x2_t v4567n_intertwined =
vzipq_f32(v46_intertwined.val[1], v57_intertwined.val[1]);
vst1q_f32(output_ptr, v0123_intertwined.val[0]);
vst1q_f32(output_ptr + 4, v4567_intertwined.val[0]);
output_ptr += col_stride;
vst1q_f32(output_ptr, v0123_intertwined.val[1]);
vst1q_f32(output_ptr + 4, v4567_intertwined.val[1]);
output_ptr += col_stride;
vst1q_f32(output_ptr, v0123n_intertwined.val[0]);
vst1q_f32(output_ptr + 4, v4567n_intertwined.val[0]);
output_ptr += col_stride;
vst1q_f32(output_ptr, v0123n_intertwined.val[1]);
vst1q_f32(output_ptr + 4, v4567n_intertwined.val[1]);
output_ptr += col_stride;
data0 += 4;
data1 += 4;
data2 += 4;
data3 += 4;
data4 += 4;
data5 += 4;
data6 += 4;
data7 += 4;
}
for (index_t c = 0; c < col_remain; ++c) {
float32x4_t vi = {*data0, *data1, *data2, *data3};
vst1q_f32(output_ptr, vi);
float32x4_t vin = {*data4, *data5, *data6, *data7};
vst1q_f32(output_ptr + 4, vin);
output_ptr += col_stride;
++data0;
++data1;
++data2;
++data3;
++data4;
++data5;
++data6;
++data7;
} // d
}
}
}
MaceStatus Gemm::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t lhs_rows,
const index_t lhs_cols,
const index_t rhs_rows,
const index_t rhs_cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool transpose_out,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
index_t rows = transpose_lhs ? lhs_cols : lhs_rows;
index_t depth = transpose_lhs ? lhs_rows : lhs_cols;
index_t cols = transpose_rhs ? rhs_rows : rhs_cols;
index_t depth2 = transpose_rhs ? rhs_cols : rhs_rows;
MACE_CHECK(depth == depth2,
"Matrices that multiply have inconsistent depth dim: ",
depth,
" vs. ",
depth2);
return Compute(context,
lhs,
rhs,
batch,
rows,
cols,
depth,
transpose_lhs ? ColMajor : RowMajor,
transpose_rhs ? ColMajor : RowMajor,
transpose_out ? ColMajor : RowMajor,
lhs_batched,
rhs_batched,
output);
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_GEMM_H_
#define MACE_OPS_ARM_FP32_GEMM_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/matrix.h"
// This implements matrix-matrix multiplication.
// In the case of matrix-vector multiplication, use gemv.h/gemv.cc instead
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Gemm {
public:
explicit Gemm(const bool should_cache_pack)
: tmp_scratch_buffer_(GetCPUAllocator()),
pack_cache_(GetCPUAllocator()),
should_cache_pack_(should_cache_pack),
cached_(0) {}
Gemm() : Gemm(false) {}
~Gemm() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
// Original matrix before transpose has row-major
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t lhs_rows,
const index_t lhs_cols,
const index_t rhs_rows,
const index_t rhs_cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool transpose_out,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
private:
void ComputeBlock(const float *packed_lhs_data,
const float *packed_rhs_data,
const index_t depth_padded,
float *packed_output_data);
void PackLhs(const MatrixMap<const float> &lhs,
float *packed_lhs);
void PackRhs(const MatrixMap<const float> &rhs,
float *packed_rhs);
void UnpackOutput(const float *packed_output,
MatrixMap<float> *output);
template<int RowBlockSize, int ColBlockSize>
void Unpack(const float *packed_output,
MatrixMap<float> *output) {
const index_t rows = output->rows();
const index_t cols = output->cols();
for (index_t r = 0; r < rows; ++r) {
for (index_t c = 0; c < cols; ++c) {
*output->data(r, c) = packed_output[r * ColBlockSize + c];
}
}
}
template<int WidthBlockSize, int DepthBlockSize>
void Pack(const MatrixMap<const float> &matrix,
MatrixMajor dst_major,
float *packed_matrix) {
const index_t rows = matrix.rows();
const index_t cols = matrix.cols();
index_t depth = cols;
if (dst_major == RowMajor) {
// rhs
depth = rows;
}
const index_t depth_padded = RoundUp(depth, static_cast<index_t>(4));
memset(packed_matrix, 0, sizeof(float) * WidthBlockSize * depth_padded);
if (dst_major == ColMajor) {
for (index_t c = 0; c < cols; ++c) {
for (index_t r = 0; r < rows; ++r) {
packed_matrix[c * WidthBlockSize + r] = matrix(r, c);
}
}
} else {
for (index_t r = 0; r < rows; ++r) {
for (index_t c = 0; c < cols; ++c) {
packed_matrix[r * WidthBlockSize + c] = matrix(r, c);
}
}
}
}
ScratchBuffer tmp_scratch_buffer_;
Buffer pack_cache_;
bool should_cache_pack_;
int cached_;
};
template<>
void Gemm::Pack<4, 4>(const MatrixMap<const float> &matrix,
MatrixMajor dst_major,
float *packed_matrix);
template<>
void Gemm::Pack<8, 4>(const MatrixMap<const float> &matrix,
MatrixMajor dst_major,
float *packed_matrix);
template<>
void Gemm::Unpack<4, 8>(const float *packed_output, MatrixMap<float> *output);
template<>
void Gemm::Unpack<8, 8>(const float *packed_output, MatrixMap<float> *output);
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_GEMM_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/ref/gemm.h"
#include "mace/ops/testing/test_utils.h"
namespace mace {
namespace ops {
namespace test {
void TestGemmFloat32(const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor output(GetCPUAllocator(), DataType::DT_FLOAT);
lhs.Resize({lhs_batched ? batch : 1, rows, depth});
rhs.Resize({rhs_batched ? batch : 1, depth, cols});
output.Resize({batch, rows, cols});
{
Tensor::MappingGuard lhs_guard(&lhs);
Tensor::MappingGuard rhs_guard(&rhs);
float *lhs_data = lhs.mutable_data<float>();
float *rhs_data = rhs.mutable_data<float>();
float *output_data = output.mutable_data<float>();
GenerateRandomRealTypeData<float>(lhs.shape(), lhs_data);
GenerateRandomRealTypeData<float>(rhs.shape(), rhs_data);
GenerateRandomRealTypeData<float>(output.shape(), output_data);
}
::mace::ops::arm::fp32::Gemm gemm;
gemm.Compute(nullptr,
&lhs,
&rhs,
batch,
rows,
cols,
depth,
lhs_major,
rhs_major,
output_major,
lhs_batched,
rhs_batched,
&output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT);
expected_output.Resize({batch, rows, cols});
::mace::ops::ref::Gemm<float> gemm_ref;
gemm_ref.Compute(nullptr,
&lhs,
&rhs,
batch,
rows,
cols,
depth,
lhs_major,
rhs_major,
output_major,
lhs_batched,
rhs_batched,
&expected_output);
ExpectTensorNear<float>(expected_output, output);
}
TEST(ArmGemm, TestGemmFloat32) {
TestGemmFloat32(1, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, false);
TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, false, true);
TestGemmFloat32(16, 31, 61, 67, RowMajor, ColMajor, RowMajor, true, true);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -22,6 +22,9 @@
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
// Disable unroll by default, since cache set conflict could be significant
// #define MACE_GEMV_UNROLL 1
namespace mace {
namespace ops {
namespace arm {
......@@ -35,14 +38,26 @@ MaceStatus Gemv::Compute(const OpContext *context,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
MACE_UNUSED(context);
MACE_CHECK(output->size() == batch * lhs_height,
"Need resize output tensor before call gemv.");
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
const float *bias_data = nullptr;
if (bias) {
bias_data = bias->data<float>();
}
float *output_data = output->mutable_data<float>();
const index_t h_block_size = 4;
const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size);
const index_t w_block_size = 8;
......@@ -52,28 +67,20 @@ MaceStatus Gemv::Compute(const OpContext *context,
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < batch; ++b) {
for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) {
// TODO(liyin): it can be put it outside the loop,
// but openmp limits param count
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
const float *bias_data = nullptr;
if (bias) {
bias_data = bias->data<float>();
}
float *output_data = output->mutable_data<float>();
const index_t h_start = h_block_idx * h_block_size;
const float
*lhs_ptr = lhs_data
+ static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ lhs_width * h_block_idx * h_block_size;
const float *rhs_ptr = rhs_data + b * lhs_width;
+ lhs_width * h_start;
const float *rhs_ptr =
rhs_data + static_cast<index_t>(rhs_batched) * b * lhs_width;
float
*ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size;
*ret_ptr = output_data + b * lhs_height + h_start;
const index_t h_block_len =
std::min(h_block_size, lhs_height - h_block_idx * h_block_size);
const index_t h_offset = h_block_idx * h_block_size;
std::min(h_block_size, lhs_height - h_start);
#ifdef MACE_GEMV_UNROLL
if (h_block_len == 4) {
float32x4_t vo0 = vdupq_n_f32(0);
float32x4_t vo1 = vdupq_n_f32(0);
......@@ -149,6 +156,11 @@ MaceStatus Gemv::Compute(const OpContext *context,
"vmla.f32 %q[vo2], q4, q8\n"
"vmla.f32 %q[vo3], q6, q8\n"
"vld1.f32 {d0-d1}, [r1]!\n"
"vld1.f32 {d4-d5}, [r2]!\n"
"vld1.f32 {d8-d9}, [r3]!\n"
"vld1.f32 {d12-d13}, [r4]!\n"
"vld1.f32 {d16-d17}, [r0]!\n"
"vmla.f32 %q[vo0], q1, q9\n"
"vmla.f32 %q[vo1], q3, q9\n"
......@@ -157,13 +169,6 @@ MaceStatus Gemv::Compute(const OpContext *context,
"subs %[r_w_block_count], #1\n"
"vld1.f32 {d0-d1}, [r1]!\n"
"vld1.f32 {d4-d5}, [r2]!\n"
"vld1.f32 {d8-d9}, [r3]!\n"
"vld1.f32 {d12-d13}, [r4]!\n"
"vld1.f32 {d16-d17}, [r0]!\n"
"vld1.f32 {d2-d3}, [r1]!\n"
"vld1.f32 {d6-d7}, [r2]!\n"
"vld1.f32 {d10-d11}, [r3]!\n"
......@@ -257,26 +262,30 @@ MaceStatus Gemv::Compute(const OpContext *context,
vo = vaddq_f32(vo, vbias);
vst1q_f32(ret_ptr, vo);
} else { // h_block_len < 4
// TODO(liyin): handle here case by case (1,2,3) to accelerate
#endif // MACE_GEMV_UNROLL
const float *tmp_lhs_ptr = lhs_ptr;
const float *tmp_rhs_ptr = rhs_ptr;
for (index_t h = 0; h < h_block_len; ++h) {
lhs_ptr = tmp_lhs_ptr + h * lhs_width;
rhs_ptr = tmp_rhs_ptr;
float32x4_t vo0 = vdupq_n_f32(0);
float32x4_t vo0n = vdupq_n_f32(0);
for (index_t w = 0; w < w_block_count; ++w) {
float32x4_t vr0 = vld1q_f32(rhs_ptr);
float32x4_t vr0n = vld1q_f32(rhs_ptr + 4);
float32x4_t vl0 = vld1q_f32(lhs_ptr);
float32x4_t vl0n = vld1q_f32(lhs_ptr + 4);
// may cause some precision error depending on the compute order
vo0 = vmlaq_f32(vo0, vl0, vr0);
vo0 = vmlaq_f32(vo0, vl0n, vr0n);
vo0n = vmlaq_f32(vo0n, vl0n, vr0n);
lhs_ptr += 8;
rhs_ptr += 8;
} // w
float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_offset + h] : 0);
vo0 = vaddq_f32(vo0, vo0n);
float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_start + h] : 0);
for (index_t w = 0; w < w_remain; ++w) {
s0 += lhs_ptr[0] * rhs_ptr[0];
++lhs_ptr;
......@@ -285,7 +294,9 @@ MaceStatus Gemv::Compute(const OpContext *context,
ret_ptr[h] = s0;
} // h
#ifdef MACE_GEMV_UNROLL
} // if
#endif // MACE_GEMV_UNROLL
} // h_block_idx
} // b
......
......@@ -38,6 +38,7 @@ class Gemv {
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
......
......@@ -28,13 +28,14 @@ namespace test {
void TestGemvFloat32(const index_t batch,
const index_t height,
const index_t width,
const bool lhs_batched) {
const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor bias(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor output(GetCPUAllocator(), DataType::DT_FLOAT);
lhs.Resize({lhs_batched ? batch : 1, height, width});
rhs.Resize({batch, width});
rhs.Resize({rhs_batched ? batch : 1, width});
bias.Resize({height});
output.Resize({batch, height});
{
......@@ -57,6 +58,7 @@ void TestGemvFloat32(const index_t batch,
height,
width,
lhs_batched,
rhs_batched,
&output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT);
......@@ -70,28 +72,22 @@ void TestGemvFloat32(const index_t batch,
height,
width,
lhs_batched,
rhs_batched,
&expected_output);
Tensor::MappingGuard output_guard(&output);
Tensor::MappingGuard expected_guard(&expected_output);
const float *output_data = output.data<float>();
const float *expected_data = expected_output.data<float>();
for (index_t i = 0; i < output.size(); ++i) {
EXPECT_NEAR(expected_data[i], output_data[i], 0.001);
}
ExpectTensorNear<float>(expected_output, output);
}
TEST(ArmGemv, TestGemvFloat32) {
TestGemvFloat32(1, 16, 4, true);
TestGemvFloat32(1, 16, 256, true);
TestGemvFloat32(2, 16, 256, true);
TestGemvFloat32(3, 63, 257, true);
TestGemvFloat32(1, 16, 4, true, true);
TestGemvFloat32(1, 16, 256, true, true);
TestGemvFloat32(2, 16, 256, true, true);
TestGemvFloat32(3, 63, 257, true, true);
TestGemvFloat32(1, 16, 4, false);
TestGemvFloat32(1, 16, 256, false);
TestGemvFloat32(2, 16, 256, false);
TestGemvFloat32(3, 63, 257, false);
TestGemvFloat32(2, 16, 256, false, true);
TestGemvFloat32(3, 63, 257, false, true);
TestGemvFloat32(2, 16, 256, true, false);
TestGemvFloat32(3, 63, 257, true, false);
}
} // namespace test
......
......@@ -43,6 +43,7 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
MACE_UNUSED(context);
......@@ -100,7 +101,8 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
*lhs_ptr = lhs_data
+ static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ lhs_width * h_block_idx * h_block_size;
const uint8_t *rhs_ptr = rhs_data + b * lhs_width;
const uint8_t *rhs_ptr =
rhs_data + static_cast<index_t>(rhs_batched) * b * lhs_width;
OUTPUT_TYPE
*ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size;
......
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// This implements matrix-vector multiplication described as
// https://github.com/google/gemmlowp/blob/master/todo/fast-gemv.txt
#ifndef MACE_OPS_ARM_Q8_GEMV_H_
#define MACE_OPS_ARM_Q8_GEMV_H_
......@@ -39,6 +42,7 @@ class Gemv {
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
......
......@@ -28,7 +28,8 @@ namespace test {
void TestGemvInt32(const index_t batch,
const index_t height,
const index_t width,
const bool lhs_batched) {
const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor bias(GetCPUAllocator(), DataType::DT_INT32);
......@@ -38,7 +39,7 @@ void TestGemvInt32(const index_t batch,
lhs.SetZeroPoint(0);
rhs.SetZeroPoint(0);
lhs.Resize({lhs_batched ? batch : 1, height, width});
rhs.Resize({batch, width});
rhs.Resize({rhs_batched ? batch : 1, batch, width});
bias.Resize({height});
output.Resize({batch, height});
{
......@@ -62,6 +63,7 @@ void TestGemvInt32(const index_t batch,
height,
width,
lhs_batched,
rhs_batched,
&output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32);
......@@ -75,6 +77,7 @@ void TestGemvInt32(const index_t batch,
height,
width,
lhs_batched,
rhs_batched,
&expected_output);
Tensor::MappingGuard output_guard(&output);
......@@ -90,7 +93,8 @@ void TestGemvInt32(const index_t batch,
void TestGemvUint8(const index_t batch,
const index_t height,
const index_t width,
const bool lhs_batched) {
const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor bias(GetCPUAllocator(), DataType::DT_INT32);
......@@ -127,6 +131,7 @@ void TestGemvUint8(const index_t batch,
height,
width,
lhs_batched,
rhs_batched,
&output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32);
......@@ -142,6 +147,7 @@ void TestGemvUint8(const index_t batch,
height,
width,
lhs_batched,
rhs_batched,
&expected_output);
Tensor::MappingGuard output_guard(&output);
......@@ -155,27 +161,27 @@ void TestGemvUint8(const index_t batch,
}
TEST(ArmGemv, TestGemvInt32) {
TestGemvInt32(1, 16, 4, true);
TestGemvInt32(1, 16, 256, true);
TestGemvInt32(2, 16, 256, true);
TestGemvInt32(3, 63, 257, true);
TestGemvInt32(1, 16, 4, false);
TestGemvInt32(1, 16, 256, false);
TestGemvInt32(2, 16, 256, false);
TestGemvInt32(3, 63, 257, false);
TestGemvInt32(1, 16, 4, true, true);
TestGemvInt32(1, 16, 256, true, true);
TestGemvInt32(2, 16, 256, true, true);
TestGemvInt32(3, 63, 257, true, true);
TestGemvInt32(2, 16, 256, false, true);
TestGemvInt32(3, 63, 257, false, true);
TestGemvInt32(2, 16, 256, true, false);
TestGemvInt32(3, 63, 257, true, false);
}
TEST(ArmGemv, TestGemvUint8) {
TestGemvUint8(1, 16, 4, true);
TestGemvUint8(1, 16, 256, true);
TestGemvUint8(2, 16, 256, true);
TestGemvUint8(3, 63, 257, true);
TestGemvUint8(1, 16, 4, false);
TestGemvUint8(1, 16, 256, false);
TestGemvUint8(2, 16, 256, false);
TestGemvUint8(3, 63, 257, false);
TestGemvUint8(1, 16, 4, true, true);
TestGemvUint8(1, 16, 256, true, true);
TestGemvUint8(2, 16, 256, true, true);
TestGemvUint8(3, 63, 257, true, true);
TestGemvUint8(2, 16, 256, false, true);
TestGemvUint8(3, 63, 257, false, true);
TestGemvUint8(2, 16, 256, true, false);
TestGemvUint8(3, 63, 257, true, false);
}
} // namespace test
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_COMMON_MATRIX_H_
#define MACE_OPS_COMMON_MATRIX_H_
namespace mace {
namespace ops {
enum MatrixMajor {
RowMajor,
ColMajor
};
inline MatrixMajor TransposeMatrixMajor(const MatrixMajor src_major) {
return src_major == RowMajor ? ColMajor : RowMajor;
}
template<typename T>
class MatrixMap {
public:
MatrixMap()
: data_(nullptr),
matrix_major_(RowMajor),
rows_(0),
cols_(0),
stride_(0) {}
MatrixMap(T *data,
const MatrixMajor matrix_major,
const index_t rows,
const index_t cols) :
data_(data),
matrix_major_(matrix_major),
rows_(rows),
cols_(cols),
stride_(matrix_major == ColMajor ? rows : cols) {}
MatrixMap(T *data,
const MatrixMajor matrix_major,
const index_t rows,
const index_t cols,
const index_t stride) :
data_(data),
matrix_major_(matrix_major),
rows_(rows),
cols_(cols),
stride_(stride) {}
MatrixMap(const MatrixMap &other)
: data_(other.data_),
matrix_major_(other.matrix_major_),
rows_(other.rows_),
cols_(other.cols_),
stride_(other.stride_) {}
MatrixMajor matrix_major() const { return matrix_major_; }
index_t rows() const { return rows_; }
index_t cols() const { return cols_; }
index_t stride() const { return stride_; }
int rows_stride() const {
return matrix_major_ == MatrixMajor::ColMajor ? 1 : stride_;
}
int cols_stride() const {
return matrix_major_ == MatrixMajor::RowMajor ? 1 : stride_;
}
index_t size() const { return rows_ * cols_; }
T *data() const { return data_; }
T *data(int rows, int cols) const {
return data_ + rows * rows_stride() + cols * cols_stride();
}
T &operator()(int row, int col) const { return *data(row, col); }
MatrixMap block(int start_row, int start_col, int block_rows,
int block_cols) const {
MACE_CHECK(start_row >= 0);
MACE_CHECK(start_row + block_rows <= rows_);
MACE_CHECK(start_col >= 0);
MACE_CHECK(start_col + block_cols <= cols_);
return MatrixMap(data(start_row, start_col),
matrix_major_,
block_rows,
block_cols,
stride_);
}
private:
T *data_;
MatrixMajor matrix_major_;
index_t rows_;
index_t cols_;
index_t stride_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_COMMON_MATRIX_H_
......@@ -33,6 +33,13 @@
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/utils/utils.h"
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/ops/arm/fp32/conv_2d_1x1.h"
#else
#include "mace/ops/ref/conv_2d.h"
#endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/gemmlowp_util.h"
#include "mace/ops/quantization_util.h"
......@@ -61,7 +68,8 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)),
is_filter_transformed_(false) {}
is_filter_transformed_(false),
conv2d_delegator_(nullptr) {}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
......@@ -69,9 +77,17 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *output = this->Output(OUTPUT);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
std::vector<index_t> filter_shape(4);
filter_shape = filter->shape();
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
if (paddings_.empty()) {
......@@ -99,25 +115,26 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter_shape[2];
index_t filter_w = filter_shape[3];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
#ifdef MACE_ENABLE_NEON
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter->dim(2);
index_t filter_w = filter->dim(3);
if (filter_h == 1 && filter_w == 1 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
if (conv2d_delegator_.get() == nullptr) {
conv2d_delegator_.reset(new arm::fp32::Conv2dK1x1());
}
conv2d_delegator_->Compute(context, input, filter, output);
} else {
// TODO(liyin): the code below needs to be refactored.
// delegate to each of kernels instead of ruling them all
index_t padded_input_height = input_height + paddings[0];
index_t padded_input_width = input_width + paddings[1];
index_t extra_input_height = padded_input_height;
......@@ -132,41 +149,48 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
std::function<void(const float *input, float *output)> conv_func;
bool
use_winograd = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1
&& input_channels >= 8 && channels >= 8;
bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1;
bool use_neon_1x1_s1 = filter_h == 1 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 2 && stride_w == 2 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_5x5_s1 = filter_h == 5 && filter_w == 5
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_1x7_s1 = filter_h == 1 && filter_w == 7
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_7x1_s1 = filter_h == 7 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_7x7_s1 = filter_h == 7 && filter_w == 7
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_7x7_s2 = filter_h == 7 && filter_w == 7
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 2 && stride_w == 2 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_7x7_s3 = filter_h == 7 && filter_w == 7
&& stride_h == 3 && stride_w == 3 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 3 && stride_w == 3 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_1x15_s1 = filter_h == 1 && filter_w == 15
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
bool use_neon_15x1_s1 = filter_h == 15 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
std::vector<index_t> transformed_input_shape;
std::vector<index_t> transformed_output_shape;
......@@ -184,7 +208,8 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
extra_input_height =
std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, winograd_out_tile_size);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
extra_input_width =
std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
......@@ -192,7 +217,8 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
pad_right += (extra_input_width - padded_input_width);
}
index_t tile_height_count = extra_output_height / winograd_out_tile_size;
index_t
tile_height_count = extra_output_height / winograd_out_tile_size;
index_t tile_width_count = extra_output_width / winograd_out_tile_size;
index_t tile_count = tile_height_count * tile_width_count;
index_t in_tile_area =
......@@ -205,13 +231,11 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
{in_tile_area, batch, channels,
tile_count});
transformed_filter_shape.insert(transformed_filter_shape.end(),
{in_tile_area, channels, input_channels});
{in_tile_area, channels,
input_channels});
} else {
index_t tile_h, tile_w;
if (use_neon_1x1_s1) {
tile_h = 1;
tile_w = 1;
} else if (use_neon_3x3_s1) {
if (use_neon_3x3_s1) {
tile_h = 2;
tile_w = 4;
} else if (use_neon_7x1_s1 || use_neon_15x1_s1) {
......@@ -270,12 +294,8 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
* sizeof(float);
total_scratch_size += padded_output_size;
}
// scratch for sgemm, preoccupy enough buffer
if (use_neon_1x1_s1) {
total_scratch_size += (input_batch * input_height * input_width
* (input_channels + channels))
* sizeof(float);
} else if (use_winograd) {
if (use_winograd) {
total_scratch_size += transformed_input_size + transformed_output_size;
}
......@@ -286,7 +306,8 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
Tensor
transformed_input(scratch->Scratch(transformed_input_size), DT_FLOAT);
Tensor
transformed_output(scratch->Scratch(transformed_output_size), DT_FLOAT);
transformed_output
(scratch->Scratch(transformed_output_size), DT_FLOAT);
Tensor padded_input(scratch->Scratch(padded_input_size), DT_FLOAT);
Tensor padded_output(scratch->Scratch(padded_output_size), DT_FLOAT);
const index_t extra_input_shape[4] =
......@@ -329,7 +350,8 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
}
float *transformed_input_data = transformed_input.mutable_data<float>();
float *transformed_output_data = transformed_output.mutable_data<float>();
float
*transformed_output_data = transformed_output.mutable_data<float>();
conv_func = [=](const float *pad_input, float *pad_output) {
WinoGradConv3x3s1(pad_input,
......@@ -346,19 +368,6 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
&sgemm_,
scratch);
};
} else if (use_neon_1x1_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK1x1S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
pad_output,
&sgemm_,
scratch);
};
} else if (use_neon_3x3_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S1(pad_input,
......@@ -468,7 +477,7 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
extra_output_width});
padded_output.Clear();
pad_output_ptr = &padded_output;
} else if (!use_neon_1x1_s1) {
} else {
output->Clear();
}
......@@ -484,7 +493,8 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
for (index_t c = 0; c < channels; ++c) {
for (index_t h = 0; h < height; ++h) {
memcpy(
output_data + b * channels * height * width + c * height * width
output_data + b * channels * height * width
+ c * height * width
+ h * width,
pad_output_data
+ b * channels * extra_output_height * extra_output_width
......@@ -495,7 +505,23 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
}
}
}
}
#else
if (conv2d_delegator_.get() == nullptr) {
conv2d_delegator_.reset(new ref::Conv2d<float>(paddings[0],
paddings[1],
stride_h,
stride_w,
dilation_h,
dilation_w));
}
conv2d_delegator_->Compute(context, input, filter, output);
#endif
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
if (bias_data != nullptr) {
const index_t image_size = height * width;
#pragma omp parallel for collapse(2) schedule(runtime)
......@@ -702,13 +728,16 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
} // m
} // b
}
private:
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
bool is_filter_transformed_;
SGemm sgemm_;
#ifdef MACE_ENABLE_NEON
std::unique_ptr<arm::fp32::Conv2dBase> conv2d_delegator_;
#else
std::unique_ptr<ref::Conv2d<float>> conv2d_delegator_;
#endif // MACE_ENABLE_NEON
private:
MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS);
......
......@@ -102,6 +102,7 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase {
output_size,
input_size,
false,
true,
output);
Tensor::MappingGuard guard_output(output);
float *output_ptr = output->mutable_data<float>();
......@@ -162,6 +163,7 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t>
output_size,
input_size,
false,
true,
output);
return MaceStatus::MACE_SUCCESS;
}
......
......@@ -25,7 +25,7 @@
#include "mace/utils/utils.h"
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/arm/fp32/gemv.h"
#ifdef MACE_ENABLE_QUANTIZE
......@@ -33,6 +33,7 @@
#endif // MACE_ENABLE_QUANTIZE
#else
#include "mace/ops/ref/gemm.h"
#include "mace/ops/ref/gemv.h"
#endif // MACE_ENABLE_NEON
......@@ -58,35 +59,45 @@ class MatMulOpBase : public Operation {
inline void Validate() {
const Tensor *A = this->Input(INPUT_A);
const Tensor *B = this->Input(INPUT_B);
MACE_CHECK(A->dim_size() == B->dim_size() && A->dim_size() >= 2,
"rank(A) should be equal to rank(B), rank should be greater "
"than or equal to 2");
index_t rank = A->dim_size();
for (index_t i = 0; i < rank - 2; ++i) {
const index_t lhs_rank = A->dim_size();
const index_t rhs_rank = B->dim_size();
MACE_CHECK(lhs_rank >= 2 && rhs_rank >= 2,
"rank should be greater than or equal to 2");
if (lhs_rank == rhs_rank) {
for (index_t i = 0; i < A->dim_size() - 2; ++i) {
MACE_CHECK(A->dim(i) == B->dim(i),
"batch dimensions are not equal: ",
A->dim(i),
" vs. ",
B->dim(i));
}
index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1);
index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2);
MACE_CHECK(ak == bk, "the number of A's column ", ak,
" must be equal to B's row ", bk);
} else {
MACE_CHECK(lhs_rank == 2 || rhs_rank == 2,
"Either lhs or rhs matrix should has rank 2 "
"for non-batched matrix multiplication");
}
index_t
lhs_depth = transpose_a_ ? A->dim(lhs_rank - 2) : A->dim(lhs_rank - 1);
index_t
rhs_depth = transpose_b_ ? B->dim(rhs_rank - 1) : B->dim(rhs_rank - 2);
MACE_CHECK(lhs_depth == rhs_depth, "the number of A's column ", lhs_depth,
" must be equal to B's row ", rhs_depth);
}
protected:
MACE_OP_INPUT_TAGS(INPUT_A, INPUT_B);
MACE_OP_INPUT_TAGS(INPUT_A, INPUT_B, BIAS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
bool transpose_a_;
bool transpose_b_;
};
template <DeviceType D, class T>
template<DeviceType D, class T>
class MatMulOp;
template <>
template<>
class MatMulOp<CPU, float> : public MatMulOpBase {
public:
explicit MatMulOp(OpConstructContext *context)
......@@ -94,72 +105,116 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
MaceStatus Run(OpContext *context) override {
Validate();
const Tensor *A = this->Input(INPUT_A);
const Tensor *B = this->Input(INPUT_B);
const Tensor *lhs = this->Input(INPUT_A);
const Tensor *rhs = this->Input(INPUT_B);
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *C = this->Output(OUTPUT);
index_t batch;
index_t height;
index_t K;
index_t width;
index_t rank = A->dim_size();
height = A->dim(rank - 2);
K = A->dim(rank - 1);
if (transpose_a_) {
std::swap(height, K);
}
if (transpose_b_) {
width = B->dim(rank - 2);
const index_t lhs_rank = lhs->dim_size();
const index_t lhs_rows = lhs->dim(lhs_rank - 2);
const index_t lhs_cols = lhs->dim(lhs_rank - 1);
const index_t rhs_rank = rhs->dim_size();
const index_t rhs_rows = rhs->dim(rhs_rank - 2);
const index_t rhs_cols = rhs->dim(rhs_rank - 1);
const index_t rows = transpose_a_ ? lhs_cols : lhs_rows;
const index_t cols = transpose_b_ ? rhs_rows : rhs_cols;
const index_t depth = transpose_a_ ? lhs_rows : lhs_cols;
const index_t
lhs_batch =
std::accumulate(lhs->shape().begin(), lhs->shape().end() - 2, 1,
std::multiplies<index_t>());
const index_t
rhs_batch =
std::accumulate(rhs->shape().begin(), rhs->shape().end() - 2, 1,
std::multiplies<index_t>());
index_t batch = 1;
std::vector<index_t> output_shape;
if (lhs_rank >= rhs_rank) {
output_shape = lhs->shape();
output_shape[lhs_rank - 2] = rows;
output_shape[lhs_rank - 1] = cols;
batch = lhs_batch;
} else {
width = B->dim(rank - 1);
output_shape = rhs->shape();
output_shape[rhs_rank - 2] = rows;
output_shape[rhs_rank - 1] = cols;
batch = rhs_batch;
}
bool lhs_batched = true;
bool rhs_batched = true;
if (lhs_rank < rhs_rank) {
lhs_batched = false;
} else if (rhs_rank < lhs_rank) {
rhs_batched = false;
}
batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
std::vector<index_t> c_shape = A->shape();
c_shape[rank - 2] = height;
c_shape[rank - 1] = width;
MACE_RETURN_IF_ERROR(C->Resize(c_shape));
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
const float *a_ptr_base = A->data<float>();
const float *b_ptr_base = B->data<float>();
float *c_ptr_base = C->mutable_data<float>();
const index_t height_a = A->dim(rank - 2);
const index_t width_a = A->dim(rank - 1);
const index_t height_b = B->dim(rank - 2);
const index_t width_b = B->dim(rank - 1);
auto scratch_buffer = context->device()->scratch_buffer();
scratch_buffer->Rewind();
MACE_RETURN_IF_ERROR(C->Resize(output_shape));
sgemm_.Run(a_ptr_base,
b_ptr_base,
if (rows == 1 && transpose_b_) {
return gemv_.Compute(context,
rhs,
lhs,
bias,
batch,
cols,
depth,
rhs_batched,
lhs_batched,
C);
} else if (cols == 1 && !transpose_a_) {
return gemv_.Compute(context,
lhs,
rhs,
bias,
batch,
height_a,
width_a,
height_b,
width_b,
rows,
depth,
lhs_batched,
rhs_batched,
C);
} else {
context->device()->scratch_buffer()->Rewind();
MaceStatus ret = gemm_.Compute(context,
lhs,
rhs,
batch,
lhs_rows,
lhs_cols,
rhs_rows,
rhs_cols,
transpose_a_,
transpose_b_,
A->is_weight(),
B->is_weight(),
c_ptr_base,
context->device()->scratch_buffer());
false,
lhs_batched,
rhs_batched,
C);
if (bias != nullptr) {
MACE_CHECK(bias->dim_size() == 1 && bias->dim(0) == cols,
"bias' dim should be <= 2.");
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard c_guard(C);
const float *bias_data = bias->data<float>();
float *c_data = C->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < batch * rows; ++i) {
for (index_t w = 0; w < cols; ++w) {
c_data[i * cols + w] += bias_data[w];
}
}
}
return MaceStatus::MACE_SUCCESS;
return ret;
}
}
private:
SGemm sgemm_;
#ifdef MACE_ENABLE_NEON
arm::fp32::Gemm gemm_;
arm::fp32::Gemv gemv_;
#else
ref::Gemv<float> gemv_;
ref::Gemm<float> gemm_;
#endif // MACE_ENABLE_NEON
};
......@@ -174,18 +229,36 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
void operator()(OpContext *context,
const Tensor *A,
const Tensor *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
const bool lhs_bached,
const bool rhs_bached,
Tensor *C) {
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
#if defined(MACE_ENABLE_NEON)
if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) {
gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C);
gemv_kernel_.Compute(context,
A,
B,
nullptr,
batch,
height,
K,
true,
true,
C);
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C);
gemv_kernel_.Compute(context,
B,
A,
nullptr,
batch,
width,
K,
true,
true,
C);
} else {
#endif // MACE_ENABLE_NEON
Tensor::MappingGuard guarda(A);
......@@ -208,9 +281,13 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K);
a_matrix(a_ptr_base + static_cast<index_t>(lhs_bached) * i * a_size,
height,
K);
gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width);
b_matrix(b_ptr_base + static_cast<index_t>(rhs_bached) * i * b_size,
K,
width);
gemmlowp::MatrixMap <uint8_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width);
......@@ -234,20 +311,39 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
void operator()(OpContext *context,
const Tensor *A,
const Tensor *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
const bool lhs_bached,
const bool rhs_bached,
Tensor *C) {
C->SetScale(A->scale() * B->scale());
C->SetZeroPoint(0);
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
#if defined(MACE_ENABLE_NEON)
if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) {
gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C);
gemv_kernel_.Compute(context,
A,
B,
nullptr,
batch,
height,
K,
lhs_bached,
rhs_bached,
C);
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C);
gemv_kernel_.Compute(context,
B,
A,
nullptr,
batch,
width,
K,
lhs_bached,
rhs_bached,
C);
} else {
#endif // MACE_ENABLE_NEON
Tensor::MappingGuard guarda(A);
......@@ -257,7 +353,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
auto b_ptr_base = B->data<uint8_t>();
auto c_ptr_base = C->mutable_data<int32_t>();
auto
gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext();
gemm_context =
context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context);
index_t a_size = height * K;
......@@ -268,9 +365,15 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K);
a_matrix
(a_ptr_base + static_cast<index_t>(lhs_bached) * i * a_size,
height,
K);
gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width);
b_matrix
(b_ptr_base + static_cast<index_t>(rhs_bached) * i * b_size,
K,
width);
gemmlowp::MatrixMap <int32_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width);
......@@ -280,7 +383,6 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
-B->zero_point(), output_pipeline);
}
}
#if defined(MACE_ENABLE_NEON)
}
......@@ -289,44 +391,65 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
#endif // MACE_ENABLE_NEON
};
template <>
class MatMulOp<DeviceType::CPU, uint8_t>: public MatMulOpBase {
template<>
class MatMulOp<DeviceType::CPU, uint8_t> : public MatMulOpBase {
public:
explicit MatMulOp(OpConstructContext *context)
: MatMulOpBase(context) {}
MaceStatus Run(OpContext *context) override {
Validate();
const Tensor *A = this->Input(INPUT_A);
const Tensor *B = this->Input(INPUT_B);
const Tensor *lhs = this->Input(INPUT_A);
const Tensor *rhs = this->Input(INPUT_B);
Tensor *C = this->Output(OUTPUT);
index_t rank = A->dim_size();
index_t height = A->dim(rank - 2);
index_t K = A->dim(rank - 1);
index_t width;
if (transpose_a_) {
std::swap(height, K);
}
if (transpose_b_) {
width = B->dim(rank - 2);
const index_t lhs_rank = lhs->dim_size();
const index_t lhs_rows = lhs->dim(lhs_rank - 2);
const index_t lhs_cols = lhs->dim(lhs_rank - 1);
const index_t rhs_rank = rhs->dim_size();
const index_t rhs_rows = rhs->dim(rhs_rank - 2);
const index_t rhs_cols = rhs->dim(rhs_rank - 1);
const index_t rows = transpose_a_ ? lhs_cols : lhs_rows;
const index_t cols = transpose_b_ ? rhs_rows : rhs_cols;
const index_t depth = transpose_a_ ? lhs_rows : lhs_cols;
const index_t
lhs_batch =
std::accumulate(lhs->shape().begin(), lhs->shape().end() - 2, 1,
std::multiplies<index_t>());
const index_t
rhs_batch =
std::accumulate(rhs->shape().begin(), rhs->shape().end() - 2, 1,
std::multiplies<index_t>());
index_t batch = 1;
std::vector<index_t> output_shape;
if (lhs_rank >= rhs_rank) {
output_shape = lhs->shape();
output_shape[lhs_rank - 2] = rows;
output_shape[lhs_rank - 1] = cols;
batch = lhs_batch;
} else {
width = B->dim(rank - 1);
output_shape = rhs->shape();
output_shape[rhs_rank - 2] = rows;
output_shape[rhs_rank - 1] = cols;
batch = rhs_batch;
}
bool lhs_batched = true;
bool rhs_batched = true;
if (lhs_rank < rhs_rank) {
lhs_batched = false;
} else if (rhs_rank < lhs_rank) {
rhs_batched = false;
}
std::vector<index_t> c_shape = A->shape();
c_shape[rank - 2] = height;
c_shape[rank - 1] = width;
MACE_RETURN_IF_ERROR(C->Resize(c_shape));
MACE_RETURN_IF_ERROR(C->Resize(output_shape));
constexpr gemmlowp::MapOrder kRowMajor = gemmlowp::MapOrder::RowMajor;
constexpr gemmlowp::MapOrder kColMajor = gemmlowp::MapOrder::ColMajor;
#define MATMUL_FIXPOINT_IMPL(AOrder, BOrder, OutType) \
MatMulFixpointImpl<AOrder, BOrder, OutType>()( \
context, A, B, height, K, width, C);
context, lhs, rhs, batch, rows, depth, cols, lhs_batched, rhs_batched, C);
#define MATMUL_FIXPOINT_IMPL_TRANSPOSE_OR_NOT(OutType) \
if (transpose_a_) { \
......@@ -380,7 +503,6 @@ class MatMulOp<DeviceType::GPU, T> : public MatMulOpBase {
};
#endif // MACE_ENABLE_OPENCL
void RegisterMatMul(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp,
DeviceType::CPU, float);
......
......@@ -101,11 +101,14 @@ void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) {
std::vector<float> rhs(k * n);
std::vector<float> result(m * n);
ops::MatrixMap<const float> matrix_lhs(1, m, k, RowMajor, lhs.data(),
ops::SGemmMatrixMap<const float>
matrix_lhs(1, m, k, SGemmRowMajor, lhs.data(),
true);
ops::MatrixMap<const float> matrix_rhs(1, k, n, RowMajor, rhs.data(),
ops::SGemmMatrixMap<const float>
matrix_rhs(1, k, n, SGemmRowMajor, rhs.data(),
true);
ops::MatrixMap<float> matrix_result(1, m, n, RowMajor, result.data());
ops::SGemmMatrixMap<float>
matrix_result(1, m, n, SGemmRowMajor, result.data());
ops::SGemm sgemm;
......@@ -395,6 +398,7 @@ void MatMulTransposeBenchmark(
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU);
MACE_BM_MATMUL_OP(1, 30000, 256, 1);
MACE_BM_MATMUL_OP(1, 128, 256, 128);
MACE_BM_MATMUL_OP(2, 128, 128, 49);
MACE_BM_MATMUL_OP(3, 128, 128, 49);
MACE_BM_MATMUL_OP(4, 128, 128, 49);
......
......@@ -15,6 +15,7 @@
#include <fstream>
#include "mace/ops/ops_test_util.h"
#include "mace/ops/ref/gemm.h"
namespace mace {
namespace ops {
......@@ -71,34 +72,121 @@ TEST_F(MatMulOpTest, SimpleCPUWithBatch) {
}
namespace {
void QuantOutputUint8(const std::vector<index_t> &batch,
const index_t height,
const index_t channels,
const index_t out_width,
const bool transpose_a,
const bool transpose_b) {
template<DeviceType D>
void Complex(const std::vector<index_t> &batch,
const index_t rows,
const index_t depth,
const index_t cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool lhs_batched,
const bool rhs_batched) {
// Construct graph
OpsTestNet net;
// Add input data
index_t lhs_rows = transpose_lhs ? depth : rows;
index_t lhs_cols = transpose_lhs ? rows : depth;
index_t rhs_rows = transpose_rhs ? cols : depth;
index_t rhs_cols = transpose_rhs ? depth : cols;
std::vector<index_t> lhs_shape = {lhs_rows, lhs_cols};
std::vector<index_t> rhs_shape = {rhs_rows, rhs_cols};
if (lhs_batched) {
lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end());
}
if (rhs_batched) {
rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end());
}
net.AddRandomInput<CPU, float>("A", lhs_shape);
net.AddRandomInput<CPU, float>("B", rhs_shape);
OpDefBuilder("MatMul", "MatMulTest")
.Input("A")
.AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("B")
.AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("Output")
.AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
ref::Gemm<float> gemm;
Tensor expected_output_tensor;
std::vector<index_t> expected_output_shape({rows, cols});
expected_output_shape.insert(expected_output_shape.begin(),
batch.begin(),
batch.end());
expected_output_tensor.Resize(expected_output_shape);
index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1,
std::multiplies<index_t>());
if (transpose_a) {
net.AddRandomInput<CPU, float>("A", {batch_count, channels, height});
} else {
net.AddRandomInput<CPU, float>("A", {batch_count, height, channels});
gemm.Compute(nullptr,
net.GetTensor("A"),
net.GetTensor("B"),
batch_count,
lhs_rows,
lhs_cols,
rhs_rows,
rhs_cols,
transpose_lhs,
transpose_rhs,
false,
lhs_batched,
rhs_batched,
&expected_output_tensor);
ExpectTensorNear<float>(expected_output_tensor, *net.GetTensor("Output"));
}
} // namespace
TEST_F(MatMulOpTest, ComplexCPUWithBatch) {
Complex<DeviceType::CPU>({1}, 3, 3, 3, false, false, true, true);
Complex<DeviceType::CPU>({}, 3, 3, 3, false, false, true, true);
Complex<DeviceType::CPU>({16}, 31, 61, 67, false, true, true, true);
Complex<DeviceType::CPU>({31}, 31, 61, 67, true, false, true, true);
Complex<DeviceType::CPU>({2, 3}, 31, 61, 67, true, true, true, true);
Complex<DeviceType::CPU>({1}, 1, 30001, 253, false, true, true, true);
Complex<DeviceType::CPU>({2}, 253, 300, 1, false, false, true, true);
// test one-side batched
Complex<DeviceType::CPU>({2, 3}, 31, 61, 67, true, true, false, true);
Complex<DeviceType::CPU>({2, 3}, 31, 61, 67, true, true, true, false);
Complex<DeviceType::CPU>({2, 3}, 31, 61, 67, true, true, false, true);
}
namespace {
void QuantOutputUint8(const std::vector<index_t> &batch,
const index_t rows,
const index_t depth,
const index_t cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool lhs_batched = true,
const bool rhs_batched = true) {
// Construct graph
OpsTestNet net;
// Add input data
// Add input data
index_t lhs_rows = transpose_lhs ? depth : rows;
index_t lhs_cols = transpose_lhs ? rows : depth;
index_t rhs_rows = transpose_rhs ? cols : depth;
index_t rhs_cols = transpose_rhs ? depth : cols;
std::vector<index_t> lhs_shape = {lhs_rows, lhs_cols};
std::vector<index_t> rhs_shape = {rhs_rows, rhs_cols};
if (lhs_batched) {
lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end());
}
if (transpose_b) {
net.AddRandomInput<CPU, float>("B", {batch_count, out_width, channels});
} else {
net.AddRandomInput<CPU, float>("B", {batch_count, channels, out_width});
if (rhs_batched) {
rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end());
}
net.AddRandomInput<CPU, float>("A", lhs_shape);
net.AddRandomInput<CPU, float>("B", rhs_shape);
OpDefBuilder("MatMul", "MatMulTest")
.Input("A")
.AddIntArg("transpose_a", transpose_a ? 1 : 0)
.AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("B")
.AddIntArg("transpose_b", transpose_b ? 1 : 0)
.AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("Output")
.AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef());
......@@ -133,9 +221,9 @@ void QuantOutputUint8(const std::vector<index_t> &batch,
OpDefBuilder("MatMul", "QuantizeMatMulTest")
.Input("QuantizedA")
.AddIntArg("transpose_a", transpose_a ? 1 : 0)
.AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("QuantizedB")
.AddIntArg("transpose_b", transpose_b ? 1 : 0)
.AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("QuantizedOutput")
.AddIntArg("T", DT_UINT8)
.OutputType({DT_UINT8})
......@@ -161,39 +249,38 @@ void QuantOutputUint8(const std::vector<index_t> &batch,
}
void QuantOutputInt32(const std::vector<index_t> &batch,
const index_t height,
const index_t channels,
const index_t out_width,
const bool transpose_a,
const bool transpose_b) {
const index_t rows,
const index_t depth,
const index_t cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool lhs_batched = true,
const bool rhs_batched = true) {
// Construct graph
OpsTestNet net;
// Add input data
index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1,
std::multiplies<index_t>());
if (transpose_a) {
net.AddRandomInput<CPU, float>("A", {batch_count, channels, height},
false);
} else {
net.AddRandomInput<CPU, float>("A", {batch_count, height, channels},
false);
// Add input data
index_t lhs_rows = transpose_lhs ? depth : rows;
index_t lhs_cols = transpose_lhs ? rows : depth;
index_t rhs_rows = transpose_rhs ? cols : depth;
index_t rhs_cols = transpose_rhs ? depth : cols;
std::vector<index_t> lhs_shape = {lhs_rows, lhs_cols};
std::vector<index_t> rhs_shape = {rhs_rows, rhs_cols};
if (lhs_batched) {
lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end());
}
if (transpose_b) {
net.AddRandomInput<CPU, float>("B",
{batch_count, out_width, channels},
false);
} else {
net.AddRandomInput<CPU, float>("B",
{batch_count, channels, out_width},
false);
if (rhs_batched) {
rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end());
}
net.AddRandomInput<CPU, float>("A", lhs_shape);
net.AddRandomInput<CPU, float>("B", rhs_shape);
OpDefBuilder("MatMul", "MatMulTest")
.Input("A")
.AddIntArg("transpose_a", transpose_a ? 1 : 0)
.AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("B")
.AddIntArg("transpose_b", transpose_b ? 1 : 0)
.AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("Output")
.AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef());
......@@ -219,9 +306,9 @@ void QuantOutputInt32(const std::vector<index_t> &batch,
OpDefBuilder("MatMul", "QuantizeMatMulTest")
.Input("QuantizedA")
.AddIntArg("transpose_a", transpose_a ? 1 : 0)
.AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("QuantizedB")
.AddIntArg("transpose_b", transpose_b ? 1 : 0)
.AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("QuantizedOutput")
.AddIntArg("T", DT_UINT8)
.OutputType({DT_INT32})
......@@ -256,10 +343,12 @@ TEST_F(MatMulOpTest, QuantOutputUint8) {
QuantOutputUint8({1}, 64, 32, 128, true, true);
QuantOutputUint8({2, 3}, 64, 32, 128, true, true);
// UnAligned
QuantOutputUint8({2}, 3, 3, 3, false, false);
QuantOutputUint8({16}, 31, 61, 67, false, true);
QuantOutputUint8({31}, 31, 61, 67, true, false);
QuantOutputUint8({2, 3}, 31, 61, 67, true, true);
QuantOutputUint8({2, 3}, 31, 61, 67, true, true, true, false);
QuantOutputUint8({2, 3}, 31, 61, 67, true, true, false, true);
}
TEST_F(MatMulOpTest, QuantOutputInt32) {
......@@ -281,12 +370,14 @@ TEST_F(MatMulOpTest, QuantOutputInt32) {
QuantOutputInt32({3}, 128, 256, 1, false, false);
// UnAligned
QuantOutputInt32({2}, 3, 3, 3, false, false);
QuantOutputInt32({16}, 31, 61, 67, false, true);
QuantOutputInt32({31}, 31, 61, 67, true, false);
QuantOutputInt32({2, 3}, 31, 61, 67, true, true);
QuantOutputInt32({1}, 1, 30001, 253, false, true);
QuantOutputInt32({2}, 253, 300, 1, false, false);
QuantOutputInt32({2, 3}, 31, 61, 67, true, true, true, false);
QuantOutputInt32({2, 3}, 31, 61, 67, true, true, false, true);
}
} // namespace test
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ref/conv_2d.h"
#include <vector>
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
namespace ref {
MaceStatus Conv2d<float>::Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) {
MACE_UNUSED(context);
const std::vector<index_t> in_shape = input->shape();
const std::vector<index_t> filter_shape = filter->shape();
const std::vector<index_t> out_shape = output->shape();
const std::vector<int> stride_hw{stride_h_, stride_w_};
const std::vector<int> dilation_hw{dilation_h_, dilation_w_};
const std::vector<int> paddings{pad_h_, pad_w_};
const index_t pad_top = pad_h_ >> 1;
const index_t pad_left = pad_w_ >> 1;
std::vector<index_t> output_shape(4);
CalcOutputSize(in_shape.data(),
NCHW,
filter_shape.data(),
OIHW,
paddings.data(),
dilation_hw.data(),
stride_hw.data(),
RoundType::FLOOR,
output_shape.data());
output->Resize(output_shape);
const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = filter_shape[1] * in_image_size;
const index_t out_batch_size = filter_shape[0] * out_image_size;
const index_t filter_size = filter_shape[2] * filter_shape[3];
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto output_data = output->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < in_shape[0]; b++) {
for (index_t m = 0; m < filter_shape[0]; ++m) {
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
const index_t in_channels = filter_shape[1];
float *out_ptr_base =
output_data + b * out_batch_size + m * out_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
float sum = 0;
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input_data + b * in_batch_size + c * in_image_size;
const float *filter_ptr =
filter_data + m * in_channels * filter_size + c * filter_size;
for (index_t kh = 0; kh < filter_shape[2]; ++kh) {
for (index_t kw = 0; kw < filter_shape[3]; ++kw) {
const index_t ih = -pad_top + h * stride_h_ + kh * dilation_h_;
const index_t iw = -pad_left + w * stride_w_ + kw * dilation_w_;
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
sum += in_ptr_base[ih * in_width + iw] * filter_ptr[kw];
}
} // kw
filter_ptr += filter_shape[3];
} // kh
} // c
out_ptr_base[h * out_width + w] = sum;
} // w
} // h
} // m
} // b
return MaceStatus::MACE_SUCCESS;
}
} // namespace ref
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REF_CONV_2D_H_
#define MACE_OPS_REF_CONV_2D_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
namespace mace {
namespace ops {
namespace ref {
template<typename OUTPUT_TYPE>
class Conv2d {
public:
Conv2d(int stride_h, int stride_w, int dilation_h, int dilation_w);
~Conv2d() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
};
template<>
class Conv2d<float> {
public:
Conv2d(int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w)
: pad_h_(pad_h),
pad_w_(pad_w),
stride_h_(stride_h),
stride_w_(stride_w),
dilation_h_(dilation_h),
dilation_w_(dilation_w) {}
~Conv2d() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
int pad_h_;
int pad_w_;
int stride_h_;
int stride_w_;
int dilation_h_;
int dilation_w_;
};
} // namespace ref
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REF_CONV_2D_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ref/gemm.h"
namespace mace {
namespace ops {
namespace ref {
MaceStatus Gemm<float>::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
MACE_UNUSED(context);
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard output_guard(output);
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) {
MatrixMap<const float>
lhs_matrix
(lhs_data + static_cast<index_t>(lhs_batched) * b * rows * depth,
lhs_major,
rows,
depth);
MatrixMap<const float>
rhs_matrix
(rhs_data + static_cast<index_t>(rhs_batched) * b * depth * cols,
rhs_major,
depth,
cols);
MatrixMap<float>
output_matrix(output_data + b * rows * cols, output_major, rows, cols);
for (index_t r = 0; r < rows; ++r) {
for (index_t c = 0; c < cols; ++c) {
float sum = 0;
for (index_t d = 0; d < depth; ++d) {
sum += lhs_matrix(r, d) * rhs_matrix(d, c);
} // d
*output_matrix.data(r, c) = sum;
} // c
} // r
} // b
return MaceStatus::MACE_SUCCESS;
}
MaceStatus Gemm<float>::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t lhs_rows,
const index_t lhs_cols,
const index_t rhs_rows,
const index_t rhs_cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool transpose_out,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
index_t rows = transpose_lhs ? lhs_cols : lhs_rows;
index_t depth = transpose_lhs ? lhs_rows : lhs_cols;
index_t cols = transpose_rhs ? rhs_rows : rhs_cols;
index_t depth2 = transpose_rhs ? rhs_cols : rhs_rows;
MACE_CHECK(depth == depth2,
"Matrices that multiply have inconsistent depth dim: ",
depth,
" vs. ",
depth2);
return Compute(context,
lhs,
rhs,
batch,
rows,
cols,
depth,
transpose_lhs ? ColMajor : RowMajor,
transpose_rhs ? ColMajor : RowMajor,
transpose_out ? ColMajor : RowMajor,
lhs_batched,
rhs_batched,
output);
}
} // namespace ref
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REF_GEMM_H_
#define MACE_OPS_REF_GEMM_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/matrix.h"
namespace mace {
namespace ops {
namespace ref {
template<typename OUTPUT_TYPE>
class Gemm {
public:
Gemm() {}
~Gemm() {}
MaceStatus Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
template<>
class Gemm<float> {
public:
Gemm() {}
~Gemm() {}
MaceStatus Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
// Original matrix before transpose has row-major
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t lhs_rows,
const index_t lhs_cols,
const index_t rhs_rows,
const index_t rhs_cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool transpose_out,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
} // namespace ref
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REF_GEMM_H_
......@@ -31,6 +31,7 @@ MaceStatus Gemv<float>::Compute(const OpContext *context,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
MACE_UNUSED(context);
......@@ -54,7 +55,7 @@ MaceStatus Gemv<float>::Compute(const OpContext *context,
sum += lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w]
* rhs_data[b * lhs_width + w];
* rhs_data[static_cast<index_t>(rhs_batched) * b * lhs_width + w];
} // w
output_data[b * lhs_height + h] = sum;
......@@ -73,6 +74,7 @@ MaceStatus Gemv<uint8_t>::Compute(const OpContext *context,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
MACE_UNUSED(context);
......@@ -102,7 +104,8 @@ MaceStatus Gemv<uint8_t>::Compute(const OpContext *context,
sum += (lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w] - lhs_zero)
* (rhs_data[b * lhs_width + w] - rhs_zero);
* (rhs_data[static_cast<index_t>(rhs_batched) * b * lhs_width + w]
- rhs_zero);
} // w
output_data[b * lhs_height + h] =
......@@ -120,6 +123,7 @@ MaceStatus Gemv<int32_t>::Compute(const OpContext *context,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
MACE_UNUSED(context);
......@@ -146,7 +150,8 @@ MaceStatus Gemv<int32_t>::Compute(const OpContext *context,
sum += (lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w] - lhs_zero)
* (rhs_data[b * lhs_width + w] - rhs_zero);
* (rhs_data[static_cast<index_t>(rhs_batched) * b * lhs_width + w]
- rhs_zero);
} // w
output_data[b * lhs_height + h] = sum;
......
......@@ -39,6 +39,7 @@ class Gemv {
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
......@@ -57,6 +58,7 @@ class Gemv<float> {
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
......@@ -76,6 +78,7 @@ class Gemv<uint8_t> {
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
......@@ -94,6 +97,7 @@ class Gemv<int32_t> {
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
#endif // MACE_ENABLE_QUANTIZE
......
......@@ -27,39 +27,17 @@
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace {
inline void AdviseFree(void *addr, size_t length) {
int page_size = sysconf(_SC_PAGESIZE);
void *addr_aligned =
reinterpret_cast<void *>(
(reinterpret_cast<uintptr_t>(addr) + page_size - 1)
& (~(page_size - 1)));
uintptr_t delta =
reinterpret_cast<uintptr_t>(addr_aligned)
- reinterpret_cast<uintptr_t>(addr);
if (length >= delta + page_size) {
size_t len_aligned = (length - delta) & (~(page_size - 1));
int ret = madvise(addr_aligned, len_aligned, MADV_DONTNEED);
if (ret != 0) {
LOG(ERROR) << "Advise free failed: " << strerror(errno);
}
}
}
} // namespace
namespace mace {
namespace ops {
void SGemm::operator()(const MatrixMap<const float> &lhs,
const MatrixMap<const float> &rhs,
MatrixMap<float> *result,
void SGemm::operator()(const SGemmMatrixMap<const float> &lhs,
const SGemmMatrixMap<const float> &rhs,
SGemmMatrixMap<float> *result,
ScratchBuffer *scratch_buffer) {
if (lhs.is_const() && !rhs.is_const()) {
MatrixMap<const float> lhs_transpose = lhs.transpose();
MatrixMap<const float> rhs_transpose = rhs.transpose();
MatrixMap<float> result_transpose = result->transpose();
SGemmMatrixMap<const float> lhs_transpose = lhs.transpose();
SGemmMatrixMap<const float> rhs_transpose = rhs.transpose();
SGemmMatrixMap<float> result_transpose = result->transpose();
return operator()(rhs_transpose,
lhs_transpose,
&result_transpose,
......@@ -150,18 +128,18 @@ void SGemm::Run(const float *A,
width_c = height_b;
}
MatrixMap<const float> matrix_a =
MatrixMap<const float>(batch,
SGemmMatrixMap<const float> matrix_a =
SGemmMatrixMap<const float>(batch,
height_a,
width_a,
ops::RowMajor,
ops::SGemmRowMajor,
A,
is_a_weight);
MatrixMap<const float> matrix_b =
ops::MatrixMap<const float>(batch,
SGemmMatrixMap<const float> matrix_b =
ops::SGemmMatrixMap<const float>(batch,
height_b,
width_b,
ops::RowMajor,
ops::SGemmRowMajor,
B,
is_b_weight);
if (transpose_a) {
......@@ -170,7 +148,8 @@ void SGemm::Run(const float *A,
if (transpose_b) {
matrix_b = matrix_b.transpose();
}
MatrixMap<float> matrix_c(batch, height_c, width_c, ops::RowMajor, C);
SGemmMatrixMap<float>
matrix_c(batch, height_c, width_c, ops::SGemmRowMajor, C);
operator()(matrix_a, matrix_b, &matrix_c, scratch_buffer);
}
......@@ -930,17 +909,17 @@ void SGemm::RunPerBatch(const float *lhs_data,
} // bw
}
void SGemm::PackLhs(const MatrixMap<const float> &lhs,
void SGemm::PackLhs(const SGemmMatrixMap<const float> &lhs,
PackedBlock *packed_block) {
Pack(lhs, PackOrder::ColMajor, packed_block);
Pack(lhs, PackOrder::SGemmColMajor, packed_block);
}
void SGemm::PackRhs(const MatrixMap<const float> &rhs,
void SGemm::PackRhs(const SGemmMatrixMap<const float> &rhs,
PackedBlock *packed_block) {
Pack(rhs, PackOrder::RowMajor, packed_block);
Pack(rhs, PackOrder::SGemmRowMajor, packed_block);
}
void SGemm::Pack(const MatrixMap<const float> &src,
void SGemm::Pack(const SGemmMatrixMap<const float> &src,
const PackOrder order,
PackedBlock *packed_block) {
MACE_CHECK_NOTNULL(packed_block);
......@@ -963,7 +942,7 @@ void SGemm::Pack(const MatrixMap<const float> &src,
}
void SGemm::UnPack(const PackedBlock &packed_result,
MatrixMap<float> *matrix_map) {
SGemmMatrixMap<float> *matrix_map) {
MACE_CHECK_NOTNULL(matrix_map);
const index_t height = matrix_map->row();
......@@ -984,7 +963,7 @@ void SGemm::UnPack(const PackedBlock &packed_result,
#undef MACE_SGEMM_UNPACK_PER_BATCH
}
void SGemm::PackPerBatch(const MatrixMap<const float> &src,
void SGemm::PackPerBatch(const SGemmMatrixMap<const float> &src,
const PackOrder order,
const index_t batch_index,
float *packed_data) {
......@@ -994,7 +973,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
const index_t width = src.col();
auto src_data = src.batch_data(batch_index);
if (src.map_major() == Major::RowMajor && order == PackOrder::ColMajor) {
if (src.map_major() == Major::SGemmRowMajor
&& order == PackOrder::SGemmColMajor) {
// This is for packing no-transpose lhs.
index_t h = 0;
#if defined(MACE_ENABLE_NEON)
......@@ -1040,8 +1020,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
for (index_t ih = h; ih < height; ++ih) {
std::copy_n(src_data + ih * width, width, packed_data + ih * width);
}
} else if (src.map_major() == Major::ColMajor &&
order == PackOrder::ColMajor) {
} else if (src.map_major() == Major::SGemmColMajor &&
order == PackOrder::SGemmColMajor) {
// This is for packing transpose-needed lhs.
index_t h = 0;
#if defined(MACE_ENABLE_NEON)
......@@ -1082,8 +1062,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
packed_data_ptr[w] = src_data_ptr[w * height];
}
}
} else if (src.map_major() == Major::RowMajor &&
order == PackOrder::RowMajor) {
} else if (src.map_major() == Major::SGemmRowMajor &&
order == PackOrder::SGemmRowMajor) {
// This is for packing no-transpose rhs.
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
......@@ -1108,8 +1088,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
packed_data_ptr[h] = src_data_ptr[h * width];
}
}
} else if (src.map_major() == Major::ColMajor &&
order == PackOrder::RowMajor) {
} else if (src.map_major() == Major::SGemmColMajor &&
order == PackOrder::SGemmRowMajor) {
// This is for packing transpose-needed rhs.
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
......@@ -1138,14 +1118,14 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
void SGemm::UnPackPerBatch(const float *packed_data,
const index_t batch_index,
MatrixMap<float> *matrix_map) {
SGemmMatrixMap<float> *matrix_map) {
MACE_CHECK_NOTNULL(matrix_map);
const index_t height = matrix_map->row();
const index_t width = matrix_map->col();
auto unpacked_data = matrix_map->batch_data(batch_index);
if (matrix_map->map_major() == Major::RowMajor) {
if (matrix_map->map_major() == Major::SGemmRowMajor) {
// This is for non-transposed result
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// This implementation is deprecated. use mace/ops/arm/fp32/gemm.h instead.
#ifndef MACE_OPS_SGEMM_H_
#define MACE_OPS_SGEMM_H_
......@@ -30,16 +32,16 @@ namespace mace {
namespace ops {
enum Major {
RowMajor,
ColMajor
SGemmRowMajor,
SGemmColMajor
};
template<typename T>
class MatrixMap {
class SGemmMatrixMap {
public:
MatrixMap() {}
SGemmMatrixMap() {}
MatrixMap(const index_t batch,
SGemmMatrixMap(const index_t batch,
const index_t row,
const index_t col,
const Major major,
......@@ -48,14 +50,20 @@ class MatrixMap {
batch_(batch),
row_(row),
col_(col),
stride_(major == RowMajor ? col : row),
stride_(major == SGemmRowMajor ? col : row),
major_(major),
data_(data),
is_const_(is_const) {}
MatrixMap transpose() const {
Major transpose_major = major_ == RowMajor ? ColMajor : RowMajor;
return MatrixMap(batch_, col_, row_, transpose_major, data_, is_const_);
SGemmMatrixMap transpose() const {
Major transpose_major =
major_ == SGemmRowMajor ? SGemmColMajor : SGemmRowMajor;
return SGemmMatrixMap(batch_,
col_,
row_,
transpose_major,
data_,
is_const_);
}
index_t batch() const {
......@@ -114,9 +122,9 @@ class SGemm {
packed_rhs_(nullptr),
packed_(false) {}
void operator()(const MatrixMap<const float> &lhs,
const MatrixMap<const float> &rhs,
MatrixMap<float> *result,
void operator()(const SGemmMatrixMap<const float> &lhs,
const SGemmMatrixMap<const float> &rhs,
SGemmMatrixMap<float> *result,
ScratchBuffer *scratch_buffer = nullptr);
void Run(const float *A,
......@@ -133,28 +141,28 @@ class SGemm {
float *C,
ScratchBuffer *scratch_buffer = nullptr);
void PackLhs(const MatrixMap<const float> &lhs,
void PackLhs(const SGemmMatrixMap<const float> &lhs,
PackedBlock *packed_block);
void PackRhs(const MatrixMap<const float> &rhs,
void PackRhs(const SGemmMatrixMap<const float> &rhs,
PackedBlock *packed_block);
void UnPack(const PackedBlock &packed_result,
MatrixMap<float> *matrix_map);
SGemmMatrixMap<float> *matrix_map);
private:
void Pack(const MatrixMap<const float> &src,
void Pack(const SGemmMatrixMap<const float> &src,
const PackOrder order,
PackedBlock *packed_block);
void PackPerBatch(const MatrixMap<const float> &src,
void PackPerBatch(const SGemmMatrixMap<const float> &src,
const PackOrder order,
const index_t batch_index,
float *packed_data);
void UnPackPerBatch(const float *packed_data,
const index_t batch_index,
MatrixMap<float> *matrix_map);
SGemmMatrixMap<float> *matrix_map);
void RunInternal(const PackedBlock &lhs,
const PackedBlock &rhs,
......
......@@ -31,10 +31,11 @@ void TestPack(const std::vector<float> &data,
Major src_order,
PackOrder pack_order) {
SGemm sg;
MatrixMap<const float> src_matrix(1, height, width, src_order, data.data());
SGemmMatrixMap<const float>
src_matrix(1, height, width, src_order, data.data());
PackedBlock packed;
packed.Resize({height, width});
if (pack_order == PackOrder::ColMajor) {
if (pack_order == PackOrder::SGemmColMajor) {
sg.PackLhs(src_matrix, &packed);
} else {
sg.PackRhs(src_matrix, &packed);
......@@ -57,18 +58,19 @@ void TestUnPack(const index_t height,
data[i] = rand_r(&seed);
}
MatrixMap<const float> src_matrix(1, height, width, src_order, data.data());
SGemmMatrixMap<const float>
src_matrix(1, height, width, src_order, data.data());
PackedBlock packed;
packed.Resize({height, width});
SGemm sg;
if (pack_order == PackOrder::ColMajor) {
if (pack_order == PackOrder::SGemmColMajor) {
sg.PackLhs(src_matrix, &packed);
} else {
sg.PackRhs(src_matrix, &packed);
}
std::vector<float> unpacked(matrix_size);
MatrixMap<float>
SGemmMatrixMap<float>
unpacked_matrix(1, height, width, src_order, unpacked.data());
sg.UnPack(packed, &unpacked_matrix);
auto unpacked_data = unpacked.data();
......@@ -87,78 +89,78 @@ TEST(SGemmPackTest, Pack) {
// For no-transpose lhs
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3, 4, Major::RowMajor, PackOrder::ColMajor);
3, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16},
4, 4, Major::RowMajor, PackOrder::ColMajor);
4, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19,
20},
5, 4, Major::RowMajor, PackOrder::ColMajor);
5, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#if defined(__aarch64__)
TestPack(data,
{1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11,
15, 19, 23, 27, 31, 4, 8, 12, 16, 20, 24, 28, 32, 33, 34, 35, 36},
9, 4, Major::RowMajor, PackOrder::ColMajor);
9, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#endif
#endif
// For transpose-needed lhs
TestPack(data,
{1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12},
3, 4, Major::ColMajor, PackOrder::ColMajor);
3, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
4, 4, Major::ColMajor, PackOrder::ColMajor);
4, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
TestPack(data,
{1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15,
20},
5, 4, Major::ColMajor, PackOrder::ColMajor);
5, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#if defined(__aarch64__)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21,
22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 9, 18, 27, 36},
9, 4, Major::ColMajor, PackOrder::ColMajor);
9, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#endif
#endif
// For no-transpose rhs
TestPack(data,
{1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12},
4, 3, Major::RowMajor, PackOrder::RowMajor);
4, 3, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
4, 4, Major::RowMajor, PackOrder::RowMajor);
4, 4, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestPack(data,
{1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15,
20},
4, 5, Major::RowMajor, PackOrder::RowMajor);
4, 5, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
#endif
// For transpose-needed rhs
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
4, 3, Major::ColMajor, PackOrder::RowMajor);
4, 3, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16},
4, 4, Major::ColMajor, PackOrder::RowMajor);
4, 4, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19,
20},
4, 5, Major::ColMajor, PackOrder::RowMajor);
4, 5, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
#endif
}
TEST(SGemmPackTest, UnPack) {
TestUnPack(4, 3, Major::RowMajor, PackOrder::RowMajor);
TestUnPack(4, 4, Major::RowMajor, PackOrder::RowMajor);
TestUnPack(4, 5, Major::RowMajor, PackOrder::RowMajor);
TestUnPack(4, 100, Major::RowMajor, PackOrder::RowMajor);
TestUnPack(4, 3, Major::ColMajor, PackOrder::RowMajor);
TestUnPack(4, 4, Major::ColMajor, PackOrder::RowMajor);
TestUnPack(4, 5, Major::ColMajor, PackOrder::RowMajor);
TestUnPack(4, 100, Major::ColMajor, PackOrder::RowMajor);
TestUnPack(4, 3, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 4, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 5, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 100, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 3, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 4, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 5, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 100, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
}
} // namespace test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册