提交 57a3298d 编写于 作者: 李滨

Merge branch 'pack' into 'master'

Pack matmul to improve performance

See merge request !789
...@@ -85,7 +85,7 @@ ndk_versions_compatible_tests: ...@@ -85,7 +85,7 @@ ndk_versions_compatible_tests:
- DEFAULT_NDK_PATH=$ANDROID_NDK_HOME - DEFAULT_NDK_PATH=$ANDROID_NDK_HOME
- prefix_path=${DEFAULT_NDK_PATH%android-ndk-*} - prefix_path=${DEFAULT_NDK_PATH%android-ndk-*}
- > - >
for ndk in android-ndk-r12b android-ndk-r15c android-ndk-r16 android-ndk-r17b; for ndk in android-ndk-r15c android-ndk-r16 android-ndk-r17b;
do do
new_ndk_path=${prefix_path}${ndk}; new_ndk_path=${prefix_path}${ndk};
if [ "$new_ndk_path" != "$DEFAULT_NDK_PATH" ]; then if [ "$new_ndk_path" != "$DEFAULT_NDK_PATH" ]; then
......
...@@ -399,6 +399,10 @@ class Tensor { ...@@ -399,6 +399,10 @@ class Tensor {
zero_point_ = zero_point; zero_point_ = zero_point;
} }
inline void SetIsWeight(bool is_weight) {
is_weight_ = is_weight;
}
private: private:
Allocator *allocator_; Allocator *allocator_;
DataType dtype_; DataType dtype_;
...@@ -409,7 +413,7 @@ class Tensor { ...@@ -409,7 +413,7 @@ class Tensor {
bool is_buffer_owner_; bool is_buffer_owner_;
bool unused_; bool unused_;
std::string name_; std::string name_;
const bool is_weight_; bool is_weight_;
float scale_; float scale_;
int32_t zero_point_; int32_t zero_point_;
......
...@@ -33,7 +33,8 @@ int main(int argc, char **argv) { ...@@ -33,7 +33,8 @@ int main(int argc, char **argv) {
// config runtime // config runtime
mace::MaceStatus status = mace::SetOpenMPThreadsAndAffinityPolicy( mace::MaceStatus status = mace::SetOpenMPThreadsAndAffinityPolicy(
FLAGS_omp_num_threads, FLAGS_omp_num_threads,
static_cast<mace::CPUAffinityPolicy>(FLAGS_cpu_affinity_policy)); static_cast<mace::CPUAffinityPolicy>(FLAGS_cpu_affinity_policy),
true);
if (status != mace::MACE_SUCCESS) { if (status != mace::MACE_SUCCESS) {
LOG(WARNING) << "Set openmp or cpu affinity failed."; LOG(WARNING) << "Set openmp or cpu affinity failed.";
} }
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector>
#include <memory> #include <memory>
#include <random> #include <random>
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/kernels/gemm.h" #include "mace/kernels/gemm.h"
#include "mace/kernels/sgemm.h"
namespace mace { namespace mace {
...@@ -72,6 +74,74 @@ void GemvTest(index_t batch, index_t N, index_t M) { ...@@ -72,6 +74,74 @@ void GemvTest(index_t batch, index_t N, index_t M) {
} }
} }
void SGemmTest(index_t batch,
index_t N,
index_t K,
index_t M,
bool transpose_a,
bool transpose_b) {
std::unique_ptr<float[]> A(new float[batch * N * K]);
std::unique_ptr<float[]> B(new float[batch * K * M]);
std::unique_ptr<float[]> C(new float[batch * N * M]);
std::unique_ptr<float[]> C_ref(new float[batch * N * M]);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + batch * N * K,
[&gen, &nd] { return nd(gen); });
std::generate(B.get(), B.get() + batch * K * M,
[&gen, &nd] { return nd(gen); });
kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a,
transpose_b);
kernels::MatrixMap<const float> matrix_a;
kernels::MatrixMap<const float> matrix_b;
if (!transpose_a) {
matrix_a =
kernels::MatrixMap<const float>(batch,
N,
K,
kernels::RowMajor,
A.get());
} else {
matrix_a =
kernels::MatrixMap<const float>(batch,
K,
N,
kernels::RowMajor,
A.get());
matrix_a = matrix_a.transpose();
}
if (!transpose_b) {
matrix_b =
kernels::MatrixMap<const float>(batch,
K,
M,
kernels::RowMajor,
B.get());
} else {
matrix_b =
kernels::MatrixMap<const float>(batch,
M,
K,
kernels::RowMajor,
B.get());
matrix_b = matrix_b.transpose();
}
kernels::MatrixMap<float> matrix_c(batch, N, M, kernels::RowMajor, C.get());
kernels::SGemm sgemm;
sgemm(matrix_a, matrix_b, &matrix_c);
for (int i = 0; i < N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
}
}
} // namespace } // namespace
TEST(GEMMTest, AlignedWithoutBatch) { TEST(GEMMTest, AlignedWithoutBatch) {
...@@ -114,4 +184,25 @@ TEST(GEMMTest, gemv) { ...@@ -114,4 +184,25 @@ TEST(GEMMTest, gemv) {
GemvTest(3, 17, 63); GemvTest(3, 17, 63);
} }
namespace {
void TestSGemmTranspose(index_t batch, index_t N, index_t K, index_t M) {
SGemmTest(batch, N, K, M, false, false);
SGemmTest(batch, N, K, M, true, false);
SGemmTest(batch, N, K, M, false, true);
SGemmTest(batch, N, K, M, true, true);
}
}
TEST(SGEMMTest, UnalignedWithoutBatch) {
std::vector<index_t> tests{1, 5, 14, 31, 47};
for (index_t N : tests) {
for (index_t K : tests) {
for (index_t M : tests) {
TestSGemmTranspose(1, N, K, M);
TestSGemmTranspose(16, N, K, M);
}
}
}
}
} // namespace mace } // namespace mace
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "mace/kernels/kernel.h" #include "mace/kernels/kernel.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
#include "mace/kernels/gemmlowp_util.h" #include "mace/kernels/gemmlowp_util.h"
#include "mace/kernels/sgemm.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/cl2_header.h"
...@@ -83,39 +84,34 @@ struct MatMulFunctor : OpKernel { ...@@ -83,39 +84,34 @@ struct MatMulFunctor : OpKernel {
const T *b_ptr_base = B->data<T>(); const T *b_ptr_base = B->data<T>();
T *c_ptr_base = C->mutable_data<T>(); T *c_ptr_base = C->mutable_data<T>();
memset(c_ptr_base, 0, batch * height * width * sizeof(T)); const index_t height_a = A->dim(rank - 2);
const index_t width_a = A->dim(rank - 1);
if (height == 1 && width > 1 && B->is_weight()) { const index_t height_b = B->dim(rank - 2);
// A * B = (B^T * A^T)^T const index_t width_b = B->dim(rank - 1);
if (!transpose_b) {
if (B_transpose_.get() == nullptr) { sgemm_.Run(a_ptr_base,
B_transpose_.reset(new Tensor(context_->device()->allocator(), b_ptr_base,
DataTypeToEnum<T>::v())); batch,
B_transpose_->Resize({batch, width, K}); height_a,
Tensor::MappingGuard guardbt(B_transpose_.get()); width_a,
T *bt_ptr_base = B_transpose_->mutable_data<T>(); height_b,
Transpose(b_ptr_base, K, width, width, bt_ptr_base); width_b,
} transpose_a,
Tensor::MappingGuard guardbt(B_transpose_.get()); transpose_b,
T *bt_ptr_base = B_transpose_->mutable_data<T>(); A->is_weight(),
Gemv(bt_ptr_base, a_ptr_base, batch, K, width, c_ptr_base); B->is_weight(),
} else { c_ptr_base,
Gemv(b_ptr_base, a_ptr_base, batch, K, width, c_ptr_base); context_->workspace()->GetScratchBuffer(D));
}
} else {
Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base,
transpose_a, transpose_b);
}
return MACE_SUCCESS; return MACE_SUCCESS;
} }
std::unique_ptr<Tensor> B_transpose_; SGemm sgemm_;
}; };
template <> template <>
struct MatMulFunctor<CPU, uint8_t> : OpKernel { struct MatMulFunctor<CPU, uint8_t> : OpKernel {
explicit MatMulFunctor(OpKernelContext *context) : OpKernel(context) {} explicit MatMulFunctor(OpKernelContext *context) : OpKernel(context) {}
template<gemmlowp::MapOrder AOrder, gemmlowp::MapOrder BOrder> template<gemmlowp::MapOrder AOrder, gemmlowp::MapOrder BOrder>
void MatMulImpl(const Tensor *A, void MatMulImpl(const Tensor *A,
const Tensor *B, const Tensor *B,
...@@ -213,6 +209,7 @@ struct MatMulFunctor<CPU, uint8_t> : OpKernel { ...@@ -213,6 +209,7 @@ struct MatMulFunctor<CPU, uint8_t> : OpKernel {
template <typename T> template <typename T>
struct MatMulFunctor<DeviceType::GPU, T> : OpKernel { struct MatMulFunctor<DeviceType::GPU, T> : OpKernel {
explicit MatMulFunctor(OpKernelContext *context) : OpKernel(context) {} explicit MatMulFunctor(OpKernelContext *context) : OpKernel(context) {}
MaceStatus operator()(const Tensor *A, MaceStatus operator()(const Tensor *A,
const Tensor *B, const Tensor *B,
Tensor *C, Tensor *C,
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "mace/core/testing/test_benchmark.h" #include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/gemm.h" #include "mace/kernels/gemm.h"
#include "mace/kernels/gemmlowp_util.h" #include "mace/kernels/gemmlowp_util.h"
#include "mace/kernels/sgemm.h"
namespace gemmlowp { namespace gemmlowp {
...@@ -107,6 +108,28 @@ void MatmulBenchmark_Mace(int iters, int m, int k, int n) { ...@@ -107,6 +108,28 @@ void MatmulBenchmark_Mace(int iters, int m, int k, int n) {
} }
} }
void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) {
mace::testing::StopTiming();
std::vector<float> lhs(m * k);
std::vector<float> rhs(k * n);
std::vector<float> result(m * n);
kernels::MatrixMap<const float> matrix_lhs(1, m, k, RowMajor, lhs.data(),
true);
kernels::MatrixMap<const float> matrix_rhs(1, k, n, RowMajor, rhs.data(),
true);
kernels::MatrixMap<float> matrix_result(1, m, n, RowMajor, result.data());
kernels::SGemm sgemm;
sgemm(matrix_lhs, matrix_rhs, &matrix_result);
mace::testing::StartTiming();
while (iters--) {
sgemm(matrix_lhs, matrix_rhs, &matrix_result);
}
}
void MatmulBenchmark_Eigen(int iters, int m, int k, int n) { void MatmulBenchmark_Eigen(int iters, int m, int k, int n) {
mace::testing::StopTiming(); mace::testing::StopTiming();
Eigen::MatrixXf lhs = Eigen::MatrixXf::Random(m, k); Eigen::MatrixXf lhs = Eigen::MatrixXf::Random(m, k);
...@@ -202,6 +225,7 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) { ...@@ -202,6 +225,7 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) {
#define MACE_BM_MATMUL(M, K, N) \ #define MACE_BM_MATMUL(M, K, N) \
MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \ MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Mace_SGemm, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \ MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \ MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32, uint8_t); MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32, uint8_t);
...@@ -215,15 +239,43 @@ MACE_BM_MATMUL(15, 384, 384); ...@@ -215,15 +239,43 @@ MACE_BM_MATMUL(15, 384, 384);
MACE_BM_MATMUL(15, 384, 1536); MACE_BM_MATMUL(15, 384, 1536);
MACE_BM_MATMUL(15, 1536, 384); MACE_BM_MATMUL(15, 1536, 384);
MACE_BM_MATMUL(1, 384, 384); MACE_BM_MATMUL(1, 256, 256);
MACE_BM_MATMUL(1, 384, 1536); MACE_BM_MATMUL(1, 256, 1536);
MACE_BM_MATMUL(1, 1536, 384); MACE_BM_MATMUL(1, 1536, 256);
MACE_BM_MATMUL(1, 384, 44678); MACE_BM_MATMUL(256, 256, 1);
MACE_BM_MATMUL(1536, 256, 1);
MACE_BM_MATMUL(256, 1536, 1);
MACE_BM_MATMUL(29792, 256, 1);
MACE_BM_MATMUL(1, 256, 29792);
MACE_BM_MATMUL(2, 256, 256);
MACE_BM_MATMUL(2, 256, 1536);
MACE_BM_MATMUL(2, 1536, 256);
MACE_BM_MATMUL(3, 256, 256);
MACE_BM_MATMUL(3, 256, 1536);
MACE_BM_MATMUL(3, 1536, 256);
MACE_BM_MATMUL(4, 256, 256);
MACE_BM_MATMUL(4, 256, 1536);
MACE_BM_MATMUL(4, 1536, 256);
MACE_BM_MATMUL(8, 256, 256);
MACE_BM_MATMUL(8, 256, 1536);
MACE_BM_MATMUL(8, 1536, 256);
MACE_BM_MATMUL(10, 256, 256);
MACE_BM_MATMUL(10, 256, 1536);
MACE_BM_MATMUL(10, 1536, 256);
MACE_BM_MATMUL(15, 256, 256);
MACE_BM_MATMUL(15, 256, 1536);
MACE_BM_MATMUL(15, 1536, 256);
// Embedding size 128 // Embedding size 128
MACE_BM_MATMUL(1, 128, 1536); MACE_BM_MATMUL(1, 128, 1536);
MACE_BM_MATMUL(1, 128, 44678); MACE_BM_MATMUL(1, 128, 44678);
// MobileNet
MACE_BM_MATMUL(128, 128, 3136);
MACE_BM_MATMUL(256, 256, 784);
MACE_BM_MATMUL(512, 512, 196);
MACE_BM_MATMUL(1024, 1024, 49);
} // namespace test } // namespace test
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -12,80 +12,1158 @@ ...@@ -12,80 +12,1158 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm> #include <memory>
#include <cstring>
#include <vector>
#include "mace/kernels/sgemm.h" #include "mace/kernels/sgemm.h"
#include "mace/core/runtime/cpu/cpu_runtime.h"
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace mace { namespace mace {
namespace kernels { namespace kernels {
void SGemm::operator()(const MatrixMap<float> &lhs, void SGemm::operator()(const MatrixMap<const float> &lhs,
const MatrixMap<float> &rhs, const MatrixMap<const float> &rhs,
MatrixMap<float> *result) { MatrixMap<float> *result,
PackedBlock<float> packed_lhs; ScratchBuffer *scratch_buffer) {
PackLhs(lhs, &packed_lhs); if (rhs.col() < lhs.row()) {
MatrixMap<const float> lhs_transpose = lhs.transpose();
PackedBlock<float> packed_rhs; MatrixMap<const float> rhs_transpose = rhs.transpose();
PackRhs(rhs, &packed_rhs); MatrixMap<float> result_transpose = result->transpose();
return operator()(rhs_transpose,
PackedBlock<float> packed_result; lhs_transpose,
operator()(packed_lhs, &result_transpose,
packed_rhs, scratch_buffer);
lhs.row(), }
lhs.col(),
rhs.col(), if (scratch_buffer != nullptr) {
&packed_result); scratch_buffer->Rewind();
UnPack(packed_result, result); index_t total_size = result->size();
if (!lhs.is_const()) {
total_size += lhs.size();
}
if (!rhs.is_const()) {
total_size += rhs.size();
}
scratch_buffer->GrowSize(total_size * sizeof(float));
scratch_buffer->Rewind();
if (!lhs.is_const()) {
packed_lhs_.reset(new Tensor(scratch_buffer->Scratch(
lhs.size() * sizeof(float)), DT_FLOAT));
}
if (!rhs.is_const()) {
packed_lhs_.reset(new Tensor(scratch_buffer->Scratch(
rhs.size() * sizeof(float)), DT_FLOAT));
}
packed_result_.reset(new Tensor(scratch_buffer->Scratch(
result->size() * sizeof(float)), DT_FLOAT));
}
if (packed_lhs_.get() == nullptr) {
packed_lhs_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT));
packed_lhs_->Resize({lhs.size()});
}
if (packed_rhs_.get() == nullptr) {
packed_rhs_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT));
packed_rhs_->Resize({rhs.size()});
}
if (packed_result_.get() == nullptr) {
packed_result_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT));
packed_result_->Resize({result->size()});
}
if (!lhs.is_const() || !packed_) {
PackLhs(lhs, packed_lhs_.get());
}
if (!rhs.is_const() || !packed_) {
PackRhs(rhs, packed_rhs_.get());
}
packed_ = true;
RunInternal(*packed_lhs_,
*packed_rhs_,
lhs.batch(),
lhs.row(),
lhs.col(),
rhs.col(),
packed_result_.get());
UnPack(*packed_result_, result);
}
void SGemm::Run(const float *A,
const float *B,
const index_t batch,
const index_t height_a,
const index_t width_a,
const index_t height_b,
const index_t width_b,
const bool transpose_a,
const bool transpose_b,
const bool is_a_weight,
const bool is_b_weight,
float *C,
ScratchBuffer *scratch_buffer) {
index_t height_c = height_a;
index_t width_c = width_b;
if (transpose_a) {
height_c = width_a;
}
if (transpose_b) {
width_c = height_b;
}
MatrixMap<const float> matrix_a =
MatrixMap<const float>(batch,
height_a,
width_a,
kernels::RowMajor,
A,
is_a_weight);
MatrixMap<const float> matrix_b =
kernels::MatrixMap<const float>(batch,
height_b,
width_b,
kernels::RowMajor,
B,
is_b_weight);
if (transpose_a) {
matrix_a = matrix_a.transpose();
}
if (transpose_b) {
matrix_b = matrix_b.transpose();
}
MatrixMap<float> matrix_c(batch, height_c, width_c, kernels::RowMajor, C);
operator()(matrix_a, matrix_b, &matrix_c, scratch_buffer);
} }
void SGemm::operator()(const PackedBlock<float> &lhs, #if defined(MACE_ENABLE_NEON)
const PackedBlock<float> &rhs, #if defined(__aarch64__)
const index_t height,
const index_t depth, // calculate 8 rows, 4 cols for each depth
const index_t width, #define MACE_SGEMM_PART_CAL_R8_C4_D1(D, VD, VDN) \
PackedBlock<float> *result) { c0 = vfmaq_laneq_f32(c0, b##D, a##VD, 0); \
(void) lhs; c1 = vfmaq_laneq_f32(c1, b##D, a##VD, 1); \
(void) rhs; c2 = vfmaq_laneq_f32(c2, b##D, a##VD, 2); \
(void) result; c3 = vfmaq_laneq_f32(c3, b##D, a##VD, 3); \
(void) height; c4 = vfmaq_laneq_f32(c4, b##D, a##VDN, 0); \
(void) depth; c5 = vfmaq_laneq_f32(c5, b##D, a##VDN, 1); \
(void) width; c6 = vfmaq_laneq_f32(c6, b##D, a##VDN, 2); \
c7 = vfmaq_laneq_f32(c7, b##D, a##VDN, 3);
// calculate 4 rows, 4 cols for each depth
#define MACE_SGEMM_PART_CAL_R4_C4_D1(D) \
c0 = vfmaq_laneq_f32(c0, b##D, a##D, 0); \
c1 = vfmaq_laneq_f32(c1, b##D, a##D, 1); \
c2 = vfmaq_laneq_f32(c2, b##D, a##D, 2); \
c3 = vfmaq_laneq_f32(c3, b##D, a##D, 3);
// calculate 4 cols for 8 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C4_D8(R, VR, VRN) \
c##R = vfmaq_laneq_f32(c##R, b0, a##VR, 0); \
c##R = vfmaq_laneq_f32(c##R, b1, a##VR, 1); \
c##R = vfmaq_laneq_f32(c##R, b2, a##VR, 2); \
c##R = vfmaq_laneq_f32(c##R, b3, a##VR, 3); \
c##R = vfmaq_laneq_f32(c##R, b4, a##VRN, 0); \
c##R = vfmaq_laneq_f32(c##R, b5, a##VRN, 1); \
c##R = vfmaq_laneq_f32(c##R, b6, a##VRN, 2); \
c##R = vfmaq_laneq_f32(c##R, b7, a##VRN, 3);
// calculate 4 cols for 4 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C4_D4(R) \
c##R = vfmaq_laneq_f32(c##R, b0, a##R, 0); \
c##R = vfmaq_laneq_f32(c##R, b1, a##R, 1); \
c##R = vfmaq_laneq_f32(c##R, b2, a##R, 2); \
c##R = vfmaq_laneq_f32(c##R, b3, a##R, 3);
// calculate 8 cols for 4 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C8_D4(VR, VRN, R) \
c##VR = vfmaq_laneq_f32(c##VR, b0, a##R, 0); \
c##VR = vfmaq_laneq_f32(c##VR, b2, a##R, 1); \
c##VR = vfmaq_laneq_f32(c##VR, b4, a##R, 2); \
c##VR = vfmaq_laneq_f32(c##VR, b6, a##R, 3); \
c##VRN = vfmaq_laneq_f32(c##VRN, b1, a##R, 0); \
c##VRN = vfmaq_laneq_f32(c##VRN, b3, a##R, 1); \
c##VRN = vfmaq_laneq_f32(c##VRN, b5, a##R, 2); \
c##VRN = vfmaq_laneq_f32(c##VRN, b7, a##R, 3);
#else
#define MACE_SGEMM_PART_CAL_R8_C4_D1(D, VD, VDN) \
c0 = vmlaq_lane_f32(c0, b##D, vget_low_f32(a##VD), 0); \
c1 = vmlaq_lane_f32(c1, b##D, vget_low_f32(a##VD), 1); \
c2 = vmlaq_lane_f32(c2, b##D, vget_high_f32(a##VD), 0); \
c3 = vmlaq_lane_f32(c3, b##D, vget_high_f32(a##VD), 1); \
c4 = vmlaq_lane_f32(c4, b##D, vget_low_f32(a##VDN), 0); \
c5 = vmlaq_lane_f32(c5, b##D, vget_low_f32(a##VDN), 1); \
c6 = vmlaq_lane_f32(c6, b##D, vget_high_f32(a##VDN), 0); \
c7 = vmlaq_lane_f32(c7, b##D, vget_high_f32(a##VDN), 1);
#define MACE_SGEMM_PART_CAL_R4_C4_D1(D) \
c0 = vmlaq_lane_f32(c0, b##D, vget_low_f32(a##D), 0); \
c1 = vmlaq_lane_f32(c1, b##D, vget_low_f32(a##D), 1); \
c2 = vmlaq_lane_f32(c2, b##D, vget_high_f32(a##D), 0); \
c3 = vmlaq_lane_f32(c3, b##D, vget_high_f32(a##D), 1);
#define MACE_SGEMM_PART_CAL_R1_C4_D8(R, VR, VRN) \
c##R = vmlaq_lane_f32(c##R, b0, vget_low_f32(a##VR), 0); \
c##R = vmlaq_lane_f32(c##R, b1, vget_low_f32(a##VR), 1); \
c##R = vmlaq_lane_f32(c##R, b2, vget_high_f32(a##VR), 0); \
c##R = vmlaq_lane_f32(c##R, b3, vget_high_f32(a##VR), 1); \
c##R = vmlaq_lane_f32(c##R, b4, vget_low_f32(a##VRN), 0); \
c##R = vmlaq_lane_f32(c##R, b5, vget_low_f32(a##VRN), 1); \
c##R = vmlaq_lane_f32(c##R, b6, vget_high_f32(a##VRN), 0); \
c##R = vmlaq_lane_f32(c##R, b7, vget_high_f32(a##VRN), 1);
#define MACE_SGEMM_PART_CAL_R1_C4_D4(R) \
c##R = vmlaq_lane_f32(c##R, b0, vget_low_f32(a##R), 0); \
c##R = vmlaq_lane_f32(c##R, b1, vget_low_f32(a##R), 1); \
c##R = vmlaq_lane_f32(c##R, b2, vget_high_f32(a##R), 0); \
c##R = vmlaq_lane_f32(c##R, b3, vget_high_f32(a##R), 1);
#endif // __aarch64__
#endif // MACE_ENABLE_NEON
void SGemm::RunInternal(const PackedBlock &lhs,
const PackedBlock &rhs,
const index_t batch,
const index_t height,
const index_t depth,
const index_t width,
PackedBlock *result) {
const float *lhs_data = lhs.data<float>();
const float *rhs_data = rhs.data<float>();
float *result_data = result->mutable_data<float>();
// (8, 8) * (8, 4) #define MACE_SGEMM_RUN_PER_BATCH \
for (index_t b = 0; b < batch; ++b) { \
RunPerBatch(lhs_data + b * height * depth, \
rhs_data + b * depth * width, \
height, \
depth, \
width, \
result_data + b * height * width); \
}
// (4, 4) * (4, 4) if (batch >= MaceOpenMPThreadCount) {
#pragma omp parallel for
MACE_SGEMM_RUN_PER_BATCH
} else {
MACE_SGEMM_RUN_PER_BATCH
}
// remain #undef MACE_SGEMM_RUN_PER_BATCH
} }
void SGemm::PackLhs(const MatrixMap<float> &lhs, void SGemm::RunPerBatch(const float *lhs_data,
PackedBlock<float> *packed_block) { const float *rhs_data,
const index_t height,
const index_t depth,
const index_t width,
float *result_data) {
#if defined(MACE_ENABLE_NEON)
const index_t block_w = width >> 2;
const index_t remain_w = width - (block_w << 2);
#else
const index_t remain_w = width;
#endif
#if defined(MACE_ENABLE_NEON)
// TODO(liyin): make better use l2(l1) cache, try to fit as much lhs data as
// as possible to cache, by tiling lhs by height and rhs by width.
// w: 4
#pragma omp parallel for
for (index_t bw = 0; bw < block_w; ++bw) {
index_t remain_h = height;
index_t block_h = 0;
const float *lhs_ptr = lhs_data;
float *res_ptr = result_data + height * (bw << 2);
#if defined(__aarch64__)
block_h = remain_h >> 3;
remain_h -= (block_h << 3);
// h: 8
for (index_t bh = 0; bh < block_h; ++bh) {
const float *rhs_ptr = rhs_data + depth * (bw << 2);
index_t remain_d = depth;
index_t block_d = remain_d >> 3;
remain_d -= (block_d << 3);
float32x4_t c0, c1, c2, c3, c4, c5, c6, c7;
c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
c2 = vdupq_n_f32(0.f);
c3 = vdupq_n_f32(0.f);
c4 = vdupq_n_f32(0.f);
c5 = vdupq_n_f32(0.f);
c6 = vdupq_n_f32(0.f);
c7 = vdupq_n_f32(0.f);
// d: 8
for (index_t bd = 0; bd < block_d; ++bd) {
// 8.8.4
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13,
a14, a15;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
a2 = vld1q_f32(lhs_ptr + 8);
a3 = vld1q_f32(lhs_ptr + 12);
a4 = vld1q_f32(lhs_ptr + 16);
a5 = vld1q_f32(lhs_ptr + 20);
a6 = vld1q_f32(lhs_ptr + 24);
a7 = vld1q_f32(lhs_ptr + 28);
a8 = vld1q_f32(lhs_ptr + 32);
a9 = vld1q_f32(lhs_ptr + 36);
a10 = vld1q_f32(lhs_ptr + 40);
a11 = vld1q_f32(lhs_ptr + 44);
a12 = vld1q_f32(lhs_ptr + 48);
a13 = vld1q_f32(lhs_ptr + 52);
a14 = vld1q_f32(lhs_ptr + 56);
a15 = vld1q_f32(lhs_ptr + 60);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
b4 = vld1q_f32(rhs_ptr + 16);
b5 = vld1q_f32(rhs_ptr + 20);
b6 = vld1q_f32(rhs_ptr + 24);
b7 = vld1q_f32(rhs_ptr + 28);
MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1
MACE_SGEMM_PART_CAL_R8_C4_D1(1, 2, 3); // d = 2
MACE_SGEMM_PART_CAL_R8_C4_D1(2, 4, 5);
MACE_SGEMM_PART_CAL_R8_C4_D1(3, 6, 7);
MACE_SGEMM_PART_CAL_R8_C4_D1(4, 8, 9);
MACE_SGEMM_PART_CAL_R8_C4_D1(5, 10, 11);
MACE_SGEMM_PART_CAL_R8_C4_D1(6, 12, 13);
MACE_SGEMM_PART_CAL_R8_C4_D1(7, 14, 15);
lhs_ptr += 64;
rhs_ptr += 32;
}
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 8.4.4
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7;
float32x4_t b0, b1, b2, b3;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
a2 = vld1q_f32(lhs_ptr + 8);
a3 = vld1q_f32(lhs_ptr + 12);
a4 = vld1q_f32(lhs_ptr + 16);
a5 = vld1q_f32(lhs_ptr + 20);
a6 = vld1q_f32(lhs_ptr + 24);
a7 = vld1q_f32(lhs_ptr + 28);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1
MACE_SGEMM_PART_CAL_R8_C4_D1(1, 2, 3); // d = 2
MACE_SGEMM_PART_CAL_R8_C4_D1(2, 4, 5);
MACE_SGEMM_PART_CAL_R8_C4_D1(3, 6, 7);
lhs_ptr += 32;
rhs_ptr += 16;
}
// TODO(liyin): handle remain by each case
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 8.1.4
float32x4_t a0, a1;
float32x4_t b0;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
b0 = vld1q_f32(rhs_ptr);
MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1
lhs_ptr += 8;
rhs_ptr += 4;
}
vst1q_f32(res_ptr, c0);
vst1q_f32(res_ptr + 4, c1);
vst1q_f32(res_ptr + 8, c2);
vst1q_f32(res_ptr + 12, c3);
vst1q_f32(res_ptr + 16, c4);
vst1q_f32(res_ptr + 20, c5);
vst1q_f32(res_ptr + 24, c6);
vst1q_f32(res_ptr + 28, c7);
res_ptr += 32;
} // bh: 8
#endif // __aarch64__
// h: 4
block_h = remain_h >> 2;
remain_h -= (block_h << 2);
for (index_t bh = 0; bh < block_h; ++bh) {
const float *rhs_ptr = rhs_data + depth * (bw << 2);
index_t remain_d = depth;
index_t block_d = 0;
float32x4_t c0, c1, c2, c3;
c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
c2 = vdupq_n_f32(0.f);
c3 = vdupq_n_f32(0.f);
// d: 8
block_d = remain_d >> 3;
remain_d -= (block_d << 3);
#if defined(__aarch64__)
for (index_t bd = 0; bd < block_d; ++bd) {
// 4.8.4
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
a2 = vld1q_f32(lhs_ptr + 8);
a3 = vld1q_f32(lhs_ptr + 12);
a4 = vld1q_f32(lhs_ptr + 16);
a5 = vld1q_f32(lhs_ptr + 20);
a6 = vld1q_f32(lhs_ptr + 24);
a7 = vld1q_f32(lhs_ptr + 28);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
b4 = vld1q_f32(rhs_ptr + 16);
b5 = vld1q_f32(rhs_ptr + 20);
b6 = vld1q_f32(rhs_ptr + 24);
b7 = vld1q_f32(rhs_ptr + 28);
MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1
MACE_SGEMM_PART_CAL_R4_C4_D1(1); // d = 2
MACE_SGEMM_PART_CAL_R4_C4_D1(2);
MACE_SGEMM_PART_CAL_R4_C4_D1(3);
MACE_SGEMM_PART_CAL_R4_C4_D1(4);
MACE_SGEMM_PART_CAL_R4_C4_D1(5);
MACE_SGEMM_PART_CAL_R4_C4_D1(6);
MACE_SGEMM_PART_CAL_R4_C4_D1(7);
lhs_ptr += 32;
rhs_ptr += 32;
}
#else // arm v7
// 4.8.4
if (block_d > 0) {
asm volatile(
"0: \n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n"
"vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n"
"vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n"
"vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n"
"vld1.f32 {d24-d25}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q10, d0[0] \n"
"vmla.f32 %[c1], q10, d0[1] \n"
"vmla.f32 %[c2], q10, d1[0] \n"
"vmla.f32 %[c3], q10, d1[1] \n"
"vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n"
"vld1.f32 {d26-d27}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q11, d2[0] \n"
"vmla.f32 %[c1], q11, d2[1] \n"
"vmla.f32 %[c2], q11, d3[0] \n"
"vmla.f32 %[c3], q11, d3[1] \n"
"vld1.f32 {d8-d9}, [%[lhs_ptr]]! \n"
"vld1.f32 {d28-d29}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q12, d4[0] \n"
"vmla.f32 %[c1], q12, d4[1] \n"
"vmla.f32 %[c2], q12, d5[0] \n"
"vmla.f32 %[c3], q12, d5[1] \n"
"vld1.f32 {d10-d11}, [%[lhs_ptr]]! \n"
"vld1.f32 {d30-d31}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q13, d6[0] \n"
"vmla.f32 %[c1], q13, d6[1] \n"
"vmla.f32 %[c2], q13, d7[0] \n"
"vmla.f32 %[c3], q13, d7[1] \n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n"
"vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n"
"vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q14, d8[0] \n"
"vmla.f32 %[c1], q14, d8[1] \n"
"vmla.f32 %[c2], q14, d9[0] \n"
"vmla.f32 %[c3], q14, d9[1] \n"
"vmla.f32 %[c0], q15, d10[0] \n"
"vmla.f32 %[c1], q15, d10[1] \n"
"vmla.f32 %[c2], q15, d11[0] \n"
"vmla.f32 %[c3], q15, d11[1] \n"
"vmla.f32 %[c0], q10, d0[0] \n"
"vmla.f32 %[c1], q10, d0[1] \n"
"vmla.f32 %[c2], q10, d1[0] \n"
"vmla.f32 %[c3], q10, d1[1] \n"
"subs %[block_d], %[block_d], #1 \n"
"vmla.f32 %[c0], q11, d2[0] \n"
"vmla.f32 %[c1], q11, d2[1] \n"
"vmla.f32 %[c2], q11, d3[0] \n"
"vmla.f32 %[c3], q11, d3[1] \n"
"bne 0b \n"
: // outputs
[lhs_ptr] "+r"(lhs_ptr),
[rhs_ptr] "+r"(rhs_ptr),
[res_ptr] "+r"(res_ptr),
[block_d] "+r"(block_d),
[c0] "+w"(c0),
[c1] "+w"(c1),
[c2] "+w"(c2),
[c3] "+w"(c3)
: // inputs
: // clabbers
"cc", "memory",
"q0", "q1", "q2", "q3", "q4", "q5",
"q10", "q11", "q12", "q13", "q14", "q15");
}
#endif // __aarch64__
// d: 4
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
for (index_t bd = 0; bd < block_d; ++bd) {
// 4.4.4
float32x4_t a0, a1, a2, a3;
float32x4_t b0, b1, b2, b3;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
a2 = vld1q_f32(lhs_ptr + 8);
a3 = vld1q_f32(lhs_ptr + 12);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1
MACE_SGEMM_PART_CAL_R4_C4_D1(1); // d = 2
MACE_SGEMM_PART_CAL_R4_C4_D1(2);
MACE_SGEMM_PART_CAL_R4_C4_D1(3);
lhs_ptr += 16;
rhs_ptr += 16;
}
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 4.1.4
float32x4_t a0;
float32x4_t b0;
a0 = vld1q_f32(lhs_ptr);
b0 = vld1q_f32(rhs_ptr);
MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1
lhs_ptr += 4;
rhs_ptr += 4;
}
vst1q_f32(res_ptr, c0);
vst1q_f32(res_ptr + 4, c1);
vst1q_f32(res_ptr + 8, c2);
vst1q_f32(res_ptr + 12, c3);
res_ptr += 16;
} // bh: 4
// h: 1
for (index_t h = 0; h < remain_h; ++h) {
const float *rhs_ptr = rhs_data + depth * (bw << 2);
index_t remain_d = depth;
index_t block_d = 0;
float32x4_t c0 = vdupq_n_f32(0.f);
// d: 8
block_d = remain_d >> 3;
remain_d -= (block_d << 3);
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.8.4
float32x4_t a0, a1;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
b4 = vld1q_f32(rhs_ptr + 16);
b5 = vld1q_f32(rhs_ptr + 20);
b6 = vld1q_f32(rhs_ptr + 24);
b7 = vld1q_f32(rhs_ptr + 28);
MACE_SGEMM_PART_CAL_R1_C4_D8(0, 0, 1);
lhs_ptr += 8;
rhs_ptr += 32;
}
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.4.4
float32x4_t a0;
float32x4_t b0, b1, b2, b3;
a0 = vld1q_f32(lhs_ptr);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
MACE_SGEMM_PART_CAL_R1_C4_D4(0);
lhs_ptr += 4;
rhs_ptr += 16;
}
// d: remain
float s0 = 0;
float s1 = 0;
float s2 = 0;
float s3 = 0;
for (index_t d = 0; d < remain_d; ++d) {
// 1.1.4
s0 += lhs_ptr[0] * rhs_ptr[0];
s1 += lhs_ptr[0] * rhs_ptr[1];
s2 += lhs_ptr[0] * rhs_ptr[2];
s3 += lhs_ptr[0] * rhs_ptr[3];
lhs_ptr += 1;
rhs_ptr += 4;
}
float32x4_t c0_remain = {s0, s1, s2, s3};
c0 += c0_remain;
vst1q_f32(res_ptr, c0);
res_ptr += 4;
} // bh: remain
} // bw
#endif // MACE_ENABLE_NEON
// ========================== remain width ===========================
result_data += (width - remain_w) * height;
rhs_data += (width - remain_w) * depth;
// w: 1
#pragma omp parallel for
for (index_t bw = 0; bw < remain_w; ++bw) {
index_t remain_h = height;
const float *lhs_ptr = lhs_data;
float *res_ptr = result_data + height * bw;
#if defined(MACE_ENABLE_NEON)
index_t block_h = 0;
#if defined(__aarch64__)
block_h = remain_h >> 3;
remain_h -= (block_h << 3);
// h: 8
for (index_t bh = 0; bh < block_h; ++bh) {
const float *rhs_ptr = rhs_data + depth * bw;
index_t remain_d = depth;
float32x4_t c0, c1;
c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
index_t block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 8.4.1
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t a0;
b0 = vld1q_f32(lhs_ptr);
b1 = vld1q_f32(lhs_ptr + 4);
b2 = vld1q_f32(lhs_ptr + 8);
b3 = vld1q_f32(lhs_ptr + 12);
b4 = vld1q_f32(lhs_ptr + 16);
b5 = vld1q_f32(lhs_ptr + 20);
b6 = vld1q_f32(lhs_ptr + 24);
b7 = vld1q_f32(lhs_ptr + 28);
a0 = vld1q_f32(rhs_ptr);
MACE_SGEMM_PART_CAL_R1_C8_D4(0, 1, 0);
lhs_ptr += 32;
rhs_ptr += 4;
}
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 8.1.1
float32x4_t b0, b1;
float32x4_t a0 = vdupq_n_f32(rhs_ptr[0]);
b0 = vld1q_f32(lhs_ptr);
b1 = vld1q_f32(lhs_ptr + 4);
c0 = vfmaq_laneq_f32(c0, b0, a0, 0);
c1 = vfmaq_laneq_f32(c1, b1, a0, 0);
lhs_ptr += 8;
rhs_ptr += 1;
}
vst1q_f32(res_ptr, c0);
vst1q_f32(res_ptr + 4, c1);
res_ptr += 8;
} // bh: 8
#endif
// h: 4
block_h = remain_h >> 2;
remain_h -= (block_h << 2);
for (index_t bh = 0; bh < block_h; ++bh) {
const float *rhs_ptr = rhs_data + depth * bw;
index_t remain_d = depth;
index_t block_d = 0;
float32x4_t c0 = vdupq_n_f32(0.f);
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 4.4.1
float32x4_t b0, b1, b2, b3;
float32x4_t a0;
b0 = vld1q_f32(lhs_ptr);
b1 = vld1q_f32(lhs_ptr + 4);
b2 = vld1q_f32(lhs_ptr + 8);
b3 = vld1q_f32(lhs_ptr + 12);
a0 = vld1q_f32(rhs_ptr);
MACE_SGEMM_PART_CAL_R1_C4_D4(0);
lhs_ptr += 16;
rhs_ptr += 4;
}
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 4.1.1
float32x4_t b0, b1;
float32x2_t a0 = vdup_n_f32(rhs_ptr[0]);
b0 = vld1q_f32(lhs_ptr);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
lhs_ptr += 4;
rhs_ptr += 1;
}
vst1q_f32(res_ptr, c0);
res_ptr += 4;
} // bh: 4
#endif // MACE_ENABLE_NEON
// h: 1
for (index_t h = 0; h < remain_h; ++h) {
const float *rhs_ptr = rhs_data + depth * bw;
index_t remain_d = depth;
float sum = 0.f;
#if defined(MACE_ENABLE_NEON)
index_t block_d = 0;
float32x4_t c0, c1;
c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
block_d = remain_d >> 3;
remain_d -= (block_d << 3);
// d: 8
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.8.1
float32x4_t a0, a1;
float32x4_t b0, b1;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
c0 = vmlaq_f32(c0, a0, b0);
c1 = vmlaq_f32(c1, a1, b1);
lhs_ptr += 8;
rhs_ptr += 8;
}
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.4.1
float32x4_t a0;
float32x4_t b0;
a0 = vld1q_f32(lhs_ptr);
b0 = vld1q_f32(rhs_ptr);
c0 = vmlaq_f32(c0, a0, b0);
lhs_ptr += 4;
rhs_ptr += 4;
}
sum += vaddvq_f32(c0);
sum += vaddvq_f32(c1);
#endif // MACE_ENABLE_NEON
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 1.1.1
sum += lhs_ptr[0] * rhs_ptr[0];
lhs_ptr += 1;
rhs_ptr += 1;
}
*res_ptr = sum;
++res_ptr;
} // bh: remain
} // bw
}
void SGemm::PackLhs(const MatrixMap<const float> &lhs,
PackedBlock *packed_block) {
Pack(lhs, PackOrder::ColMajor, packed_block); Pack(lhs, PackOrder::ColMajor, packed_block);
} }
void SGemm::PackRhs(const MatrixMap<float> &rhs, void SGemm::PackRhs(const MatrixMap<const float> &rhs,
PackedBlock<float> *packed_block) { PackedBlock *packed_block) {
Pack(rhs, PackOrder::RowMajor, packed_block); Pack(rhs, PackOrder::RowMajor, packed_block);
} }
void SGemm::UnPack(const PackedBlock<float> &packed_result, void SGemm::Pack(const MatrixMap<const float> &src,
const PackOrder order,
PackedBlock *packed_block) {
MACE_CHECK_NOTNULL(packed_block);
const index_t height = src.row();
const index_t width = src.col();
auto packed_data = packed_block->mutable_data<float>();
#define MACE_SGEMM_PACK_PER_BATCH \
for (index_t b = 0; b < src.batch(); ++b) { \
PackPerBatch(src, order, b, packed_data + b * height * width); \
}
if (src.batch() >= MaceOpenMPThreadCount) {
#pragma omp parallel for
MACE_SGEMM_PACK_PER_BATCH
} else {
MACE_SGEMM_PACK_PER_BATCH
}
#undef MACE_SGEMM_PACK_PER_BATCH
}
void SGemm::UnPack(const PackedBlock &packed_result,
MatrixMap<float> *matrix_map) { MatrixMap<float> *matrix_map) {
(void) packed_result; MACE_CHECK_NOTNULL(matrix_map);
(void) matrix_map;
const index_t height = matrix_map->row();
const index_t width = matrix_map->col();
auto packed_data = packed_result.data<float>();
#define MACE_SGEMM_UNPACK_PER_BATCH \
for (index_t b = 0; b < matrix_map->batch(); ++b) { \
UnPackPerBatch(packed_data + b * height * width, b, matrix_map); \
}
if (matrix_map->batch() >= MaceOpenMPThreadCount) {
#pragma omp parallel for
MACE_SGEMM_UNPACK_PER_BATCH
} else {
MACE_SGEMM_UNPACK_PER_BATCH
}
#undef MACE_SGEMM_UNPACK_PER_BATCH
} }
void SGemm::Pack(const MatrixMap<float> &src, void SGemm::PackPerBatch(const MatrixMap<const float> &src,
const PackOrder order, const PackOrder order,
PackedBlock<float> *packed_block) { const index_t batch_index,
(void) src; float *packed_data) {
(void) order; MACE_CHECK_NOTNULL(packed_data);
(void) packed_block;
const index_t height = src.row();
const index_t width = src.col();
auto src_data = src.batch_data(batch_index);
if (src.major() == Major::RowMajor && order == PackOrder::ColMajor) {
// This is for packing no-transpose lhs.
index_t h = 0;
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#pragma omp parallel for
for (index_t ih = h; ih <= height - 8; ih += 8) {
const float *src_data_ptr = src_data + ih * width;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
const index_t src_offset = w;
const index_t packed_offset = w * 8;
float32x4_t vs0 = {src_data_ptr[src_offset],
src_data_ptr[src_offset + width],
src_data_ptr[src_offset + 2 * width],
src_data_ptr[src_offset + 3 * width]};
float32x4_t vs1 = {src_data_ptr[src_offset + 4 * width],
src_data_ptr[src_offset + 5 * width],
src_data_ptr[src_offset + 6 * width],
src_data_ptr[src_offset + 7 * width]};
vst1q_f32(packed_data_ptr + packed_offset, vs0);
vst1q_f32(packed_data_ptr + packed_offset + 4, vs1);
}
}
h += (height - h) / 8 * 8;
#endif
#pragma omp parallel for
for (index_t ih = h; ih <= height - 4; ih += 4) {
const float *src_data_ptr = src_data + ih * width;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
const index_t src_offset = w;
const index_t packed_offset = w * 4;
float32x4_t vs = {src_data_ptr[src_offset],
src_data_ptr[src_offset + width],
src_data_ptr[src_offset + 2 * width],
src_data_ptr[src_offset + 3 * width]};
vst1q_f32(packed_data_ptr + packed_offset, vs);
}
}
h += (height - h) / 4 * 4;
#endif
#pragma omp parallel for
for (index_t ih = h; ih < height; ++ih) {
std::copy_n(src_data + ih * width, width, packed_data + ih * width);
}
} else if (src.major() == Major::ColMajor && order == PackOrder::ColMajor) {
// This is for packing transpose-needed lhs.
index_t h = 0;
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#pragma omp parallel for
for (index_t ih = h; ih <= height - 8; ih += 8) {
const float *src_data_ptr = src_data + ih;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
const index_t src_offset = w * height;
const index_t packed_offset = w * 8;
float32x4_t vs0 = vld1q_f32(src_data_ptr + src_offset);
float32x4_t vs1 = vld1q_f32(src_data_ptr + src_offset + 4);
vst1q_f32(packed_data_ptr + packed_offset, vs0);
vst1q_f32(packed_data_ptr + packed_offset + 4, vs1);
}
}
h += (height - h) / 8 * 8;
#endif
#pragma omp parallel for
for (index_t ih = h; ih <= height - 4; ih += 4) {
const float *src_data_ptr = src_data + ih;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
const index_t src_offset = w * height;
const index_t packed_offset = w * 4;
float32x4_t vs = vld1q_f32(src_data_ptr + src_offset);
vst1q_f32(packed_data_ptr + packed_offset, vs);
}
}
h += (height - h) / 4 * 4;
#endif
#pragma omp parallel for
for (index_t ih = h; ih < height; ++ih) {
const float *src_data_ptr = src_data + ih;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
packed_data_ptr[w] = src_data_ptr[w * height];
}
}
} else if (src.major() == Major::RowMajor && order == PackOrder::RowMajor) {
// This is for packing no-transpose rhs.
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for
for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *src_data_ptr = src_data + iw;
float *packed_data_ptr = packed_data + iw * height;
for (index_t h = 0; h < height; ++h) {
const index_t src_offset = h * width;
const index_t packed_offset = h * 4;
float32x4_t vs = vld1q_f32(src_data_ptr + src_offset);
vst1q_f32(packed_data_ptr + packed_offset, vs);
}
}
w += (width - w) / 4 * 4;
#endif
#pragma omp parallel for
for (index_t iw = w; iw < width; ++iw) {
const float *src_data_ptr = src_data + iw;
float *packed_data_ptr = packed_data + iw * height;
for (index_t h = 0; h < height; ++h) {
packed_data_ptr[h] = src_data_ptr[h * width];
}
}
} else if (src.major() == Major::ColMajor && order == PackOrder::RowMajor) {
// This is for packing transpose-needed rhs.
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for
for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *src_data_ptr = src_data + iw * height;
float *packed_data_ptr = packed_data + iw * height;
for (index_t h = 0; h < height; ++h) {
const index_t src_offset = h;
const index_t packed_offset = h * 4;
float32x4_t vs = {src_data_ptr[src_offset],
src_data_ptr[src_offset + height],
src_data_ptr[src_offset + 2 * height],
src_data_ptr[src_offset + 3 * height]};
vst1q_f32(packed_data_ptr + packed_offset, vs);
}
}
w += (width - w) / 4 * 4;
#endif
#pragma omp parallel for
for (index_t iw = w; iw < width; ++iw) {
std::copy_n(src_data + iw * height, height, packed_data + iw * height);
}
}
}
void SGemm::UnPackPerBatch(const float *packed_data,
const index_t batch_index,
MatrixMap<float> *matrix_map) {
MACE_CHECK_NOTNULL(matrix_map);
const index_t height = matrix_map->row();
const index_t width = matrix_map->col();
auto unpacked_data = matrix_map->batch_data(batch_index);
if (matrix_map->major() == Major::RowMajor) {
// This is for non-transposed result
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for
for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *packed_data_ptr = packed_data + iw * height;
float *unpacked_data_ptr = unpacked_data + iw;
for (index_t h = 0; h < height; ++h) {
const index_t packed_offset = h * 4;
const index_t unpacked_offset = h * width;
float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset);
vst1q_f32(unpacked_data_ptr + unpacked_offset, vs);
}
}
w += (width - w) / 4 * 4;
#endif
#pragma omp parallel for
for (index_t iw = w; iw < width; ++iw) {
const float *packed_data_ptr = packed_data + iw * height;
float *unpacked_data_ptr = unpacked_data + iw;
for (index_t h = 0; h < height; ++h) {
unpacked_data_ptr[h * width] = packed_data_ptr[h];
}
}
} else {
// This is for transposed result
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for
for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *packed_data_ptr = packed_data + iw * height;
float *unpacked_data_ptr = unpacked_data + iw * height;
for (index_t h = 0; h < height; ++h) {
const index_t packed_offset = h * 4;
const index_t unpacked_offset = h;
float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset);
unpacked_data_ptr[unpacked_offset] = vs[0];
unpacked_data_ptr[unpacked_offset + height] = vs[1];
unpacked_data_ptr[unpacked_offset + 2 * height] = vs[2];
unpacked_data_ptr[unpacked_offset + 3 * height] = vs[3];
}
}
w += (width - w) / 4 * 4;
#endif
#pragma omp parallel for
for (index_t iw = w; iw < width; ++iw) {
std::copy_n(
packed_data + iw * height, height, unpacked_data + iw * height);
}
}
} }
} // namespace kernels } // namespace kernels
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
#ifndef MACE_KERNELS_SGEMM_H_ #ifndef MACE_KERNELS_SGEMM_H_
#define MACE_KERNELS_SGEMM_H_ #define MACE_KERNELS_SGEMM_H_
#include <memory>
#include <utility>
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
...@@ -34,22 +37,29 @@ enum Major { ...@@ -34,22 +37,29 @@ enum Major {
template<typename T> template<typename T>
class MatrixMap { class MatrixMap {
public: public:
MatrixMap(const index_t row, MatrixMap() {}
MatrixMap(const index_t batch,
const index_t row,
const index_t col, const index_t col,
const Major major, const Major major,
T *data) : T *data,
const bool is_const = false) :
batch_(batch),
row_(row), row_(row),
col_(col), col_(col),
stride_(major == RowMajor ? col : row), stride_(major == RowMajor ? col : row),
major_(major), major_(major),
data_(data) {} data_(data),
is_const_(is_const) {}
MatrixMap<T> transpose(const MatrixMap<T> &matrix_map) {
Major transpose_major = matrix_map.major_ == RowMajor ? ColMajor : RowMajor; MatrixMap transpose() const {
return MatrixMap<T>(matrix_map.col_, Major transpose_major = major_ == RowMajor ? ColMajor : RowMajor;
matrix_map.row_, return MatrixMap(batch_, col_, row_, transpose_major, data_, is_const_);
transpose_major, }
matrix_map.data_);
index_t batch() const {
return batch_;
} }
index_t row() const { index_t row() const {
...@@ -72,66 +82,100 @@ class MatrixMap { ...@@ -72,66 +82,100 @@ class MatrixMap {
return data_; return data_;
} }
T *data(int row, int col) const { T *batch_data(index_t batch) const {
return data_ + row * stride_ + col; return data_ + batch * row_ * col_;
}
index_t size() const {
return batch_ * row_ * col_;
}
bool is_const() const {
return is_const_;
} }
private: private:
index_t batch_;
index_t row_; index_t row_;
index_t col_; index_t col_;
index_t stride_; index_t stride_;
Major major_; Major major_;
T *data_; T *data_;
bool is_const_;
}; };
typedef Major PackOrder; typedef Major PackOrder;
typedef Tensor PackedBlock;
template<typename T>
class PackedBlock {
public:
PackedBlock() : data_tensor_(GetCPUAllocator(),
DataTypeToEnum<T>::v()) {}
const T *data() {
return data_tensor_.data<T>();
}
T *mutable_data() {
return data_tensor_.mutable_data<T>();
}
Tensor *tensor() {
return &data_tensor_;
}
private:
Tensor data_tensor_;
};
class SGemm { class SGemm {
public: public:
void operator()(const MatrixMap<float> &lhs, SGemm()
const MatrixMap<float> &rhs, : packed_lhs_(nullptr),
MatrixMap<float> *result); packed_rhs_(nullptr),
packed_(false) {}
void operator()(const PackedBlock<float> &lhs,
const PackedBlock<float> &rhs, void operator()(const MatrixMap<const float> &lhs,
const index_t height, const MatrixMap<const float> &rhs,
const index_t depth, MatrixMap<float> *result,
const index_t width, ScratchBuffer *scratch_buffer = nullptr);
PackedBlock<float> *result);
void Run(const float *A,
void PackLhs(const MatrixMap<float> &lhs, PackedBlock<float> *packed_block); const float *B,
const index_t batch,
void PackRhs(const MatrixMap<float> &rhs, PackedBlock<float> *packed_block); const index_t height_a,
const index_t width_a,
void UnPack(const PackedBlock<float> &packed_result, const index_t height_b,
const index_t width_b,
const bool transpose_a,
const bool transpose_b,
const bool is_a_weight,
const bool is_b_weight,
float *C,
ScratchBuffer *scratch_buffer = nullptr);
void PackLhs(const MatrixMap<const float> &lhs,
PackedBlock *packed_block);
void PackRhs(const MatrixMap<const float> &rhs,
PackedBlock *packed_block);
void UnPack(const PackedBlock &packed_result,
MatrixMap<float> *matrix_map); MatrixMap<float> *matrix_map);
private: private:
void Pack(const MatrixMap<float> &src, void Pack(const MatrixMap<const float> &src,
const PackOrder order, const PackOrder order,
PackedBlock<float> *packed_block); PackedBlock *packed_block);
void PackPerBatch(const MatrixMap<const float> &src,
const PackOrder order,
const index_t batch_index,
float *packed_data);
void UnPackPerBatch(const float *packed_data,
const index_t batch_index,
MatrixMap<float> *matrix_map);
void RunInternal(const PackedBlock &lhs,
const PackedBlock &rhs,
const index_t batch,
const index_t height,
const index_t depth,
const index_t width,
PackedBlock *result);
void RunPerBatch(const float *lhs,
const float *rhs,
const index_t height,
const index_t depth,
const index_t width,
float *result);
std::unique_ptr<Tensor> packed_lhs_;
std::unique_ptr<Tensor> packed_rhs_;
std::unique_ptr<Tensor> packed_result_;
bool packed_;
}; };
} // namespace kernels } // namespace kernels
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <random>
#include <vector>
#include "mace/kernels/sgemm.h"
namespace mace {
namespace kernels {
namespace test {
namespace {
void TestPack(const std::vector<float> &data,
const std::vector<float> &expected_data,
const index_t height,
const index_t width,
Major src_order,
PackOrder pack_order) {
SGemm sg;
MatrixMap<const float> src_matrix(1, height, width, src_order, data.data());
PackedBlock packed;
packed.Resize({height, width});
if (pack_order == PackOrder::ColMajor) {
sg.PackLhs(src_matrix, &packed);
} else {
sg.PackRhs(src_matrix, &packed);
}
auto packed_data = packed.data<float>();
for (index_t i = 0; i < packed.size(); ++i) {
EXPECT_EQ(expected_data[i], packed_data[i]);
}
}
void TestUnPack(const index_t height,
const index_t width,
Major src_order,
PackOrder pack_order) {
static auto seed = static_cast<unsigned int>(time(nullptr));
const index_t matrix_size = height * width;
std::vector<float> data(matrix_size);
for (int i = 0; i < matrix_size; ++i) {
data[i] = rand_r(&seed);
}
MatrixMap<const float> src_matrix(1, height, width, src_order, data.data());
PackedBlock packed;
packed.Resize({height, width});
SGemm sg;
if (pack_order == PackOrder::ColMajor) {
sg.PackLhs(src_matrix, &packed);
} else {
sg.PackRhs(src_matrix, &packed);
}
std::vector<float> unpacked(matrix_size);
MatrixMap<float>
unpacked_matrix(1, height, width, src_order, unpacked.data());
sg.UnPack(packed, &unpacked_matrix);
auto unpacked_data = unpacked.data();
for (index_t i = 0; i < packed.size(); ++i) {
EXPECT_EQ(data[i], unpacked_data[i]);
}
}
} // namespace
TEST(SGemmPackTest, Pack) {
std::vector<float> data =
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36};
// For no-transpose lhs
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3, 4, Major::RowMajor, PackOrder::ColMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16},
4, 4, Major::RowMajor, PackOrder::ColMajor);
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19,
20},
5, 4, Major::RowMajor, PackOrder::ColMajor);
#if defined(__aarch64__)
TestPack(data,
{1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11,
15, 19, 23, 27, 31, 4, 8, 12, 16, 20, 24, 28, 32, 33, 34, 35, 36},
9, 4, Major::RowMajor, PackOrder::ColMajor);
#endif
#endif
// For transpose-needed lhs
TestPack(data,
{1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12},
3, 4, Major::ColMajor, PackOrder::ColMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
4, 4, Major::ColMajor, PackOrder::ColMajor);
TestPack(data,
{1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15,
20},
5, 4, Major::ColMajor, PackOrder::ColMajor);
#if defined(__aarch64__)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21,
22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 9, 18, 27, 36},
9, 4, Major::ColMajor, PackOrder::ColMajor);
#endif
#endif
// For no-transpose rhs
TestPack(data,
{1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12},
4, 3, Major::RowMajor, PackOrder::RowMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
4, 4, Major::RowMajor, PackOrder::RowMajor);
TestPack(data,
{1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15,
20},
4, 5, Major::RowMajor, PackOrder::RowMajor);
#endif
// For transpose-needed rhs
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
4, 3, Major::ColMajor, PackOrder::RowMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16},
4, 4, Major::ColMajor, PackOrder::RowMajor);
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19,
20},
4, 5, Major::ColMajor, PackOrder::RowMajor);
#endif
}
TEST(SGemmPackTest, UnPack) {
TestUnPack(4, 3, Major::RowMajor, PackOrder::RowMajor);
TestUnPack(4, 4, Major::RowMajor, PackOrder::RowMajor);
TestUnPack(4, 5, Major::RowMajor, PackOrder::RowMajor);
TestUnPack(4, 100, Major::RowMajor, PackOrder::RowMajor);
TestUnPack(4, 3, Major::ColMajor, PackOrder::RowMajor);
TestUnPack(4, 4, Major::ColMajor, PackOrder::RowMajor);
TestUnPack(4, 5, Major::ColMajor, PackOrder::RowMajor);
TestUnPack(4, 100, Major::ColMajor, PackOrder::RowMajor);
}
} // namespace test
} // namespace kernels
} // namespace mace
...@@ -40,7 +40,11 @@ class MatMulOp : public Operator<D, T> { ...@@ -40,7 +40,11 @@ class MatMulOp : public Operator<D, T> {
"than or equal to 2"); "than or equal to 2");
index_t rank = A->dim_size(); index_t rank = A->dim_size();
for (index_t i = 0; i < rank - 2; ++i) { for (index_t i = 0; i < rank - 2; ++i) {
MACE_CHECK(A->dim(i) == B->dim(i), "batch dimensions are not equal"); MACE_CHECK(A->dim(i) == B->dim(i),
"batch dimensions are not equal: ",
A->dim(i),
" vs. ",
B->dim(i));
} }
index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1); index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1);
index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2); index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2);
......
...@@ -33,13 +33,15 @@ void MatMulBenchmark( ...@@ -33,13 +33,15 @@ void MatMulBenchmark(
// Add input data // Add input data
net.AddRandomInput<D, T>("A", {batch, height, channels}); net.AddRandomInput<D, T>("A", {batch, height, channels});
net.AddRandomInput<D, T>("B", {batch, channels, out_width}); net.AddRandomInput<D, T>("B", {batch, channels, out_width});
net.GetTensor("A")->SetIsWeight(true);
net.GetTensor("B")->SetIsWeight(true);
if (DataTypeToEnum<T>::value == DT_UINT8) { if (DataTypeToEnum<T>::value == DT_UINT8) {
net.GetTensor("A")->SetScale(0.1); net.GetTensor("A")->SetScale(0.1);
net.GetTensor("B")->SetScale(0.1); net.GetTensor("B")->SetScale(0.1);
} }
if (D == DeviceType::GPU) { if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH); BufferToImage<D, T>(&net, "A", "AImage",
kernels::BufferType::IN_OUT_WIDTH);
BufferToImage<D, T>(&net, "B", "BImage", BufferToImage<D, T>(&net, "B", "BImage",
kernels::BufferType::IN_OUT_HEIGHT); kernels::BufferType::IN_OUT_HEIGHT);
...@@ -71,7 +73,7 @@ void MatMulBenchmark( ...@@ -71,7 +73,7 @@ void MatMulBenchmark(
mace::testing::StartTiming(); mace::testing::StartTiming();
while (iters--) { while (iters--) {
net.RunOp(D); net.Run();
} }
net.Sync(); net.Sync();
} }
...@@ -86,6 +88,8 @@ void MatMulTransposeBenchmark( ...@@ -86,6 +88,8 @@ void MatMulTransposeBenchmark(
// Add input data // Add input data
net.AddRandomInput<D, T>("A", {batch, height, channels}); net.AddRandomInput<D, T>("A", {batch, height, channels});
net.AddRandomInput<D, T>("B", {batch, out_width, channels}); net.AddRandomInput<D, T>("B", {batch, out_width, channels});
net.GetTensor("A")->SetIsWeight(true);
net.GetTensor("B")->SetIsWeight(true);
if (DataTypeToEnum<T>::value == DT_UINT8) { if (DataTypeToEnum<T>::value == DT_UINT8) {
net.GetTensor("A")->SetScale(0.1); net.GetTensor("A")->SetScale(0.1);
net.GetTensor("B")->SetScale(0.1); net.GetTensor("B")->SetScale(0.1);
...@@ -116,7 +120,7 @@ void MatMulTransposeBenchmark( ...@@ -116,7 +120,7 @@ void MatMulTransposeBenchmark(
mace::testing::StartTiming(); mace::testing::StartTiming();
while (iters--) { while (iters--) {
net.RunOp(D); net.Run();
} }
net.Sync(); net.Sync();
} }
...@@ -154,10 +158,15 @@ void MatMulTransposeBenchmark( ...@@ -154,10 +158,15 @@ void MatMulTransposeBenchmark(
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \ MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU); MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU);
MACE_BM_MATMUL(1, 128, 128, 49);
MACE_BM_MATMUL(2, 128, 128, 49);
MACE_BM_MATMUL(3, 128, 128, 49);
MACE_BM_MATMUL(4, 128, 128, 49);
MACE_BM_MATMUL(16, 32, 128, 49); MACE_BM_MATMUL(16, 32, 128, 49);
MACE_BM_MATMUL(16, 32, 128, 961); MACE_BM_MATMUL(16, 32, 128, 961);
MACE_BM_MATMUL(16, 32, 128, 3969); MACE_BM_MATMUL(16, 32, 128, 3969);
MACE_BM_MATMUL(16, 128, 128, 49); MACE_BM_MATMUL(16, 128, 128, 49);
MACE_BM_MATMUL(16, 49, 128, 128);
MACE_BM_MATMUL(16, 128, 128, 961); MACE_BM_MATMUL(16, 128, 128, 961);
MACE_BM_MATMUL(16, 128, 128, 3969); MACE_BM_MATMUL(16, 128, 128, 3969);
......
...@@ -211,8 +211,8 @@ void WinoMatMulBenchmark( ...@@ -211,8 +211,8 @@ void WinoMatMulBenchmark(
const index_t round_w = (width + block_size - 1) / block_size; const index_t round_w = (width + block_size - 1) / block_size;
const index_t out_width = round_h * round_w; const index_t out_width = round_h * round_w;
// Add input data // Add input data
net.AddRandomInput<D, float>("A", {batch, out_channels, in_channels, 1}); net.AddRandomInput<D, float>("A", {batch, out_channels, in_channels});
net.AddRandomInput<D, float>("B", {batch, in_channels, out_width, 1}); net.AddRandomInput<D, float>("B", {batch, in_channels, out_width});
if (D == DeviceType::GPU) { if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH); BufferToImage<D, T>(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册