提交 1be8257c 编写于 作者: 李滨

Merge branch 'opt_gemm' into 'master'

Optimize gemm

See merge request !980
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "mace/core/allocator.h" #include "mace/core/allocator.h"
#include <unistd.h>
#include <sys/mman.h>
#include <memory>
namespace mace { namespace mace {
Allocator *GetCPUAllocator() { Allocator *GetCPUAllocator() {
...@@ -21,4 +25,22 @@ Allocator *GetCPUAllocator() { ...@@ -21,4 +25,22 @@ Allocator *GetCPUAllocator() {
return &allocator; return &allocator;
} }
void AdviseFree(void *addr, size_t length) {
int page_size = sysconf(_SC_PAGESIZE);
void *addr_aligned =
reinterpret_cast<void *>(
(reinterpret_cast<uintptr_t>(addr) + page_size - 1)
& (~(page_size - 1)));
uintptr_t delta =
reinterpret_cast<uintptr_t>(addr_aligned)
- reinterpret_cast<uintptr_t>(addr);
if (length >= delta + page_size) {
size_t len_aligned = (length - delta) & (~(page_size - 1));
int ret = madvise(addr_aligned, len_aligned, MADV_DONTNEED);
if (ret != 0) {
LOG(ERROR) << "Advise free failed: " << strerror(errno);
}
}
}
} // namespace mace } // namespace mace
...@@ -40,6 +40,10 @@ constexpr size_t kMaceAlignment = 64; ...@@ -40,6 +40,10 @@ constexpr size_t kMaceAlignment = 64;
constexpr size_t kMaceAlignment = 32; constexpr size_t kMaceAlignment = 32;
#endif #endif
inline index_t PadAlignSize(index_t size) {
return (size + kMaceAlignment - 1) & (~(kMaceAlignment - 1));
}
class Allocator { class Allocator {
public: public:
Allocator() {} Allocator() {}
...@@ -140,6 +144,8 @@ class CPUAllocator : public Allocator { ...@@ -140,6 +144,8 @@ class CPUAllocator : public Allocator {
// Global CPU allocator used for CPU/GPU/DSP // Global CPU allocator used for CPU/GPU/DSP
Allocator *GetCPUAllocator(); Allocator *GetCPUAllocator();
void AdviseFree(void *addr, size_t length);
} // namespace mace } // namespace mace
#endif // MACE_CORE_ALLOCATOR_H_ #endif // MACE_CORE_ALLOCATOR_H_
...@@ -384,7 +384,7 @@ class BufferSlice : public BufferBase { ...@@ -384,7 +384,7 @@ class BufferSlice : public BufferBase {
BufferSlice(const BufferSlice &other) BufferSlice(const BufferSlice &other)
: BufferSlice(other.buffer_, other.offset_, other.size_) {} : BufferSlice(other.buffer_, other.offset_, other.size_) {}
~BufferSlice() { virtual ~BufferSlice() {
if (buffer_ != nullptr && mapped_buf_ != nullptr) { if (buffer_ != nullptr && mapped_buf_ != nullptr) {
UnMap(); UnMap();
} }
...@@ -506,7 +506,7 @@ class ScratchBuffer: public Buffer { ...@@ -506,7 +506,7 @@ class ScratchBuffer: public Buffer {
virtual ~ScratchBuffer() {} virtual ~ScratchBuffer() {}
MaceStatus GrowSize(index_t size) { MaceStatus GrowSize(const index_t size) {
if (size > size_) { if (size > size_) {
VLOG(1) << "Grow scratch size to: " << size; VLOG(1) << "Grow scratch size to: " << size;
MACE_CHECK(offset_ == 0, "scratch is being used, cannot grow size"); MACE_CHECK(offset_ == 0, "scratch is being used, cannot grow size");
......
...@@ -25,11 +25,11 @@ void OpContext::set_device(Device *device) { ...@@ -25,11 +25,11 @@ void OpContext::set_device(Device *device) {
device_ = device; device_ = device;
} }
Device* OpContext::device() { Device* OpContext::device() const {
return device_; return device_;
} }
Workspace* OpContext::workspace() { Workspace* OpContext::workspace() const {
return ws_; return ws_;
} }
...@@ -37,7 +37,7 @@ void OpContext::set_future(StatsFuture *future) { ...@@ -37,7 +37,7 @@ void OpContext::set_future(StatsFuture *future) {
future_ = future; future_ = future;
} }
StatsFuture *OpContext::future() { StatsFuture *OpContext::future() const {
return future_; return future_;
} }
......
...@@ -26,11 +26,11 @@ class OpContext { ...@@ -26,11 +26,11 @@ class OpContext {
OpContext(Workspace *ws, Device *device); OpContext(Workspace *ws, Device *device);
~OpContext(); ~OpContext();
void set_device(Device *device); void set_device(Device *device);
Device *device(); Device *device() const;
Workspace *workspace(); Workspace *workspace() const;
void set_future(StatsFuture *future); void set_future(StatsFuture *future);
StatsFuture *future(); StatsFuture *future() const;
private: private:
Device *device_; Device *device_;
Workspace *ws_; Workspace *ws_;
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_CONV_2D_H_
#define MACE_OPS_ARM_FP32_CONV_2D_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Conv2dBase {
public:
Conv2dBase() = default;
virtual ~Conv2dBase() = default;
virtual MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) = 0;
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_CONV_2D_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/fp32/conv_2d_1x1.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
MaceStatus Conv2dK1x1::Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) {
index_t batch = input->dim(0);
index_t height = input->dim(2);
index_t width = input->dim(3);
index_t in_channels = input->dim(1);
index_t out_channels = filter->dim(0);
MACE_RETURN_IF_ERROR(output->Resize({batch, out_channels, height, width}));
context->device()->scratch_buffer()->Rewind();
return gemm_.Compute(context,
filter,
input,
batch,
out_channels,
in_channels,
in_channels,
height * width,
false,
false,
false,
false,
true,
output);
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_CONV_2D_1X1_H_
#define MACE_OPS_ARM_FP32_CONV_2D_1X1_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/arm/fp32/conv_2d.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Conv2dK1x1 : public Conv2dBase {
public:
Conv2dK1x1() : gemm_(true) {}
virtual ~Conv2dK1x1() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
Gemm gemm_;
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_CONV_2D_1X1_H_
此差异已折叠。
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_GEMM_H_
#define MACE_OPS_ARM_FP32_GEMM_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/matrix.h"
// This implements matrix-matrix multiplication.
// In the case of matrix-vector multiplication, use gemv.h/gemv.cc instead
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Gemm {
public:
explicit Gemm(const bool should_cache_pack)
: tmp_scratch_buffer_(GetCPUAllocator()),
pack_cache_(GetCPUAllocator()),
should_cache_pack_(should_cache_pack),
cached_(0) {}
Gemm() : Gemm(false) {}
~Gemm() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
// Original matrix before transpose has row-major
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t lhs_rows,
const index_t lhs_cols,
const index_t rhs_rows,
const index_t rhs_cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool transpose_out,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
private:
void ComputeBlock(const float *packed_lhs_data,
const float *packed_rhs_data,
const index_t depth_padded,
float *packed_output_data);
void PackLhs(const MatrixMap<const float> &lhs,
float *packed_lhs);
void PackRhs(const MatrixMap<const float> &rhs,
float *packed_rhs);
void UnpackOutput(const float *packed_output,
MatrixMap<float> *output);
template<int RowBlockSize, int ColBlockSize>
void Unpack(const float *packed_output,
MatrixMap<float> *output) {
const index_t rows = output->rows();
const index_t cols = output->cols();
for (index_t r = 0; r < rows; ++r) {
for (index_t c = 0; c < cols; ++c) {
*output->data(r, c) = packed_output[r * ColBlockSize + c];
}
}
}
template<int WidthBlockSize, int DepthBlockSize>
void Pack(const MatrixMap<const float> &matrix,
MatrixMajor dst_major,
float *packed_matrix) {
const index_t rows = matrix.rows();
const index_t cols = matrix.cols();
index_t depth = cols;
if (dst_major == RowMajor) {
// rhs
depth = rows;
}
const index_t depth_padded = RoundUp(depth, static_cast<index_t>(4));
memset(packed_matrix, 0, sizeof(float) * WidthBlockSize * depth_padded);
if (dst_major == ColMajor) {
for (index_t c = 0; c < cols; ++c) {
for (index_t r = 0; r < rows; ++r) {
packed_matrix[c * WidthBlockSize + r] = matrix(r, c);
}
}
} else {
for (index_t r = 0; r < rows; ++r) {
for (index_t c = 0; c < cols; ++c) {
packed_matrix[r * WidthBlockSize + c] = matrix(r, c);
}
}
}
}
ScratchBuffer tmp_scratch_buffer_;
Buffer pack_cache_;
bool should_cache_pack_;
int cached_;
};
template<>
void Gemm::Pack<4, 4>(const MatrixMap<const float> &matrix,
MatrixMajor dst_major,
float *packed_matrix);
template<>
void Gemm::Pack<8, 4>(const MatrixMap<const float> &matrix,
MatrixMajor dst_major,
float *packed_matrix);
template<>
void Gemm::Unpack<4, 8>(const float *packed_output, MatrixMap<float> *output);
template<>
void Gemm::Unpack<8, 8>(const float *packed_output, MatrixMap<float> *output);
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_GEMM_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/ref/gemm.h"
#include "mace/ops/testing/test_utils.h"
namespace mace {
namespace ops {
namespace test {
void TestGemmFloat32(const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor output(GetCPUAllocator(), DataType::DT_FLOAT);
lhs.Resize({lhs_batched ? batch : 1, rows, depth});
rhs.Resize({rhs_batched ? batch : 1, depth, cols});
output.Resize({batch, rows, cols});
{
Tensor::MappingGuard lhs_guard(&lhs);
Tensor::MappingGuard rhs_guard(&rhs);
float *lhs_data = lhs.mutable_data<float>();
float *rhs_data = rhs.mutable_data<float>();
float *output_data = output.mutable_data<float>();
GenerateRandomRealTypeData<float>(lhs.shape(), lhs_data);
GenerateRandomRealTypeData<float>(rhs.shape(), rhs_data);
GenerateRandomRealTypeData<float>(output.shape(), output_data);
}
::mace::ops::arm::fp32::Gemm gemm;
gemm.Compute(nullptr,
&lhs,
&rhs,
batch,
rows,
cols,
depth,
lhs_major,
rhs_major,
output_major,
lhs_batched,
rhs_batched,
&output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT);
expected_output.Resize({batch, rows, cols});
::mace::ops::ref::Gemm<float> gemm_ref;
gemm_ref.Compute(nullptr,
&lhs,
&rhs,
batch,
rows,
cols,
depth,
lhs_major,
rhs_major,
output_major,
lhs_batched,
rhs_batched,
&expected_output);
ExpectTensorNear<float>(expected_output, output);
}
TEST(ArmGemm, TestGemmFloat32) {
TestGemmFloat32(1, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true);
TestGemmFloat32(1, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true);
TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, false);
TestGemmFloat32(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, false, true);
TestGemmFloat32(16, 31, 61, 67, RowMajor, ColMajor, RowMajor, true, true);
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) #define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif #endif
// Disable unroll by default, since cache set conflict could be significant
// #define MACE_GEMV_UNROLL 1
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
...@@ -35,14 +38,26 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -35,14 +38,26 @@ MaceStatus Gemv::Compute(const OpContext *context,
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output) { Tensor *output) {
MACE_UNUSED(context); MACE_UNUSED(context);
MACE_CHECK(output->size() == batch * lhs_height,
"Need resize output tensor before call gemv.");
Tensor::MappingGuard lhs_guard(lhs); Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs); Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard bias_guard(bias); Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
const float *bias_data = nullptr;
if (bias) {
bias_data = bias->data<float>();
}
float *output_data = output->mutable_data<float>();
const index_t h_block_size = 4; const index_t h_block_size = 4;
const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size); const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size);
const index_t w_block_size = 8; const index_t w_block_size = 8;
...@@ -52,28 +67,20 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -52,28 +67,20 @@ MaceStatus Gemv::Compute(const OpContext *context,
#pragma omp parallel for collapse(2) schedule(runtime) #pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < batch; ++b) { for (index_t b = 0; b < batch; ++b) {
for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) { for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) {
// TODO(liyin): it can be put it outside the loop, const index_t h_start = h_block_idx * h_block_size;
// but openmp limits param count
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
const float *bias_data = nullptr;
if (bias) {
bias_data = bias->data<float>();
}
float *output_data = output->mutable_data<float>();
const float const float
*lhs_ptr = lhs_data *lhs_ptr = lhs_data
+ static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width + static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ lhs_width * h_block_idx * h_block_size; + lhs_width * h_start;
const float *rhs_ptr = rhs_data + b * lhs_width; const float *rhs_ptr =
rhs_data + static_cast<index_t>(rhs_batched) * b * lhs_width;
float float
*ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size; *ret_ptr = output_data + b * lhs_height + h_start;
const index_t h_block_len = const index_t h_block_len =
std::min(h_block_size, lhs_height - h_block_idx * h_block_size); std::min(h_block_size, lhs_height - h_start);
const index_t h_offset = h_block_idx * h_block_size;
#ifdef MACE_GEMV_UNROLL
if (h_block_len == 4) { if (h_block_len == 4) {
float32x4_t vo0 = vdupq_n_f32(0); float32x4_t vo0 = vdupq_n_f32(0);
float32x4_t vo1 = vdupq_n_f32(0); float32x4_t vo1 = vdupq_n_f32(0);
...@@ -149,6 +156,11 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -149,6 +156,11 @@ MaceStatus Gemv::Compute(const OpContext *context,
"vmla.f32 %q[vo2], q4, q8\n" "vmla.f32 %q[vo2], q4, q8\n"
"vmla.f32 %q[vo3], q6, q8\n" "vmla.f32 %q[vo3], q6, q8\n"
"vld1.f32 {d0-d1}, [r1]!\n"
"vld1.f32 {d4-d5}, [r2]!\n"
"vld1.f32 {d8-d9}, [r3]!\n"
"vld1.f32 {d12-d13}, [r4]!\n"
"vld1.f32 {d16-d17}, [r0]!\n"
"vmla.f32 %q[vo0], q1, q9\n" "vmla.f32 %q[vo0], q1, q9\n"
"vmla.f32 %q[vo1], q3, q9\n" "vmla.f32 %q[vo1], q3, q9\n"
...@@ -157,13 +169,6 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -157,13 +169,6 @@ MaceStatus Gemv::Compute(const OpContext *context,
"subs %[r_w_block_count], #1\n" "subs %[r_w_block_count], #1\n"
"vld1.f32 {d0-d1}, [r1]!\n"
"vld1.f32 {d4-d5}, [r2]!\n"
"vld1.f32 {d8-d9}, [r3]!\n"
"vld1.f32 {d12-d13}, [r4]!\n"
"vld1.f32 {d16-d17}, [r0]!\n"
"vld1.f32 {d2-d3}, [r1]!\n" "vld1.f32 {d2-d3}, [r1]!\n"
"vld1.f32 {d6-d7}, [r2]!\n" "vld1.f32 {d6-d7}, [r2]!\n"
"vld1.f32 {d10-d11}, [r3]!\n" "vld1.f32 {d10-d11}, [r3]!\n"
...@@ -257,26 +262,30 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -257,26 +262,30 @@ MaceStatus Gemv::Compute(const OpContext *context,
vo = vaddq_f32(vo, vbias); vo = vaddq_f32(vo, vbias);
vst1q_f32(ret_ptr, vo); vst1q_f32(ret_ptr, vo);
} else { // h_block_len < 4 } else { // h_block_len < 4
// TODO(liyin): handle here case by case (1,2,3) to accelerate #endif // MACE_GEMV_UNROLL
const float *tmp_lhs_ptr = lhs_ptr; const float *tmp_lhs_ptr = lhs_ptr;
const float *tmp_rhs_ptr = rhs_ptr; const float *tmp_rhs_ptr = rhs_ptr;
for (index_t h = 0; h < h_block_len; ++h) { for (index_t h = 0; h < h_block_len; ++h) {
lhs_ptr = tmp_lhs_ptr + h * lhs_width; lhs_ptr = tmp_lhs_ptr + h * lhs_width;
rhs_ptr = tmp_rhs_ptr; rhs_ptr = tmp_rhs_ptr;
float32x4_t vo0 = vdupq_n_f32(0); float32x4_t vo0 = vdupq_n_f32(0);
float32x4_t vo0n = vdupq_n_f32(0);
for (index_t w = 0; w < w_block_count; ++w) { for (index_t w = 0; w < w_block_count; ++w) {
float32x4_t vr0 = vld1q_f32(rhs_ptr); float32x4_t vr0 = vld1q_f32(rhs_ptr);
float32x4_t vr0n = vld1q_f32(rhs_ptr + 4); float32x4_t vr0n = vld1q_f32(rhs_ptr + 4);
float32x4_t vl0 = vld1q_f32(lhs_ptr); float32x4_t vl0 = vld1q_f32(lhs_ptr);
float32x4_t vl0n = vld1q_f32(lhs_ptr + 4); float32x4_t vl0n = vld1q_f32(lhs_ptr + 4);
// may cause some precision error depending on the compute order
vo0 = vmlaq_f32(vo0, vl0, vr0); vo0 = vmlaq_f32(vo0, vl0, vr0);
vo0 = vmlaq_f32(vo0, vl0n, vr0n); vo0n = vmlaq_f32(vo0n, vl0n, vr0n);
lhs_ptr += 8; lhs_ptr += 8;
rhs_ptr += 8; rhs_ptr += 8;
} // w } // w
float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_offset + h] : 0); vo0 = vaddq_f32(vo0, vo0n);
float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_start + h] : 0);
for (index_t w = 0; w < w_remain; ++w) { for (index_t w = 0; w < w_remain; ++w) {
s0 += lhs_ptr[0] * rhs_ptr[0]; s0 += lhs_ptr[0] * rhs_ptr[0];
++lhs_ptr; ++lhs_ptr;
...@@ -285,7 +294,9 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -285,7 +294,9 @@ MaceStatus Gemv::Compute(const OpContext *context,
ret_ptr[h] = s0; ret_ptr[h] = s0;
} // h } // h
#ifdef MACE_GEMV_UNROLL
} // if } // if
#endif // MACE_GEMV_UNROLL
} // h_block_idx } // h_block_idx
} // b } // b
......
...@@ -38,6 +38,7 @@ class Gemv { ...@@ -38,6 +38,7 @@ class Gemv {
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output); Tensor *output);
}; };
......
...@@ -28,13 +28,14 @@ namespace test { ...@@ -28,13 +28,14 @@ namespace test {
void TestGemvFloat32(const index_t batch, void TestGemvFloat32(const index_t batch,
const index_t height, const index_t height,
const index_t width, const index_t width,
const bool lhs_batched) { const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT); Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT); Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor bias(GetCPUAllocator(), DataType::DT_FLOAT); Tensor bias(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor output(GetCPUAllocator(), DataType::DT_FLOAT); Tensor output(GetCPUAllocator(), DataType::DT_FLOAT);
lhs.Resize({lhs_batched ? batch : 1, height, width}); lhs.Resize({lhs_batched ? batch : 1, height, width});
rhs.Resize({batch, width}); rhs.Resize({rhs_batched ? batch : 1, width});
bias.Resize({height}); bias.Resize({height});
output.Resize({batch, height}); output.Resize({batch, height});
{ {
...@@ -57,6 +58,7 @@ void TestGemvFloat32(const index_t batch, ...@@ -57,6 +58,7 @@ void TestGemvFloat32(const index_t batch,
height, height,
width, width,
lhs_batched, lhs_batched,
rhs_batched,
&output); &output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT); Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT);
...@@ -70,28 +72,22 @@ void TestGemvFloat32(const index_t batch, ...@@ -70,28 +72,22 @@ void TestGemvFloat32(const index_t batch,
height, height,
width, width,
lhs_batched, lhs_batched,
rhs_batched,
&expected_output); &expected_output);
Tensor::MappingGuard output_guard(&output); ExpectTensorNear<float>(expected_output, output);
Tensor::MappingGuard expected_guard(&expected_output);
const float *output_data = output.data<float>();
const float *expected_data = expected_output.data<float>();
for (index_t i = 0; i < output.size(); ++i) {
EXPECT_NEAR(expected_data[i], output_data[i], 0.001);
}
} }
TEST(ArmGemv, TestGemvFloat32) { TEST(ArmGemv, TestGemvFloat32) {
TestGemvFloat32(1, 16, 4, true); TestGemvFloat32(1, 16, 4, true, true);
TestGemvFloat32(1, 16, 256, true); TestGemvFloat32(1, 16, 256, true, true);
TestGemvFloat32(2, 16, 256, true); TestGemvFloat32(2, 16, 256, true, true);
TestGemvFloat32(3, 63, 257, true); TestGemvFloat32(3, 63, 257, true, true);
TestGemvFloat32(1, 16, 4, false); TestGemvFloat32(2, 16, 256, false, true);
TestGemvFloat32(1, 16, 256, false); TestGemvFloat32(3, 63, 257, false, true);
TestGemvFloat32(2, 16, 256, false); TestGemvFloat32(2, 16, 256, true, false);
TestGemvFloat32(3, 63, 257, false); TestGemvFloat32(3, 63, 257, true, false);
} }
} // namespace test } // namespace test
......
...@@ -43,6 +43,7 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context, ...@@ -43,6 +43,7 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output) { Tensor *output) {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -100,7 +101,8 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context, ...@@ -100,7 +101,8 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
*lhs_ptr = lhs_data *lhs_ptr = lhs_data
+ static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width + static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ lhs_width * h_block_idx * h_block_size; + lhs_width * h_block_idx * h_block_size;
const uint8_t *rhs_ptr = rhs_data + b * lhs_width; const uint8_t *rhs_ptr =
rhs_data + static_cast<index_t>(rhs_batched) * b * lhs_width;
OUTPUT_TYPE OUTPUT_TYPE
*ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size; *ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size;
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// This implements matrix-vector multiplication described as
// https://github.com/google/gemmlowp/blob/master/todo/fast-gemv.txt
#ifndef MACE_OPS_ARM_Q8_GEMV_H_ #ifndef MACE_OPS_ARM_Q8_GEMV_H_
#define MACE_OPS_ARM_Q8_GEMV_H_ #define MACE_OPS_ARM_Q8_GEMV_H_
...@@ -39,6 +42,7 @@ class Gemv { ...@@ -39,6 +42,7 @@ class Gemv {
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output); Tensor *output);
}; };
......
...@@ -28,7 +28,8 @@ namespace test { ...@@ -28,7 +28,8 @@ namespace test {
void TestGemvInt32(const index_t batch, void TestGemvInt32(const index_t batch,
const index_t height, const index_t height,
const index_t width, const index_t width,
const bool lhs_batched) { const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8); Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8); Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor bias(GetCPUAllocator(), DataType::DT_INT32); Tensor bias(GetCPUAllocator(), DataType::DT_INT32);
...@@ -38,7 +39,7 @@ void TestGemvInt32(const index_t batch, ...@@ -38,7 +39,7 @@ void TestGemvInt32(const index_t batch,
lhs.SetZeroPoint(0); lhs.SetZeroPoint(0);
rhs.SetZeroPoint(0); rhs.SetZeroPoint(0);
lhs.Resize({lhs_batched ? batch : 1, height, width}); lhs.Resize({lhs_batched ? batch : 1, height, width});
rhs.Resize({batch, width}); rhs.Resize({rhs_batched ? batch : 1, batch, width});
bias.Resize({height}); bias.Resize({height});
output.Resize({batch, height}); output.Resize({batch, height});
{ {
...@@ -62,6 +63,7 @@ void TestGemvInt32(const index_t batch, ...@@ -62,6 +63,7 @@ void TestGemvInt32(const index_t batch,
height, height,
width, width,
lhs_batched, lhs_batched,
rhs_batched,
&output); &output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32); Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32);
...@@ -75,6 +77,7 @@ void TestGemvInt32(const index_t batch, ...@@ -75,6 +77,7 @@ void TestGemvInt32(const index_t batch,
height, height,
width, width,
lhs_batched, lhs_batched,
rhs_batched,
&expected_output); &expected_output);
Tensor::MappingGuard output_guard(&output); Tensor::MappingGuard output_guard(&output);
...@@ -90,7 +93,8 @@ void TestGemvInt32(const index_t batch, ...@@ -90,7 +93,8 @@ void TestGemvInt32(const index_t batch,
void TestGemvUint8(const index_t batch, void TestGemvUint8(const index_t batch,
const index_t height, const index_t height,
const index_t width, const index_t width,
const bool lhs_batched) { const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8); Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8); Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor bias(GetCPUAllocator(), DataType::DT_INT32); Tensor bias(GetCPUAllocator(), DataType::DT_INT32);
...@@ -127,6 +131,7 @@ void TestGemvUint8(const index_t batch, ...@@ -127,6 +131,7 @@ void TestGemvUint8(const index_t batch,
height, height,
width, width,
lhs_batched, lhs_batched,
rhs_batched,
&output); &output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32); Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32);
...@@ -142,6 +147,7 @@ void TestGemvUint8(const index_t batch, ...@@ -142,6 +147,7 @@ void TestGemvUint8(const index_t batch,
height, height,
width, width,
lhs_batched, lhs_batched,
rhs_batched,
&expected_output); &expected_output);
Tensor::MappingGuard output_guard(&output); Tensor::MappingGuard output_guard(&output);
...@@ -155,27 +161,27 @@ void TestGemvUint8(const index_t batch, ...@@ -155,27 +161,27 @@ void TestGemvUint8(const index_t batch,
} }
TEST(ArmGemv, TestGemvInt32) { TEST(ArmGemv, TestGemvInt32) {
TestGemvInt32(1, 16, 4, true); TestGemvInt32(1, 16, 4, true, true);
TestGemvInt32(1, 16, 256, true); TestGemvInt32(1, 16, 256, true, true);
TestGemvInt32(2, 16, 256, true); TestGemvInt32(2, 16, 256, true, true);
TestGemvInt32(3, 63, 257, true); TestGemvInt32(3, 63, 257, true, true);
TestGemvInt32(1, 16, 4, false); TestGemvInt32(2, 16, 256, false, true);
TestGemvInt32(1, 16, 256, false); TestGemvInt32(3, 63, 257, false, true);
TestGemvInt32(2, 16, 256, false); TestGemvInt32(2, 16, 256, true, false);
TestGemvInt32(3, 63, 257, false); TestGemvInt32(3, 63, 257, true, false);
} }
TEST(ArmGemv, TestGemvUint8) { TEST(ArmGemv, TestGemvUint8) {
TestGemvUint8(1, 16, 4, true); TestGemvUint8(1, 16, 4, true, true);
TestGemvUint8(1, 16, 256, true); TestGemvUint8(1, 16, 256, true, true);
TestGemvUint8(2, 16, 256, true); TestGemvUint8(2, 16, 256, true, true);
TestGemvUint8(3, 63, 257, true); TestGemvUint8(3, 63, 257, true, true);
TestGemvUint8(1, 16, 4, false); TestGemvUint8(2, 16, 256, false, true);
TestGemvUint8(1, 16, 256, false); TestGemvUint8(3, 63, 257, false, true);
TestGemvUint8(2, 16, 256, false); TestGemvUint8(2, 16, 256, true, false);
TestGemvUint8(3, 63, 257, false); TestGemvUint8(3, 63, 257, true, false);
} }
} // namespace test } // namespace test
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_COMMON_MATRIX_H_
#define MACE_OPS_COMMON_MATRIX_H_
namespace mace {
namespace ops {
enum MatrixMajor {
RowMajor,
ColMajor
};
inline MatrixMajor TransposeMatrixMajor(const MatrixMajor src_major) {
return src_major == RowMajor ? ColMajor : RowMajor;
}
template<typename T>
class MatrixMap {
public:
MatrixMap()
: data_(nullptr),
matrix_major_(RowMajor),
rows_(0),
cols_(0),
stride_(0) {}
MatrixMap(T *data,
const MatrixMajor matrix_major,
const index_t rows,
const index_t cols) :
data_(data),
matrix_major_(matrix_major),
rows_(rows),
cols_(cols),
stride_(matrix_major == ColMajor ? rows : cols) {}
MatrixMap(T *data,
const MatrixMajor matrix_major,
const index_t rows,
const index_t cols,
const index_t stride) :
data_(data),
matrix_major_(matrix_major),
rows_(rows),
cols_(cols),
stride_(stride) {}
MatrixMap(const MatrixMap &other)
: data_(other.data_),
matrix_major_(other.matrix_major_),
rows_(other.rows_),
cols_(other.cols_),
stride_(other.stride_) {}
MatrixMajor matrix_major() const { return matrix_major_; }
index_t rows() const { return rows_; }
index_t cols() const { return cols_; }
index_t stride() const { return stride_; }
int rows_stride() const {
return matrix_major_ == MatrixMajor::ColMajor ? 1 : stride_;
}
int cols_stride() const {
return matrix_major_ == MatrixMajor::RowMajor ? 1 : stride_;
}
index_t size() const { return rows_ * cols_; }
T *data() const { return data_; }
T *data(int rows, int cols) const {
return data_ + rows * rows_stride() + cols * cols_stride();
}
T &operator()(int row, int col) const { return *data(row, col); }
MatrixMap block(int start_row, int start_col, int block_rows,
int block_cols) const {
MACE_CHECK(start_row >= 0);
MACE_CHECK(start_row + block_rows <= rows_);
MACE_CHECK(start_col >= 0);
MACE_CHECK(start_col + block_cols <= cols_);
return MatrixMap(data(start_row, start_col),
matrix_major_,
block_rows,
block_cols,
stride_);
}
private:
T *data_;
MatrixMajor matrix_major_;
index_t rows_;
index_t cols_;
index_t stride_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_COMMON_MATRIX_H_
此差异已折叠。
...@@ -102,6 +102,7 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase { ...@@ -102,6 +102,7 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase {
output_size, output_size,
input_size, input_size,
false, false,
true,
output); output);
Tensor::MappingGuard guard_output(output); Tensor::MappingGuard guard_output(output);
float *output_ptr = output->mutable_data<float>(); float *output_ptr = output->mutable_data<float>();
...@@ -162,6 +163,7 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t> ...@@ -162,6 +163,7 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t>
output_size, output_size,
input_size, input_size,
false, false,
true,
output); output);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
#ifdef MACE_ENABLE_NEON #ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/arm/fp32/gemv.h" #include "mace/ops/arm/fp32/gemv.h"
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
#else #else
#include "mace/ops/ref/gemm.h"
#include "mace/ops/ref/gemv.h" #include "mace/ops/ref/gemv.h"
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
...@@ -58,35 +59,45 @@ class MatMulOpBase : public Operation { ...@@ -58,35 +59,45 @@ class MatMulOpBase : public Operation {
inline void Validate() { inline void Validate() {
const Tensor *A = this->Input(INPUT_A); const Tensor *A = this->Input(INPUT_A);
const Tensor *B = this->Input(INPUT_B); const Tensor *B = this->Input(INPUT_B);
MACE_CHECK(A->dim_size() == B->dim_size() && A->dim_size() >= 2, const index_t lhs_rank = A->dim_size();
"rank(A) should be equal to rank(B), rank should be greater " const index_t rhs_rank = B->dim_size();
"than or equal to 2");
index_t rank = A->dim_size(); MACE_CHECK(lhs_rank >= 2 && rhs_rank >= 2,
for (index_t i = 0; i < rank - 2; ++i) { "rank should be greater than or equal to 2");
MACE_CHECK(A->dim(i) == B->dim(i), if (lhs_rank == rhs_rank) {
"batch dimensions are not equal: ", for (index_t i = 0; i < A->dim_size() - 2; ++i) {
A->dim(i), MACE_CHECK(A->dim(i) == B->dim(i),
" vs. ", "batch dimensions are not equal: ",
B->dim(i)); A->dim(i),
" vs. ",
B->dim(i));
}
} else {
MACE_CHECK(lhs_rank == 2 || rhs_rank == 2,
"Either lhs or rhs matrix should has rank 2 "
"for non-batched matrix multiplication");
} }
index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1);
index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2); index_t
MACE_CHECK(ak == bk, "the number of A's column ", ak, lhs_depth = transpose_a_ ? A->dim(lhs_rank - 2) : A->dim(lhs_rank - 1);
" must be equal to B's row ", bk); index_t
rhs_depth = transpose_b_ ? B->dim(rhs_rank - 1) : B->dim(rhs_rank - 2);
MACE_CHECK(lhs_depth == rhs_depth, "the number of A's column ", lhs_depth,
" must be equal to B's row ", rhs_depth);
} }
protected: protected:
MACE_OP_INPUT_TAGS(INPUT_A, INPUT_B); MACE_OP_INPUT_TAGS(INPUT_A, INPUT_B, BIAS);
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
bool transpose_a_; bool transpose_a_;
bool transpose_b_; bool transpose_b_;
}; };
template <DeviceType D, class T> template<DeviceType D, class T>
class MatMulOp; class MatMulOp;
template <> template<>
class MatMulOp<CPU, float> : public MatMulOpBase { class MatMulOp<CPU, float> : public MatMulOpBase {
public: public:
explicit MatMulOp(OpConstructContext *context) explicit MatMulOp(OpConstructContext *context)
...@@ -94,72 +105,116 @@ class MatMulOp<CPU, float> : public MatMulOpBase { ...@@ -94,72 +105,116 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
Validate(); Validate();
const Tensor *A = this->Input(INPUT_A); const Tensor *lhs = this->Input(INPUT_A);
const Tensor *B = this->Input(INPUT_B); const Tensor *rhs = this->Input(INPUT_B);
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *C = this->Output(OUTPUT); Tensor *C = this->Output(OUTPUT);
index_t batch; const index_t lhs_rank = lhs->dim_size();
index_t height; const index_t lhs_rows = lhs->dim(lhs_rank - 2);
index_t K; const index_t lhs_cols = lhs->dim(lhs_rank - 1);
index_t width; const index_t rhs_rank = rhs->dim_size();
const index_t rhs_rows = rhs->dim(rhs_rank - 2);
index_t rank = A->dim_size(); const index_t rhs_cols = rhs->dim(rhs_rank - 1);
height = A->dim(rank - 2);
K = A->dim(rank - 1); const index_t rows = transpose_a_ ? lhs_cols : lhs_rows;
if (transpose_a_) { const index_t cols = transpose_b_ ? rhs_rows : rhs_cols;
std::swap(height, K); const index_t depth = transpose_a_ ? lhs_rows : lhs_cols;
} const index_t
if (transpose_b_) { lhs_batch =
width = B->dim(rank - 2); std::accumulate(lhs->shape().begin(), lhs->shape().end() - 2, 1,
std::multiplies<index_t>());
const index_t
rhs_batch =
std::accumulate(rhs->shape().begin(), rhs->shape().end() - 2, 1,
std::multiplies<index_t>());
index_t batch = 1;
std::vector<index_t> output_shape;
if (lhs_rank >= rhs_rank) {
output_shape = lhs->shape();
output_shape[lhs_rank - 2] = rows;
output_shape[lhs_rank - 1] = cols;
batch = lhs_batch;
} else { } else {
width = B->dim(rank - 1); output_shape = rhs->shape();
output_shape[rhs_rank - 2] = rows;
output_shape[rhs_rank - 1] = cols;
batch = rhs_batch;
}
bool lhs_batched = true;
bool rhs_batched = true;
if (lhs_rank < rhs_rank) {
lhs_batched = false;
} else if (rhs_rank < lhs_rank) {
rhs_batched = false;
} }
batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
std::vector<index_t> c_shape = A->shape();
c_shape[rank - 2] = height;
c_shape[rank - 1] = width;
MACE_RETURN_IF_ERROR(C->Resize(c_shape));
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
const float *a_ptr_base = A->data<float>();
const float *b_ptr_base = B->data<float>();
float *c_ptr_base = C->mutable_data<float>();
const index_t height_a = A->dim(rank - 2);
const index_t width_a = A->dim(rank - 1);
const index_t height_b = B->dim(rank - 2);
const index_t width_b = B->dim(rank - 1);
auto scratch_buffer = context->device()->scratch_buffer();
scratch_buffer->Rewind();
sgemm_.Run(a_ptr_base,
b_ptr_base,
batch,
height_a,
width_a,
height_b,
width_b,
transpose_a_,
transpose_b_,
A->is_weight(),
B->is_weight(),
c_ptr_base,
context->device()->scratch_buffer());
return MaceStatus::MACE_SUCCESS; MACE_RETURN_IF_ERROR(C->Resize(output_shape));
if (rows == 1 && transpose_b_) {
return gemv_.Compute(context,
rhs,
lhs,
bias,
batch,
cols,
depth,
rhs_batched,
lhs_batched,
C);
} else if (cols == 1 && !transpose_a_) {
return gemv_.Compute(context,
lhs,
rhs,
bias,
batch,
rows,
depth,
lhs_batched,
rhs_batched,
C);
} else {
context->device()->scratch_buffer()->Rewind();
MaceStatus ret = gemm_.Compute(context,
lhs,
rhs,
batch,
lhs_rows,
lhs_cols,
rhs_rows,
rhs_cols,
transpose_a_,
transpose_b_,
false,
lhs_batched,
rhs_batched,
C);
if (bias != nullptr) {
MACE_CHECK(bias->dim_size() == 1 && bias->dim(0) == cols,
"bias' dim should be <= 2.");
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard c_guard(C);
const float *bias_data = bias->data<float>();
float *c_data = C->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < batch * rows; ++i) {
for (index_t w = 0; w < cols; ++w) {
c_data[i * cols + w] += bias_data[w];
}
}
}
return ret;
}
} }
private: private:
SGemm sgemm_;
#ifdef MACE_ENABLE_NEON #ifdef MACE_ENABLE_NEON
arm::fp32::Gemm gemm_;
arm::fp32::Gemv gemv_; arm::fp32::Gemv gemv_;
#else #else
ref::Gemv<float> gemv_; ref::Gemv<float> gemv_;
ref::Gemm<float> gemm_;
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
}; };
...@@ -174,18 +229,36 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> { ...@@ -174,18 +229,36 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
void operator()(OpContext *context, void operator()(OpContext *context,
const Tensor *A, const Tensor *A,
const Tensor *B, const Tensor *B,
const index_t batch,
const index_t height, const index_t height,
const index_t K, const index_t K,
const index_t width, const index_t width,
const bool lhs_bached,
const bool rhs_bached,
Tensor *C) { Tensor *C) {
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) {
gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C); gemv_kernel_.Compute(context,
A,
B,
nullptr,
batch,
height,
K,
true,
true,
C);
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C); gemv_kernel_.Compute(context,
B,
A,
nullptr,
batch,
width,
K,
true,
true,
C);
} else { } else {
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
Tensor::MappingGuard guarda(A); Tensor::MappingGuard guarda(A);
...@@ -208,9 +281,13 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> { ...@@ -208,9 +281,13 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
for (index_t i = 0; i < batch; ++i) { for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder> gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K); a_matrix(a_ptr_base + static_cast<index_t>(lhs_bached) * i * a_size,
height,
K);
gemmlowp::MatrixMap<const uint8_t, BOrder> gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width); b_matrix(b_ptr_base + static_cast<index_t>(rhs_bached) * i * b_size,
K,
width);
gemmlowp::MatrixMap <uint8_t, gemmlowp::MapOrder::RowMajor> gemmlowp::MatrixMap <uint8_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width); c_matrix(c_ptr_base + i * c_size, height, width);
...@@ -234,20 +311,39 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -234,20 +311,39 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
void operator()(OpContext *context, void operator()(OpContext *context,
const Tensor *A, const Tensor *A,
const Tensor *B, const Tensor *B,
const index_t batch,
const index_t height, const index_t height,
const index_t K, const index_t K,
const index_t width, const index_t width,
const bool lhs_bached,
const bool rhs_bached,
Tensor *C) { Tensor *C) {
C->SetScale(A->scale() * B->scale()); C->SetScale(A->scale() * B->scale());
C->SetZeroPoint(0); C->SetZeroPoint(0);
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) {
gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C); gemv_kernel_.Compute(context,
A,
B,
nullptr,
batch,
height,
K,
lhs_bached,
rhs_bached,
C);
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C); gemv_kernel_.Compute(context,
B,
A,
nullptr,
batch,
width,
K,
lhs_bached,
rhs_bached,
C);
} else { } else {
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
Tensor::MappingGuard guarda(A); Tensor::MappingGuard guarda(A);
...@@ -257,7 +353,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -257,7 +353,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
auto b_ptr_base = B->data<uint8_t>(); auto b_ptr_base = B->data<uint8_t>();
auto c_ptr_base = C->mutable_data<int32_t>(); auto c_ptr_base = C->mutable_data<int32_t>();
auto auto
gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); gemm_context =
context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context); MACE_CHECK_NOTNULL(gemm_context);
index_t a_size = height * K; index_t a_size = height * K;
...@@ -268,9 +365,15 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -268,9 +365,15 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
for (index_t i = 0; i < batch; ++i) { for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder> gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K); a_matrix
(a_ptr_base + static_cast<index_t>(lhs_bached) * i * a_size,
height,
K);
gemmlowp::MatrixMap<const uint8_t, BOrder> gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width); b_matrix
(b_ptr_base + static_cast<index_t>(rhs_bached) * i * b_size,
K,
width);
gemmlowp::MatrixMap <int32_t, gemmlowp::MapOrder::RowMajor> gemmlowp::MatrixMap <int32_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width); c_matrix(c_ptr_base + i * c_size, height, width);
...@@ -280,7 +383,6 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -280,7 +383,6 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
-B->zero_point(), output_pipeline); -B->zero_point(), output_pipeline);
} }
} }
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
} }
...@@ -289,44 +391,65 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -289,44 +391,65 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
}; };
template <> template<>
class MatMulOp<DeviceType::CPU, uint8_t>: public MatMulOpBase { class MatMulOp<DeviceType::CPU, uint8_t> : public MatMulOpBase {
public: public:
explicit MatMulOp(OpConstructContext *context) explicit MatMulOp(OpConstructContext *context)
: MatMulOpBase(context) {} : MatMulOpBase(context) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
Validate(); Validate();
const Tensor *A = this->Input(INPUT_A); const Tensor *lhs = this->Input(INPUT_A);
const Tensor *B = this->Input(INPUT_B); const Tensor *rhs = this->Input(INPUT_B);
Tensor *C = this->Output(OUTPUT); Tensor *C = this->Output(OUTPUT);
index_t rank = A->dim_size(); const index_t lhs_rank = lhs->dim_size();
index_t height = A->dim(rank - 2); const index_t lhs_rows = lhs->dim(lhs_rank - 2);
index_t K = A->dim(rank - 1); const index_t lhs_cols = lhs->dim(lhs_rank - 1);
index_t width; const index_t rhs_rank = rhs->dim_size();
const index_t rhs_rows = rhs->dim(rhs_rank - 2);
if (transpose_a_) { const index_t rhs_cols = rhs->dim(rhs_rank - 1);
std::swap(height, K);
} const index_t rows = transpose_a_ ? lhs_cols : lhs_rows;
if (transpose_b_) { const index_t cols = transpose_b_ ? rhs_rows : rhs_cols;
width = B->dim(rank - 2); const index_t depth = transpose_a_ ? lhs_rows : lhs_cols;
const index_t
lhs_batch =
std::accumulate(lhs->shape().begin(), lhs->shape().end() - 2, 1,
std::multiplies<index_t>());
const index_t
rhs_batch =
std::accumulate(rhs->shape().begin(), rhs->shape().end() - 2, 1,
std::multiplies<index_t>());
index_t batch = 1;
std::vector<index_t> output_shape;
if (lhs_rank >= rhs_rank) {
output_shape = lhs->shape();
output_shape[lhs_rank - 2] = rows;
output_shape[lhs_rank - 1] = cols;
batch = lhs_batch;
} else { } else {
width = B->dim(rank - 1); output_shape = rhs->shape();
output_shape[rhs_rank - 2] = rows;
output_shape[rhs_rank - 1] = cols;
batch = rhs_batch;
}
bool lhs_batched = true;
bool rhs_batched = true;
if (lhs_rank < rhs_rank) {
lhs_batched = false;
} else if (rhs_rank < lhs_rank) {
rhs_batched = false;
} }
std::vector<index_t> c_shape = A->shape(); MACE_RETURN_IF_ERROR(C->Resize(output_shape));
c_shape[rank - 2] = height;
c_shape[rank - 1] = width;
MACE_RETURN_IF_ERROR(C->Resize(c_shape));
constexpr gemmlowp::MapOrder kRowMajor = gemmlowp::MapOrder::RowMajor; constexpr gemmlowp::MapOrder kRowMajor = gemmlowp::MapOrder::RowMajor;
constexpr gemmlowp::MapOrder kColMajor = gemmlowp::MapOrder::ColMajor; constexpr gemmlowp::MapOrder kColMajor = gemmlowp::MapOrder::ColMajor;
#define MATMUL_FIXPOINT_IMPL(AOrder, BOrder, OutType) \ #define MATMUL_FIXPOINT_IMPL(AOrder, BOrder, OutType) \
MatMulFixpointImpl<AOrder, BOrder, OutType>()( \ MatMulFixpointImpl<AOrder, BOrder, OutType>()( \
context, A, B, height, K, width, C); context, lhs, rhs, batch, rows, depth, cols, lhs_batched, rhs_batched, C);
#define MATMUL_FIXPOINT_IMPL_TRANSPOSE_OR_NOT(OutType) \ #define MATMUL_FIXPOINT_IMPL_TRANSPOSE_OR_NOT(OutType) \
if (transpose_a_) { \ if (transpose_a_) { \
...@@ -380,7 +503,6 @@ class MatMulOp<DeviceType::GPU, T> : public MatMulOpBase { ...@@ -380,7 +503,6 @@ class MatMulOp<DeviceType::GPU, T> : public MatMulOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterMatMul(OpRegistryBase *op_registry) { void RegisterMatMul(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp, MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -101,11 +101,14 @@ void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) { ...@@ -101,11 +101,14 @@ void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) {
std::vector<float> rhs(k * n); std::vector<float> rhs(k * n);
std::vector<float> result(m * n); std::vector<float> result(m * n);
ops::MatrixMap<const float> matrix_lhs(1, m, k, RowMajor, lhs.data(), ops::SGemmMatrixMap<const float>
true); matrix_lhs(1, m, k, SGemmRowMajor, lhs.data(),
ops::MatrixMap<const float> matrix_rhs(1, k, n, RowMajor, rhs.data(), true);
true); ops::SGemmMatrixMap<const float>
ops::MatrixMap<float> matrix_result(1, m, n, RowMajor, result.data()); matrix_rhs(1, k, n, SGemmRowMajor, rhs.data(),
true);
ops::SGemmMatrixMap<float>
matrix_result(1, m, n, SGemmRowMajor, result.data());
ops::SGemm sgemm; ops::SGemm sgemm;
...@@ -395,6 +398,7 @@ void MatMulTransposeBenchmark( ...@@ -395,6 +398,7 @@ void MatMulTransposeBenchmark(
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU); MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU);
MACE_BM_MATMUL_OP(1, 30000, 256, 1); MACE_BM_MATMUL_OP(1, 30000, 256, 1);
MACE_BM_MATMUL_OP(1, 128, 256, 128);
MACE_BM_MATMUL_OP(2, 128, 128, 49); MACE_BM_MATMUL_OP(2, 128, 128, 49);
MACE_BM_MATMUL_OP(3, 128, 128, 49); MACE_BM_MATMUL_OP(3, 128, 128, 49);
MACE_BM_MATMUL_OP(4, 128, 128, 49); MACE_BM_MATMUL_OP(4, 128, 128, 49);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <fstream> #include <fstream>
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "mace/ops/ref/gemm.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -71,34 +72,121 @@ TEST_F(MatMulOpTest, SimpleCPUWithBatch) { ...@@ -71,34 +72,121 @@ TEST_F(MatMulOpTest, SimpleCPUWithBatch) {
} }
namespace { namespace {
void QuantOutputUint8(const std::vector<index_t> &batch,
const index_t height, template<DeviceType D>
const index_t channels, void Complex(const std::vector<index_t> &batch,
const index_t out_width, const index_t rows,
const bool transpose_a, const index_t depth,
const bool transpose_b) { const index_t cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool lhs_batched,
const bool rhs_batched) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
index_t lhs_rows = transpose_lhs ? depth : rows;
index_t lhs_cols = transpose_lhs ? rows : depth;
index_t rhs_rows = transpose_rhs ? cols : depth;
index_t rhs_cols = transpose_rhs ? depth : cols;
std::vector<index_t> lhs_shape = {lhs_rows, lhs_cols};
std::vector<index_t> rhs_shape = {rhs_rows, rhs_cols};
if (lhs_batched) {
lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end());
}
if (rhs_batched) {
rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end());
}
net.AddRandomInput<CPU, float>("A", lhs_shape);
net.AddRandomInput<CPU, float>("B", rhs_shape);
OpDefBuilder("MatMul", "MatMulTest")
.Input("A")
.AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("B")
.AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("Output")
.AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
ref::Gemm<float> gemm;
Tensor expected_output_tensor;
std::vector<index_t> expected_output_shape({rows, cols});
expected_output_shape.insert(expected_output_shape.begin(),
batch.begin(),
batch.end());
expected_output_tensor.Resize(expected_output_shape);
index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1, index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1,
std::multiplies<index_t>()); std::multiplies<index_t>());
if (transpose_a) { gemm.Compute(nullptr,
net.AddRandomInput<CPU, float>("A", {batch_count, channels, height}); net.GetTensor("A"),
} else { net.GetTensor("B"),
net.AddRandomInput<CPU, float>("A", {batch_count, height, channels}); batch_count,
lhs_rows,
lhs_cols,
rhs_rows,
rhs_cols,
transpose_lhs,
transpose_rhs,
false,
lhs_batched,
rhs_batched,
&expected_output_tensor);
ExpectTensorNear<float>(expected_output_tensor, *net.GetTensor("Output"));
}
} // namespace
TEST_F(MatMulOpTest, ComplexCPUWithBatch) {
Complex<DeviceType::CPU>({1}, 3, 3, 3, false, false, true, true);
Complex<DeviceType::CPU>({}, 3, 3, 3, false, false, true, true);
Complex<DeviceType::CPU>({16}, 31, 61, 67, false, true, true, true);
Complex<DeviceType::CPU>({31}, 31, 61, 67, true, false, true, true);
Complex<DeviceType::CPU>({2, 3}, 31, 61, 67, true, true, true, true);
Complex<DeviceType::CPU>({1}, 1, 30001, 253, false, true, true, true);
Complex<DeviceType::CPU>({2}, 253, 300, 1, false, false, true, true);
// test one-side batched
Complex<DeviceType::CPU>({2, 3}, 31, 61, 67, true, true, false, true);
Complex<DeviceType::CPU>({2, 3}, 31, 61, 67, true, true, true, false);
Complex<DeviceType::CPU>({2, 3}, 31, 61, 67, true, true, false, true);
}
namespace {
void QuantOutputUint8(const std::vector<index_t> &batch,
const index_t rows,
const index_t depth,
const index_t cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool lhs_batched = true,
const bool rhs_batched = true) {
// Construct graph
OpsTestNet net;
// Add input data
// Add input data
index_t lhs_rows = transpose_lhs ? depth : rows;
index_t lhs_cols = transpose_lhs ? rows : depth;
index_t rhs_rows = transpose_rhs ? cols : depth;
index_t rhs_cols = transpose_rhs ? depth : cols;
std::vector<index_t> lhs_shape = {lhs_rows, lhs_cols};
std::vector<index_t> rhs_shape = {rhs_rows, rhs_cols};
if (lhs_batched) {
lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end());
} }
if (transpose_b) { if (rhs_batched) {
net.AddRandomInput<CPU, float>("B", {batch_count, out_width, channels}); rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end());
} else {
net.AddRandomInput<CPU, float>("B", {batch_count, channels, out_width});
} }
net.AddRandomInput<CPU, float>("A", lhs_shape);
net.AddRandomInput<CPU, float>("B", rhs_shape);
OpDefBuilder("MatMul", "MatMulTest") OpDefBuilder("MatMul", "MatMulTest")
.Input("A") .Input("A")
.AddIntArg("transpose_a", transpose_a ? 1 : 0) .AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("B") .Input("B")
.AddIntArg("transpose_b", transpose_b ? 1 : 0) .AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("Output") .Output("Output")
.AddIntArg("T", DT_FLOAT) .AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -133,9 +221,9 @@ void QuantOutputUint8(const std::vector<index_t> &batch, ...@@ -133,9 +221,9 @@ void QuantOutputUint8(const std::vector<index_t> &batch,
OpDefBuilder("MatMul", "QuantizeMatMulTest") OpDefBuilder("MatMul", "QuantizeMatMulTest")
.Input("QuantizedA") .Input("QuantizedA")
.AddIntArg("transpose_a", transpose_a ? 1 : 0) .AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("QuantizedB") .Input("QuantizedB")
.AddIntArg("transpose_b", transpose_b ? 1 : 0) .AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("QuantizedOutput") .Output("QuantizedOutput")
.AddIntArg("T", DT_UINT8) .AddIntArg("T", DT_UINT8)
.OutputType({DT_UINT8}) .OutputType({DT_UINT8})
...@@ -161,39 +249,38 @@ void QuantOutputUint8(const std::vector<index_t> &batch, ...@@ -161,39 +249,38 @@ void QuantOutputUint8(const std::vector<index_t> &batch,
} }
void QuantOutputInt32(const std::vector<index_t> &batch, void QuantOutputInt32(const std::vector<index_t> &batch,
const index_t height, const index_t rows,
const index_t channels, const index_t depth,
const index_t out_width, const index_t cols,
const bool transpose_a, const bool transpose_lhs,
const bool transpose_b) { const bool transpose_rhs,
const bool lhs_batched = true,
const bool rhs_batched = true) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1, // Add input data
std::multiplies<index_t>()); index_t lhs_rows = transpose_lhs ? depth : rows;
if (transpose_a) { index_t lhs_cols = transpose_lhs ? rows : depth;
net.AddRandomInput<CPU, float>("A", {batch_count, channels, height}, index_t rhs_rows = transpose_rhs ? cols : depth;
false); index_t rhs_cols = transpose_rhs ? depth : cols;
} else { std::vector<index_t> lhs_shape = {lhs_rows, lhs_cols};
net.AddRandomInput<CPU, float>("A", {batch_count, height, channels}, std::vector<index_t> rhs_shape = {rhs_rows, rhs_cols};
false); if (lhs_batched) {
lhs_shape.insert(lhs_shape.begin(), batch.begin(), batch.end());
} }
if (transpose_b) { if (rhs_batched) {
net.AddRandomInput<CPU, float>("B", rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end());
{batch_count, out_width, channels},
false);
} else {
net.AddRandomInput<CPU, float>("B",
{batch_count, channels, out_width},
false);
} }
net.AddRandomInput<CPU, float>("A", lhs_shape);
net.AddRandomInput<CPU, float>("B", rhs_shape);
OpDefBuilder("MatMul", "MatMulTest") OpDefBuilder("MatMul", "MatMulTest")
.Input("A") .Input("A")
.AddIntArg("transpose_a", transpose_a ? 1 : 0) .AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("B") .Input("B")
.AddIntArg("transpose_b", transpose_b ? 1 : 0) .AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("Output") .Output("Output")
.AddIntArg("T", DT_FLOAT) .AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -219,9 +306,9 @@ void QuantOutputInt32(const std::vector<index_t> &batch, ...@@ -219,9 +306,9 @@ void QuantOutputInt32(const std::vector<index_t> &batch,
OpDefBuilder("MatMul", "QuantizeMatMulTest") OpDefBuilder("MatMul", "QuantizeMatMulTest")
.Input("QuantizedA") .Input("QuantizedA")
.AddIntArg("transpose_a", transpose_a ? 1 : 0) .AddIntArg("transpose_a", transpose_lhs ? 1 : 0)
.Input("QuantizedB") .Input("QuantizedB")
.AddIntArg("transpose_b", transpose_b ? 1 : 0) .AddIntArg("transpose_b", transpose_rhs ? 1 : 0)
.Output("QuantizedOutput") .Output("QuantizedOutput")
.AddIntArg("T", DT_UINT8) .AddIntArg("T", DT_UINT8)
.OutputType({DT_INT32}) .OutputType({DT_INT32})
...@@ -256,10 +343,12 @@ TEST_F(MatMulOpTest, QuantOutputUint8) { ...@@ -256,10 +343,12 @@ TEST_F(MatMulOpTest, QuantOutputUint8) {
QuantOutputUint8({1}, 64, 32, 128, true, true); QuantOutputUint8({1}, 64, 32, 128, true, true);
QuantOutputUint8({2, 3}, 64, 32, 128, true, true); QuantOutputUint8({2, 3}, 64, 32, 128, true, true);
// UnAligned // UnAligned
QuantOutputUint8({2}, 3, 3, 3, false, false);
QuantOutputUint8({16}, 31, 61, 67, false, true); QuantOutputUint8({16}, 31, 61, 67, false, true);
QuantOutputUint8({31}, 31, 61, 67, true, false); QuantOutputUint8({31}, 31, 61, 67, true, false);
QuantOutputUint8({2, 3}, 31, 61, 67, true, true); QuantOutputUint8({2, 3}, 31, 61, 67, true, true);
QuantOutputUint8({2, 3}, 31, 61, 67, true, true, true, false);
QuantOutputUint8({2, 3}, 31, 61, 67, true, true, false, true);
} }
TEST_F(MatMulOpTest, QuantOutputInt32) { TEST_F(MatMulOpTest, QuantOutputInt32) {
...@@ -281,12 +370,14 @@ TEST_F(MatMulOpTest, QuantOutputInt32) { ...@@ -281,12 +370,14 @@ TEST_F(MatMulOpTest, QuantOutputInt32) {
QuantOutputInt32({3}, 128, 256, 1, false, false); QuantOutputInt32({3}, 128, 256, 1, false, false);
// UnAligned // UnAligned
QuantOutputInt32({2}, 3, 3, 3, false, false);
QuantOutputInt32({16}, 31, 61, 67, false, true); QuantOutputInt32({16}, 31, 61, 67, false, true);
QuantOutputInt32({31}, 31, 61, 67, true, false); QuantOutputInt32({31}, 31, 61, 67, true, false);
QuantOutputInt32({2, 3}, 31, 61, 67, true, true); QuantOutputInt32({2, 3}, 31, 61, 67, true, true);
QuantOutputInt32({1}, 1, 30001, 253, false, true); QuantOutputInt32({1}, 1, 30001, 253, false, true);
QuantOutputInt32({2}, 253, 300, 1, false, false); QuantOutputInt32({2}, 253, 300, 1, false, false);
QuantOutputInt32({2, 3}, 31, 61, 67, true, true, true, false);
QuantOutputInt32({2, 3}, 31, 61, 67, true, true, false, true);
} }
} // namespace test } // namespace test
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ref/conv_2d.h"
#include <vector>
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
namespace ref {
MaceStatus Conv2d<float>::Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) {
MACE_UNUSED(context);
const std::vector<index_t> in_shape = input->shape();
const std::vector<index_t> filter_shape = filter->shape();
const std::vector<index_t> out_shape = output->shape();
const std::vector<int> stride_hw{stride_h_, stride_w_};
const std::vector<int> dilation_hw{dilation_h_, dilation_w_};
const std::vector<int> paddings{pad_h_, pad_w_};
const index_t pad_top = pad_h_ >> 1;
const index_t pad_left = pad_w_ >> 1;
std::vector<index_t> output_shape(4);
CalcOutputSize(in_shape.data(),
NCHW,
filter_shape.data(),
OIHW,
paddings.data(),
dilation_hw.data(),
stride_hw.data(),
RoundType::FLOOR,
output_shape.data());
output->Resize(output_shape);
const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = filter_shape[1] * in_image_size;
const index_t out_batch_size = filter_shape[0] * out_image_size;
const index_t filter_size = filter_shape[2] * filter_shape[3];
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto output_data = output->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < in_shape[0]; b++) {
for (index_t m = 0; m < filter_shape[0]; ++m) {
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
const index_t in_channels = filter_shape[1];
float *out_ptr_base =
output_data + b * out_batch_size + m * out_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
float sum = 0;
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input_data + b * in_batch_size + c * in_image_size;
const float *filter_ptr =
filter_data + m * in_channels * filter_size + c * filter_size;
for (index_t kh = 0; kh < filter_shape[2]; ++kh) {
for (index_t kw = 0; kw < filter_shape[3]; ++kw) {
const index_t ih = -pad_top + h * stride_h_ + kh * dilation_h_;
const index_t iw = -pad_left + w * stride_w_ + kw * dilation_w_;
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
sum += in_ptr_base[ih * in_width + iw] * filter_ptr[kw];
}
} // kw
filter_ptr += filter_shape[3];
} // kh
} // c
out_ptr_base[h * out_width + w] = sum;
} // w
} // h
} // m
} // b
return MaceStatus::MACE_SUCCESS;
}
} // namespace ref
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REF_CONV_2D_H_
#define MACE_OPS_REF_CONV_2D_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
namespace mace {
namespace ops {
namespace ref {
template<typename OUTPUT_TYPE>
class Conv2d {
public:
Conv2d(int stride_h, int stride_w, int dilation_h, int dilation_w);
~Conv2d() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
};
template<>
class Conv2d<float> {
public:
Conv2d(int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w)
: pad_h_(pad_h),
pad_w_(pad_w),
stride_h_(stride_h),
stride_w_(stride_w),
dilation_h_(dilation_h),
dilation_w_(dilation_w) {}
~Conv2d() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
int pad_h_;
int pad_w_;
int stride_h_;
int stride_w_;
int dilation_h_;
int dilation_w_;
};
} // namespace ref
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REF_CONV_2D_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ref/gemm.h"
namespace mace {
namespace ops {
namespace ref {
MaceStatus Gemm<float>::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
MACE_UNUSED(context);
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard output_guard(output);
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) {
MatrixMap<const float>
lhs_matrix
(lhs_data + static_cast<index_t>(lhs_batched) * b * rows * depth,
lhs_major,
rows,
depth);
MatrixMap<const float>
rhs_matrix
(rhs_data + static_cast<index_t>(rhs_batched) * b * depth * cols,
rhs_major,
depth,
cols);
MatrixMap<float>
output_matrix(output_data + b * rows * cols, output_major, rows, cols);
for (index_t r = 0; r < rows; ++r) {
for (index_t c = 0; c < cols; ++c) {
float sum = 0;
for (index_t d = 0; d < depth; ++d) {
sum += lhs_matrix(r, d) * rhs_matrix(d, c);
} // d
*output_matrix.data(r, c) = sum;
} // c
} // r
} // b
return MaceStatus::MACE_SUCCESS;
}
MaceStatus Gemm<float>::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t lhs_rows,
const index_t lhs_cols,
const index_t rhs_rows,
const index_t rhs_cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool transpose_out,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) {
index_t rows = transpose_lhs ? lhs_cols : lhs_rows;
index_t depth = transpose_lhs ? lhs_rows : lhs_cols;
index_t cols = transpose_rhs ? rhs_rows : rhs_cols;
index_t depth2 = transpose_rhs ? rhs_cols : rhs_rows;
MACE_CHECK(depth == depth2,
"Matrices that multiply have inconsistent depth dim: ",
depth,
" vs. ",
depth2);
return Compute(context,
lhs,
rhs,
batch,
rows,
cols,
depth,
transpose_lhs ? ColMajor : RowMajor,
transpose_rhs ? ColMajor : RowMajor,
transpose_out ? ColMajor : RowMajor,
lhs_batched,
rhs_batched,
output);
}
} // namespace ref
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REF_GEMM_H_
#define MACE_OPS_REF_GEMM_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/matrix.h"
namespace mace {
namespace ops {
namespace ref {
template<typename OUTPUT_TYPE>
class Gemm {
public:
Gemm() {}
~Gemm() {}
MaceStatus Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
template<>
class Gemm<float> {
public:
Gemm() {}
~Gemm() {}
MaceStatus Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
// Original matrix before transpose has row-major
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t lhs_rows,
const index_t lhs_cols,
const index_t rhs_rows,
const index_t rhs_cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool transpose_out,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
};
} // namespace ref
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REF_GEMM_H_
...@@ -31,6 +31,7 @@ MaceStatus Gemv<float>::Compute(const OpContext *context, ...@@ -31,6 +31,7 @@ MaceStatus Gemv<float>::Compute(const OpContext *context,
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output) { Tensor *output) {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -52,9 +53,9 @@ MaceStatus Gemv<float>::Compute(const OpContext *context, ...@@ -52,9 +53,9 @@ MaceStatus Gemv<float>::Compute(const OpContext *context,
float sum = bias ? bias_data[h] : 0; float sum = bias ? bias_data[h] : 0;
for (index_t w = 0; w < lhs_width; ++w) { for (index_t w = 0; w < lhs_width; ++w) {
sum += lhs_data[ sum += lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w] + h * lhs_width + w]
* rhs_data[b * lhs_width + w]; * rhs_data[static_cast<index_t>(rhs_batched) * b * lhs_width + w];
} // w } // w
output_data[b * lhs_height + h] = sum; output_data[b * lhs_height + h] = sum;
...@@ -73,6 +74,7 @@ MaceStatus Gemv<uint8_t>::Compute(const OpContext *context, ...@@ -73,6 +74,7 @@ MaceStatus Gemv<uint8_t>::Compute(const OpContext *context,
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output) { Tensor *output) {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -102,7 +104,8 @@ MaceStatus Gemv<uint8_t>::Compute(const OpContext *context, ...@@ -102,7 +104,8 @@ MaceStatus Gemv<uint8_t>::Compute(const OpContext *context,
sum += (lhs_data[ sum += (lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w] - lhs_zero) + h * lhs_width + w] - lhs_zero)
* (rhs_data[b * lhs_width + w] - rhs_zero); * (rhs_data[static_cast<index_t>(rhs_batched) * b * lhs_width + w]
- rhs_zero);
} // w } // w
output_data[b * lhs_height + h] = output_data[b * lhs_height + h] =
...@@ -120,6 +123,7 @@ MaceStatus Gemv<int32_t>::Compute(const OpContext *context, ...@@ -120,6 +123,7 @@ MaceStatus Gemv<int32_t>::Compute(const OpContext *context,
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output) { Tensor *output) {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -146,7 +150,8 @@ MaceStatus Gemv<int32_t>::Compute(const OpContext *context, ...@@ -146,7 +150,8 @@ MaceStatus Gemv<int32_t>::Compute(const OpContext *context,
sum += (lhs_data[ sum += (lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w] - lhs_zero) + h * lhs_width + w] - lhs_zero)
* (rhs_data[b * lhs_width + w] - rhs_zero); * (rhs_data[static_cast<index_t>(rhs_batched) * b * lhs_width + w]
- rhs_zero);
} // w } // w
output_data[b * lhs_height + h] = sum; output_data[b * lhs_height + h] = sum;
......
...@@ -39,6 +39,7 @@ class Gemv { ...@@ -39,6 +39,7 @@ class Gemv {
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output); Tensor *output);
}; };
...@@ -57,6 +58,7 @@ class Gemv<float> { ...@@ -57,6 +58,7 @@ class Gemv<float> {
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output); Tensor *output);
}; };
...@@ -76,6 +78,7 @@ class Gemv<uint8_t> { ...@@ -76,6 +78,7 @@ class Gemv<uint8_t> {
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output); Tensor *output);
}; };
...@@ -94,6 +97,7 @@ class Gemv<int32_t> { ...@@ -94,6 +97,7 @@ class Gemv<int32_t> {
const index_t lhs_height, const index_t lhs_height,
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched,
Tensor *output); Tensor *output);
}; };
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
......
...@@ -27,39 +27,17 @@ ...@@ -27,39 +27,17 @@
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) #define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif #endif
namespace {
inline void AdviseFree(void *addr, size_t length) {
int page_size = sysconf(_SC_PAGESIZE);
void *addr_aligned =
reinterpret_cast<void *>(
(reinterpret_cast<uintptr_t>(addr) + page_size - 1)
& (~(page_size - 1)));
uintptr_t delta =
reinterpret_cast<uintptr_t>(addr_aligned)
- reinterpret_cast<uintptr_t>(addr);
if (length >= delta + page_size) {
size_t len_aligned = (length - delta) & (~(page_size - 1));
int ret = madvise(addr_aligned, len_aligned, MADV_DONTNEED);
if (ret != 0) {
LOG(ERROR) << "Advise free failed: " << strerror(errno);
}
}
}
} // namespace
namespace mace { namespace mace {
namespace ops { namespace ops {
void SGemm::operator()(const MatrixMap<const float> &lhs, void SGemm::operator()(const SGemmMatrixMap<const float> &lhs,
const MatrixMap<const float> &rhs, const SGemmMatrixMap<const float> &rhs,
MatrixMap<float> *result, SGemmMatrixMap<float> *result,
ScratchBuffer *scratch_buffer) { ScratchBuffer *scratch_buffer) {
if (lhs.is_const() && !rhs.is_const()) { if (lhs.is_const() && !rhs.is_const()) {
MatrixMap<const float> lhs_transpose = lhs.transpose(); SGemmMatrixMap<const float> lhs_transpose = lhs.transpose();
MatrixMap<const float> rhs_transpose = rhs.transpose(); SGemmMatrixMap<const float> rhs_transpose = rhs.transpose();
MatrixMap<float> result_transpose = result->transpose(); SGemmMatrixMap<float> result_transpose = result->transpose();
return operator()(rhs_transpose, return operator()(rhs_transpose,
lhs_transpose, lhs_transpose,
&result_transpose, &result_transpose,
...@@ -150,18 +128,18 @@ void SGemm::Run(const float *A, ...@@ -150,18 +128,18 @@ void SGemm::Run(const float *A,
width_c = height_b; width_c = height_b;
} }
MatrixMap<const float> matrix_a = SGemmMatrixMap<const float> matrix_a =
MatrixMap<const float>(batch, SGemmMatrixMap<const float>(batch,
height_a, height_a,
width_a, width_a,
ops::RowMajor, ops::SGemmRowMajor,
A, A,
is_a_weight); is_a_weight);
MatrixMap<const float> matrix_b = SGemmMatrixMap<const float> matrix_b =
ops::MatrixMap<const float>(batch, ops::SGemmMatrixMap<const float>(batch,
height_b, height_b,
width_b, width_b,
ops::RowMajor, ops::SGemmRowMajor,
B, B,
is_b_weight); is_b_weight);
if (transpose_a) { if (transpose_a) {
...@@ -170,7 +148,8 @@ void SGemm::Run(const float *A, ...@@ -170,7 +148,8 @@ void SGemm::Run(const float *A,
if (transpose_b) { if (transpose_b) {
matrix_b = matrix_b.transpose(); matrix_b = matrix_b.transpose();
} }
MatrixMap<float> matrix_c(batch, height_c, width_c, ops::RowMajor, C); SGemmMatrixMap<float>
matrix_c(batch, height_c, width_c, ops::SGemmRowMajor, C);
operator()(matrix_a, matrix_b, &matrix_c, scratch_buffer); operator()(matrix_a, matrix_b, &matrix_c, scratch_buffer);
} }
...@@ -930,17 +909,17 @@ void SGemm::RunPerBatch(const float *lhs_data, ...@@ -930,17 +909,17 @@ void SGemm::RunPerBatch(const float *lhs_data,
} // bw } // bw
} }
void SGemm::PackLhs(const MatrixMap<const float> &lhs, void SGemm::PackLhs(const SGemmMatrixMap<const float> &lhs,
PackedBlock *packed_block) { PackedBlock *packed_block) {
Pack(lhs, PackOrder::ColMajor, packed_block); Pack(lhs, PackOrder::SGemmColMajor, packed_block);
} }
void SGemm::PackRhs(const MatrixMap<const float> &rhs, void SGemm::PackRhs(const SGemmMatrixMap<const float> &rhs,
PackedBlock *packed_block) { PackedBlock *packed_block) {
Pack(rhs, PackOrder::RowMajor, packed_block); Pack(rhs, PackOrder::SGemmRowMajor, packed_block);
} }
void SGemm::Pack(const MatrixMap<const float> &src, void SGemm::Pack(const SGemmMatrixMap<const float> &src,
const PackOrder order, const PackOrder order,
PackedBlock *packed_block) { PackedBlock *packed_block) {
MACE_CHECK_NOTNULL(packed_block); MACE_CHECK_NOTNULL(packed_block);
...@@ -963,7 +942,7 @@ void SGemm::Pack(const MatrixMap<const float> &src, ...@@ -963,7 +942,7 @@ void SGemm::Pack(const MatrixMap<const float> &src,
} }
void SGemm::UnPack(const PackedBlock &packed_result, void SGemm::UnPack(const PackedBlock &packed_result,
MatrixMap<float> *matrix_map) { SGemmMatrixMap<float> *matrix_map) {
MACE_CHECK_NOTNULL(matrix_map); MACE_CHECK_NOTNULL(matrix_map);
const index_t height = matrix_map->row(); const index_t height = matrix_map->row();
...@@ -984,7 +963,7 @@ void SGemm::UnPack(const PackedBlock &packed_result, ...@@ -984,7 +963,7 @@ void SGemm::UnPack(const PackedBlock &packed_result,
#undef MACE_SGEMM_UNPACK_PER_BATCH #undef MACE_SGEMM_UNPACK_PER_BATCH
} }
void SGemm::PackPerBatch(const MatrixMap<const float> &src, void SGemm::PackPerBatch(const SGemmMatrixMap<const float> &src,
const PackOrder order, const PackOrder order,
const index_t batch_index, const index_t batch_index,
float *packed_data) { float *packed_data) {
...@@ -994,7 +973,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src, ...@@ -994,7 +973,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
const index_t width = src.col(); const index_t width = src.col();
auto src_data = src.batch_data(batch_index); auto src_data = src.batch_data(batch_index);
if (src.map_major() == Major::RowMajor && order == PackOrder::ColMajor) { if (src.map_major() == Major::SGemmRowMajor
&& order == PackOrder::SGemmColMajor) {
// This is for packing no-transpose lhs. // This is for packing no-transpose lhs.
index_t h = 0; index_t h = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
...@@ -1040,8 +1020,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src, ...@@ -1040,8 +1020,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
for (index_t ih = h; ih < height; ++ih) { for (index_t ih = h; ih < height; ++ih) {
std::copy_n(src_data + ih * width, width, packed_data + ih * width); std::copy_n(src_data + ih * width, width, packed_data + ih * width);
} }
} else if (src.map_major() == Major::ColMajor && } else if (src.map_major() == Major::SGemmColMajor &&
order == PackOrder::ColMajor) { order == PackOrder::SGemmColMajor) {
// This is for packing transpose-needed lhs. // This is for packing transpose-needed lhs.
index_t h = 0; index_t h = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
...@@ -1082,8 +1062,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src, ...@@ -1082,8 +1062,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
packed_data_ptr[w] = src_data_ptr[w * height]; packed_data_ptr[w] = src_data_ptr[w * height];
} }
} }
} else if (src.map_major() == Major::RowMajor && } else if (src.map_major() == Major::SGemmRowMajor &&
order == PackOrder::RowMajor) { order == PackOrder::SGemmRowMajor) {
// This is for packing no-transpose rhs. // This is for packing no-transpose rhs.
index_t w = 0; index_t w = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
...@@ -1108,8 +1088,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src, ...@@ -1108,8 +1088,8 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
packed_data_ptr[h] = src_data_ptr[h * width]; packed_data_ptr[h] = src_data_ptr[h * width];
} }
} }
} else if (src.map_major() == Major::ColMajor && } else if (src.map_major() == Major::SGemmColMajor &&
order == PackOrder::RowMajor) { order == PackOrder::SGemmRowMajor) {
// This is for packing transpose-needed rhs. // This is for packing transpose-needed rhs.
index_t w = 0; index_t w = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
...@@ -1138,14 +1118,14 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src, ...@@ -1138,14 +1118,14 @@ void SGemm::PackPerBatch(const MatrixMap<const float> &src,
void SGemm::UnPackPerBatch(const float *packed_data, void SGemm::UnPackPerBatch(const float *packed_data,
const index_t batch_index, const index_t batch_index,
MatrixMap<float> *matrix_map) { SGemmMatrixMap<float> *matrix_map) {
MACE_CHECK_NOTNULL(matrix_map); MACE_CHECK_NOTNULL(matrix_map);
const index_t height = matrix_map->row(); const index_t height = matrix_map->row();
const index_t width = matrix_map->col(); const index_t width = matrix_map->col();
auto unpacked_data = matrix_map->batch_data(batch_index); auto unpacked_data = matrix_map->batch_data(batch_index);
if (matrix_map->map_major() == Major::RowMajor) { if (matrix_map->map_major() == Major::SGemmRowMajor) {
// This is for non-transposed result // This is for non-transposed result
index_t w = 0; index_t w = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// This implementation is deprecated. use mace/ops/arm/fp32/gemm.h instead.
#ifndef MACE_OPS_SGEMM_H_ #ifndef MACE_OPS_SGEMM_H_
#define MACE_OPS_SGEMM_H_ #define MACE_OPS_SGEMM_H_
...@@ -30,16 +32,16 @@ namespace mace { ...@@ -30,16 +32,16 @@ namespace mace {
namespace ops { namespace ops {
enum Major { enum Major {
RowMajor, SGemmRowMajor,
ColMajor SGemmColMajor
}; };
template<typename T> template<typename T>
class MatrixMap { class SGemmMatrixMap {
public: public:
MatrixMap() {} SGemmMatrixMap() {}
MatrixMap(const index_t batch, SGemmMatrixMap(const index_t batch,
const index_t row, const index_t row,
const index_t col, const index_t col,
const Major major, const Major major,
...@@ -48,14 +50,20 @@ class MatrixMap { ...@@ -48,14 +50,20 @@ class MatrixMap {
batch_(batch), batch_(batch),
row_(row), row_(row),
col_(col), col_(col),
stride_(major == RowMajor ? col : row), stride_(major == SGemmRowMajor ? col : row),
major_(major), major_(major),
data_(data), data_(data),
is_const_(is_const) {} is_const_(is_const) {}
MatrixMap transpose() const { SGemmMatrixMap transpose() const {
Major transpose_major = major_ == RowMajor ? ColMajor : RowMajor; Major transpose_major =
return MatrixMap(batch_, col_, row_, transpose_major, data_, is_const_); major_ == SGemmRowMajor ? SGemmColMajor : SGemmRowMajor;
return SGemmMatrixMap(batch_,
col_,
row_,
transpose_major,
data_,
is_const_);
} }
index_t batch() const { index_t batch() const {
...@@ -114,9 +122,9 @@ class SGemm { ...@@ -114,9 +122,9 @@ class SGemm {
packed_rhs_(nullptr), packed_rhs_(nullptr),
packed_(false) {} packed_(false) {}
void operator()(const MatrixMap<const float> &lhs, void operator()(const SGemmMatrixMap<const float> &lhs,
const MatrixMap<const float> &rhs, const SGemmMatrixMap<const float> &rhs,
MatrixMap<float> *result, SGemmMatrixMap<float> *result,
ScratchBuffer *scratch_buffer = nullptr); ScratchBuffer *scratch_buffer = nullptr);
void Run(const float *A, void Run(const float *A,
...@@ -133,28 +141,28 @@ class SGemm { ...@@ -133,28 +141,28 @@ class SGemm {
float *C, float *C,
ScratchBuffer *scratch_buffer = nullptr); ScratchBuffer *scratch_buffer = nullptr);
void PackLhs(const MatrixMap<const float> &lhs, void PackLhs(const SGemmMatrixMap<const float> &lhs,
PackedBlock *packed_block); PackedBlock *packed_block);
void PackRhs(const MatrixMap<const float> &rhs, void PackRhs(const SGemmMatrixMap<const float> &rhs,
PackedBlock *packed_block); PackedBlock *packed_block);
void UnPack(const PackedBlock &packed_result, void UnPack(const PackedBlock &packed_result,
MatrixMap<float> *matrix_map); SGemmMatrixMap<float> *matrix_map);
private: private:
void Pack(const MatrixMap<const float> &src, void Pack(const SGemmMatrixMap<const float> &src,
const PackOrder order, const PackOrder order,
PackedBlock *packed_block); PackedBlock *packed_block);
void PackPerBatch(const MatrixMap<const float> &src, void PackPerBatch(const SGemmMatrixMap<const float> &src,
const PackOrder order, const PackOrder order,
const index_t batch_index, const index_t batch_index,
float *packed_data); float *packed_data);
void UnPackPerBatch(const float *packed_data, void UnPackPerBatch(const float *packed_data,
const index_t batch_index, const index_t batch_index,
MatrixMap<float> *matrix_map); SGemmMatrixMap<float> *matrix_map);
void RunInternal(const PackedBlock &lhs, void RunInternal(const PackedBlock &lhs,
const PackedBlock &rhs, const PackedBlock &rhs,
......
...@@ -31,10 +31,11 @@ void TestPack(const std::vector<float> &data, ...@@ -31,10 +31,11 @@ void TestPack(const std::vector<float> &data,
Major src_order, Major src_order,
PackOrder pack_order) { PackOrder pack_order) {
SGemm sg; SGemm sg;
MatrixMap<const float> src_matrix(1, height, width, src_order, data.data()); SGemmMatrixMap<const float>
src_matrix(1, height, width, src_order, data.data());
PackedBlock packed; PackedBlock packed;
packed.Resize({height, width}); packed.Resize({height, width});
if (pack_order == PackOrder::ColMajor) { if (pack_order == PackOrder::SGemmColMajor) {
sg.PackLhs(src_matrix, &packed); sg.PackLhs(src_matrix, &packed);
} else { } else {
sg.PackRhs(src_matrix, &packed); sg.PackRhs(src_matrix, &packed);
...@@ -57,18 +58,19 @@ void TestUnPack(const index_t height, ...@@ -57,18 +58,19 @@ void TestUnPack(const index_t height,
data[i] = rand_r(&seed); data[i] = rand_r(&seed);
} }
MatrixMap<const float> src_matrix(1, height, width, src_order, data.data()); SGemmMatrixMap<const float>
src_matrix(1, height, width, src_order, data.data());
PackedBlock packed; PackedBlock packed;
packed.Resize({height, width}); packed.Resize({height, width});
SGemm sg; SGemm sg;
if (pack_order == PackOrder::ColMajor) { if (pack_order == PackOrder::SGemmColMajor) {
sg.PackLhs(src_matrix, &packed); sg.PackLhs(src_matrix, &packed);
} else { } else {
sg.PackRhs(src_matrix, &packed); sg.PackRhs(src_matrix, &packed);
} }
std::vector<float> unpacked(matrix_size); std::vector<float> unpacked(matrix_size);
MatrixMap<float> SGemmMatrixMap<float>
unpacked_matrix(1, height, width, src_order, unpacked.data()); unpacked_matrix(1, height, width, src_order, unpacked.data());
sg.UnPack(packed, &unpacked_matrix); sg.UnPack(packed, &unpacked_matrix);
auto unpacked_data = unpacked.data(); auto unpacked_data = unpacked.data();
...@@ -87,78 +89,78 @@ TEST(SGemmPackTest, Pack) { ...@@ -87,78 +89,78 @@ TEST(SGemmPackTest, Pack) {
// For no-transpose lhs // For no-transpose lhs
TestPack(data, TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3, 4, Major::RowMajor, PackOrder::ColMajor); 3, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
TestPack(data, TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}, {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16},
4, 4, Major::RowMajor, PackOrder::ColMajor); 4, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
TestPack(data, TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19, {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19,
20}, 20},
5, 4, Major::RowMajor, PackOrder::ColMajor); 5, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#if defined(__aarch64__) #if defined(__aarch64__)
TestPack(data, TestPack(data,
{1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, {1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11,
15, 19, 23, 27, 31, 4, 8, 12, 16, 20, 24, 28, 32, 33, 34, 35, 36}, 15, 19, 23, 27, 31, 4, 8, 12, 16, 20, 24, 28, 32, 33, 34, 35, 36},
9, 4, Major::RowMajor, PackOrder::ColMajor); 9, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#endif #endif
#endif #endif
// For transpose-needed lhs // For transpose-needed lhs
TestPack(data, TestPack(data,
{1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12}, {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12},
3, 4, Major::ColMajor, PackOrder::ColMajor); 3, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
TestPack(data, TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
4, 4, Major::ColMajor, PackOrder::ColMajor); 4, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
TestPack(data, TestPack(data,
{1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15, {1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15,
20}, 20},
5, 4, Major::ColMajor, PackOrder::ColMajor); 5, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#if defined(__aarch64__) #if defined(__aarch64__)
TestPack(data, TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, {1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21,
22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 9, 18, 27, 36}, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 9, 18, 27, 36},
9, 4, Major::ColMajor, PackOrder::ColMajor); 9, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#endif #endif
#endif #endif
// For no-transpose rhs // For no-transpose rhs
TestPack(data, TestPack(data,
{1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12}, {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12},
4, 3, Major::RowMajor, PackOrder::RowMajor); 4, 3, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
TestPack(data, TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
4, 4, Major::RowMajor, PackOrder::RowMajor); 4, 4, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestPack(data, TestPack(data,
{1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15, {1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15,
20}, 20},
4, 5, Major::RowMajor, PackOrder::RowMajor); 4, 5, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
#endif #endif
// For transpose-needed rhs // For transpose-needed rhs
TestPack(data, TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
4, 3, Major::ColMajor, PackOrder::RowMajor); 4, 3, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
TestPack(data, TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}, {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16},
4, 4, Major::ColMajor, PackOrder::RowMajor); 4, 4, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestPack(data, TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19, {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19,
20}, 20},
4, 5, Major::ColMajor, PackOrder::RowMajor); 4, 5, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
#endif #endif
} }
TEST(SGemmPackTest, UnPack) { TEST(SGemmPackTest, UnPack) {
TestUnPack(4, 3, Major::RowMajor, PackOrder::RowMajor); TestUnPack(4, 3, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 4, Major::RowMajor, PackOrder::RowMajor); TestUnPack(4, 4, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 5, Major::RowMajor, PackOrder::RowMajor); TestUnPack(4, 5, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 100, Major::RowMajor, PackOrder::RowMajor); TestUnPack(4, 100, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 3, Major::ColMajor, PackOrder::RowMajor); TestUnPack(4, 3, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 4, Major::ColMajor, PackOrder::RowMajor); TestUnPack(4, 4, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 5, Major::ColMajor, PackOrder::RowMajor); TestUnPack(4, 5, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 100, Major::ColMajor, PackOrder::RowMajor); TestUnPack(4, 100, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
} }
} // namespace test } // namespace test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册