提交 f5829926 编写于 作者: 李滨

Merge branch 'sgemm' into 'master'

Init sgemm

See merge request !770
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstring>
#include <vector>
#include "mace/kernels/sgemm.h"
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
namespace mace {
namespace kernels {
void SGemm::operator()(const MatrixMap<float> &lhs,
const MatrixMap<float> &rhs,
MatrixMap<float> *result) {
PackedBlock<float> packed_lhs;
PackLhs(lhs, &packed_lhs);
PackedBlock<float> packed_rhs;
PackRhs(rhs, &packed_rhs);
PackedBlock<float> packed_result;
operator()(packed_lhs,
packed_rhs,
lhs.row(),
lhs.col(),
rhs.col(),
&packed_result);
UnPack(packed_result, result);
}
void SGemm::operator()(const PackedBlock<float> &lhs,
const PackedBlock<float> &rhs,
const index_t height,
const index_t depth,
const index_t width,
PackedBlock<float> *result) {
(void) lhs;
(void) rhs;
(void) result;
(void) height;
(void) depth;
(void) width;
// (8, 8) * (8, 4)
// (4, 4) * (4, 4)
// remain
}
void SGemm::PackLhs(const MatrixMap<float> &lhs,
PackedBlock<float> *packed_block) {
Pack(lhs, PackOrder::ColMajor, packed_block);
}
void SGemm::PackRhs(const MatrixMap<float> &rhs,
PackedBlock<float> *packed_block) {
Pack(rhs, PackOrder::RowMajor, packed_block);
}
void SGemm::UnPack(const PackedBlock<float> &packed_result,
MatrixMap<float> *matrix_map) {
(void) packed_result;
(void) matrix_map;
}
void SGemm::Pack(const MatrixMap<float> &src,
const PackOrder order,
PackedBlock<float> *packed_block) {
(void) src;
(void) order;
(void) packed_block;
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_SGEMM_H_
#define MACE_KERNELS_SGEMM_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/core/allocator.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
enum Major {
RowMajor,
ColMajor
};
template<typename T>
class MatrixMap {
public:
MatrixMap(const index_t row,
const index_t col,
const Major major,
T *data) :
row_(row),
col_(col),
stride_(major == RowMajor ? col : row),
major_(major),
data_(data) {}
MatrixMap<T> transpose(const MatrixMap<T> &matrix_map) {
Major transpose_major = matrix_map.major_ == RowMajor ? ColMajor : RowMajor;
return MatrixMap<T>(matrix_map.col_,
matrix_map.row_,
transpose_major,
matrix_map.data_);
}
index_t row() const {
return row_;
}
index_t col() const {
return col_;
}
index_t stride() const {
return stride_;
}
Major major() const {
return major_;
}
T *data() const {
return data_;
}
T *data(int row, int col) const {
return data_ + row * stride_ + col;
}
private:
index_t row_;
index_t col_;
index_t stride_;
Major major_;
T *data_;
};
typedef Major PackOrder;
template<typename T>
class PackedBlock {
public:
PackedBlock() : data_tensor_(GetDeviceAllocator(CPU),
DataTypeToEnum<T>::v()) {}
const T *data() {
return data_tensor_.data<T>();
}
T *mutable_data() {
return data_tensor_.mutable_data<T>();
}
Tensor *tensor() {
return &data_tensor_;
}
private:
Tensor data_tensor_;
};
class SGemm {
public:
void operator()(const MatrixMap<float> &lhs,
const MatrixMap<float> &rhs,
MatrixMap<float> *result);
void operator()(const PackedBlock<float> &lhs,
const PackedBlock<float> &rhs,
const index_t height,
const index_t depth,
const index_t width,
PackedBlock<float> *result);
void PackLhs(const MatrixMap<float> &lhs, PackedBlock<float> *packed_block);
void PackRhs(const MatrixMap<float> &rhs, PackedBlock<float> *packed_block);
void UnPack(const PackedBlock<float> &packed_result,
MatrixMap<float> *matrix_map);
private:
void Pack(const MatrixMap<float> &src,
const PackOrder order,
PackedBlock<float> *packed_block);
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_SGEMM_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册