From cc85f82c6a77aaf4aabbf619ce9d3dc61670d460 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Tue, 28 Aug 2018 12:24:26 +0800 Subject: [PATCH] Init sgemm --- mace/kernels/sgemm.cc | 92 +++++++++++++++++++++++++++ mace/kernels/sgemm.h | 140 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 mace/kernels/sgemm.cc create mode 100644 mace/kernels/sgemm.h diff --git a/mace/kernels/sgemm.cc b/mace/kernels/sgemm.cc new file mode 100644 index 00000000..ae9a4e0f --- /dev/null +++ b/mace/kernels/sgemm.cc @@ -0,0 +1,92 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mace/kernels/sgemm.h" + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +namespace mace { +namespace kernels { + +void SGemm::operator()(const MatrixMap &lhs, + const MatrixMap &rhs, + MatrixMap *result) { + PackedBlock packed_lhs; + PackLhs(lhs, &packed_lhs); + + PackedBlock packed_rhs; + PackRhs(rhs, &packed_rhs); + + PackedBlock packed_result; + operator()(packed_lhs, + packed_rhs, + lhs.row(), + lhs.col(), + rhs.col(), + &packed_result); + UnPack(packed_result, result); +} + +void SGemm::operator()(const PackedBlock &lhs, + const PackedBlock &rhs, + const index_t height, + const index_t depth, + const index_t width, + PackedBlock *result) { + (void) lhs; + (void) rhs; + (void) result; + (void) height; + (void) depth; + (void) width; + + // (8, 8) * (8, 4) + + // (4, 4) * (4, 4) + + // remain +} + +void SGemm::PackLhs(const MatrixMap &lhs, + PackedBlock *packed_block) { + Pack(lhs, PackOrder::ColMajor, packed_block); +} + +void SGemm::PackRhs(const MatrixMap &rhs, + PackedBlock *packed_block) { + Pack(rhs, PackOrder::RowMajor, packed_block); +} + +void SGemm::UnPack(const PackedBlock &packed_result, + MatrixMap *matrix_map) { + (void) packed_result; + (void) matrix_map; +} + +void SGemm::Pack(const MatrixMap &src, + const PackOrder order, + PackedBlock *packed_block) { + (void) src; + (void) order; + (void) packed_block; +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/sgemm.h b/mace/kernels/sgemm.h new file mode 100644 index 00000000..15cec1dd --- /dev/null +++ b/mace/kernels/sgemm.h @@ -0,0 +1,140 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_SGEMM_H_ +#define MACE_KERNELS_SGEMM_H_ + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +#include "mace/core/types.h" +#include "mace/core/allocator.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +enum Major { + RowMajor, + ColMajor +}; + +template +class MatrixMap { + public: + MatrixMap(const index_t row, + const index_t col, + const Major major, + T *data) : + row_(row), + col_(col), + stride_(major == RowMajor ? col : row), + major_(major), + data_(data) {} + + MatrixMap transpose(const MatrixMap &matrix_map) { + Major transpose_major = matrix_map.major_ == RowMajor ? ColMajor : RowMajor; + return MatrixMap(matrix_map.col_, + matrix_map.row_, + transpose_major, + matrix_map.data_); + } + + index_t row() const { + return row_; + } + + index_t col() const { + return col_; + } + + index_t stride() const { + return stride_; + } + + Major major() const { + return major_; + } + + T *data() const { + return data_; + } + + T *data(int row, int col) const { + return data_ + row * stride_ + col; + } + + private: + index_t row_; + index_t col_; + index_t stride_; + Major major_; + T *data_; +}; + +typedef Major PackOrder; + +template +class PackedBlock { + public: + PackedBlock() : data_tensor_(GetDeviceAllocator(CPU), + DataTypeToEnum::v()) {} + + const T *data() { + return data_tensor_.data(); + } + + T *mutable_data() { + return data_tensor_.mutable_data(); + } + + Tensor *tensor() { + return &data_tensor_; + } + + private: + Tensor data_tensor_; +}; + +class SGemm { + public: + void operator()(const MatrixMap &lhs, + const MatrixMap &rhs, + MatrixMap *result); + + void operator()(const PackedBlock &lhs, + const PackedBlock &rhs, + const index_t height, + const index_t depth, + const index_t width, + PackedBlock *result); + + void PackLhs(const MatrixMap &lhs, PackedBlock *packed_block); + + void PackRhs(const MatrixMap &rhs, PackedBlock *packed_block); + + void UnPack(const PackedBlock &packed_result, + MatrixMap *matrix_map); + + private: + void Pack(const MatrixMap &src, + const PackOrder order, + PackedBlock *packed_block); +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_SGEMM_H_ -- GitLab