提交 9577c8cb 编写于 作者: B Bin Li

Add quantized matmul

上级 2dbe26d6
......@@ -31,6 +31,10 @@ struct GemmlowpOutputPipeline {
gemmlowp::OutputStageBiasAddition<ColVectorMap>,
gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
gemmlowp::OutputStageSaturatingCastToUint8> Pipeline;
typedef std::tuple<
gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
gemmlowp::OutputStageSaturatingCastToUint8> NoBiasPipeline;
static Pipeline Make(
const int32_t *bias_data, const index_t channels, const float lhs_scale,
const float rhs_scale, const float output_scale,
......@@ -52,6 +56,23 @@ struct GemmlowpOutputPipeline {
return std::make_tuple(bias_addition_stage, quantize_down_stage,
saturating_cast_stage);
}
static NoBiasPipeline MakeNoBias(
const float lhs_scale, const float rhs_scale, const float output_scale,
const int32_t output_zero_point) {
int32_t quantized_multiplier;
int32_t right_shift;
kernels::GetOutputMultiplierAndShift(lhs_scale, rhs_scale, output_scale,
&quantized_multiplier, &right_shift);
gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
quantize_down_stage;
quantize_down_stage.result_offset_after_shift = output_zero_point;
quantize_down_stage.result_fixedpoint_multiplier = quantized_multiplier;
quantize_down_stage.result_shift = right_shift;
gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
return std::make_tuple(quantize_down_stage, saturating_cast_stage);
}
};
} // namespace mace
......
......@@ -30,6 +30,7 @@
#include "mace/core/tensor.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
#include "mace/kernels/gemmlowp_util.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
......@@ -92,6 +93,101 @@ struct MatMulFunctor {
}
};
template <>
struct MatMulFunctor<CPU, uint8_t> {
template<gemmlowp::MapOrder AOrder, gemmlowp::MapOrder BOrder>
void MatMulImpl(const Tensor *A,
const Tensor *B,
const index_t height,
const index_t K,
const index_t width,
Tensor *C) {
gemmlowp::GemmContext& gemm_context = GetGemmlowpContext();
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
auto a_ptr_base = A->data<uint8_t>();
auto b_ptr_base = B->data<uint8_t>();
auto c_ptr_base = C->mutable_data<uint8_t>();
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
index_t a_size = height * K;
index_t b_size = K * width;
index_t c_size = height * width;
const auto &output_pipeline = GemmlowpOutputPipeline::MakeNoBias(
A->scale(), B->scale(), C->scale(), C->zero_point());
for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K);
gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width);
gemmlowp::MatrixMap<uint8_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width);
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<uint8_t, uint8_t, BitDepthParams>(
&gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(),
-B->zero_point(), output_pipeline);
}
}
MaceStatus operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
bool transpose_a,
bool transpose_b,
StatsFuture *future) {
MACE_UNUSED(future);
index_t rank = A->dim_size();
index_t height = A->dim(rank - 2);
index_t K = A->dim(rank - 1);
index_t width;
if (transpose_a) {
std::swap(height, K);
}
if (transpose_b) {
width = B->dim(rank - 2);
} else {
width = B->dim(rank - 1);
}
std::vector<index_t> c_shape = A->shape();
c_shape[rank - 2] = height;
c_shape[rank - 1] = width;
MACE_RETURN_IF_ERROR(C->Resize(c_shape));
constexpr gemmlowp::MapOrder kRowMajor = gemmlowp::MapOrder::RowMajor;
constexpr gemmlowp::MapOrder kColMajor = gemmlowp::MapOrder::ColMajor;
#define MATMUL_IMPL(AOrder, BOrder) \
MatMulImpl<AOrder, BOrder>(A, B, height, K, width, C);
if (transpose_a) {
if (transpose_b) {
MATMUL_IMPL(kColMajor, kColMajor);
} else {
MATMUL_IMPL(kColMajor, kRowMajor);
}
} else {
if (transpose_b) {
MATMUL_IMPL(kRowMajor, kColMajor);
} else {
MATMUL_IMPL(kRowMajor, kRowMajor);
}
}
#undef MATMUL_IMPL
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
struct MatMulFunctor<DeviceType::GPU, T> {
......
......@@ -109,9 +109,9 @@ void MatmulBenchmark_Mace(int iters, int m, int k, int n) {
void MatmulBenchmark_Eigen(int iters, int m, int k, int n) {
mace::testing::StopTiming();
Eigen::MatrixXd lhs = Eigen::MatrixXd::Random(m, k);
Eigen::MatrixXd rhs = Eigen::MatrixXd::Random(k, n);
Eigen::MatrixXd result = Eigen::MatrixXd::Zero(m, n);
Eigen::MatrixXf lhs = Eigen::MatrixXf::Random(m, k);
Eigen::MatrixXf rhs = Eigen::MatrixXf::Random(k, n);
Eigen::MatrixXf result = Eigen::MatrixXf::Zero(m, n);
// warm up
result = lhs * rhs;
mace::testing::StartTiming();
......
......@@ -24,6 +24,12 @@ void Register_MatMul(OperatorRegistryBase *op_registry) {
.Build(),
MatMulOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
MatMulOp<DeviceType::CPU, uint8_t>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul")
.Device(DeviceType::GPU)
......
......@@ -31,8 +31,12 @@ void MatMulBenchmark(
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("A", {batch, height, channels});
net.AddRandomInput<D, float>("B", {batch, channels, out_width});
net.AddRandomInput<D, T>("A", {batch, height, channels});
net.AddRandomInput<D, T>("B", {batch, channels, out_width});
if (DataTypeToEnum<T>::value == DT_UINT8) {
net.GetTensor("A")->SetScale(0.1);
net.GetTensor("B")->SetScale(0.1);
}
if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH);
......@@ -50,12 +54,18 @@ void MatMulBenchmark(
.Input("A")
.Input("B")
.Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
}
net.Setup(D);
if (DataTypeToEnum<T>::value == DT_UINT8) {
net.GetTensor("Output")->SetScale(0.1);
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
for (int i = 0; i < 2; ++i) {
net.Run();
}
net.Sync();
......@@ -74,8 +84,12 @@ void MatMulTransposeBenchmark(
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("A", {batch, height, channels});
net.AddRandomInput<D, float>("B", {batch, out_width, channels});
net.AddRandomInput<D, T>("A", {batch, height, channels});
net.AddRandomInput<D, T>("B", {batch, out_width, channels});
if (DataTypeToEnum<T>::value == DT_UINT8) {
net.GetTensor("A")->SetScale(0.1);
net.GetTensor("B")->SetScale(0.1);
}
if (D == DeviceType::CPU) {
OpDefBuilder("MatMul", "MatMulBM")
......@@ -83,14 +97,20 @@ void MatMulTransposeBenchmark(
.Input("B")
.AddIntArg("transpose_b", 1)
.Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
MACE_NOT_IMPLEMENTED;
}
net.Setup(D);
if (DataTypeToEnum<T>::value == DT_UINT8) {
net.GetTensor("Output")->SetScale(0.1);
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
for (int i = 0; i < 2; ++i) {
net.Run();
}
net.Sync();
......@@ -116,7 +136,8 @@ void MatMulTransposeBenchmark(
#define MACE_BM_MATMUL(N, H, C, W) \
MACE_BM_MATMUL_MACRO(N, H, C, W, float, CPU); \
MACE_BM_MATMUL_MACRO(N, H, C, W, float, GPU); \
MACE_BM_MATMUL_MACRO(N, H, C, W, half, GPU);
MACE_BM_MATMUL_MACRO(N, H, C, W, half, GPU); \
MACE_BM_MATMUL_MACRO(N, H, C, W, uint8_t, CPU);
#define MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, TYPE, DEVICE) \
static void MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE( \
......@@ -130,7 +151,8 @@ void MatMulTransposeBenchmark(
MACE_BENCHMARK(MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_MATMUL_TRANPOSE(N, H, C, W) \
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU);
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU);
MACE_BM_MATMUL(16, 32, 128, 49);
MACE_BM_MATMUL(16, 32, 128, 961);
......
......@@ -213,6 +213,116 @@ TEST_F(MatMulOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>({2, 3}, 31, 61, 67);
}
namespace {
void Quant(const std::vector<index_t> &batch,
const index_t height,
const index_t channels,
const index_t out_width,
const bool transpose_a,
const bool transpose_b) {
// Construct graph
OpsTestNet net;
// Add input data
index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1,
std::multiplies<index_t>());
if (transpose_a) {
net.AddRandomInput<CPU, float>("A", {batch_count, channels, height});
} else {
net.AddRandomInput<CPU, float>("A", {batch_count, height, channels});
}
if (transpose_b) {
net.AddRandomInput<CPU, float>("B", {batch_count, out_width, channels});
} else {
net.AddRandomInput<CPU, float>("B", {batch_count, channels, out_width});
}
OpDefBuilder("MatMul", "MatMulTest")
.Input("A")
.AddIntArg("transpose_a", transpose_a ? 1 : 0)
.Input("B")
.AddIntArg("transpose_b", transpose_b ? 1 : 0)
.Output("Output")
.AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
OpDefBuilder("Quantize", "QuantizeA")
.Input("A")
.Output("QuantizedA")
.OutputType({DT_UINT8})
.AddIntArg("T", DT_UINT8)
.AddIntArg("non_zero", true)
.Finalize(net.NewOperatorDef());
net.RunOp();
OpDefBuilder("Quantize", "QuantizeB")
.Input("B")
.Output("QuantizedB")
.OutputType({DT_UINT8})
.AddIntArg("T", DT_UINT8)
.AddIntArg("non_zero", true)
.Finalize(net.NewOperatorDef());
net.RunOp();
OpDefBuilder("Quantize", "QuantizeOutput")
.Input("Output")
.Output("ExpectedQuantizedOutput")
.OutputType({DT_UINT8})
.AddIntArg("T", DT_UINT8)
.AddIntArg("non_zero", true)
.Finalize(net.NewOperatorDef());
net.RunOp();
OpDefBuilder("MatMul", "QuantizeMatMulTest")
.Input("QuantizedA")
.AddIntArg("transpose_a", transpose_a ? 1 : 0)
.Input("QuantizedB")
.AddIntArg("transpose_b", transpose_b ? 1 : 0)
.Output("QuantizedOutput")
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
net.Setup(DeviceType::CPU);
Tensor *eq_output = net.GetTensor("ExpectedQuantizedOutput");
Tensor *q_output = net.GetTensor("QuantizedOutput");
q_output->SetScale(eq_output->scale());
q_output->SetZeroPoint(eq_output->zero_point());
net.Run();
OpDefBuilder("Dequantize", "DeQuantizeTest")
.Input("QuantizedOutput")
.Output("DequantizedOutput")
.OutputType({DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
net.RunOp();
// Check
ExpectTensorSimilar<float>(*net.GetOutput("Output"),
*net.GetTensor("DequantizedOutput"), 0.01);
}
} // namespace
TEST_F(MatMulOpTest, Quant) {
Quant({1}, 64, 128, 32, false, false);
Quant({1}, 64, 32, 128, false, false);
Quant({2, 3}, 64, 32, 128, false, false);
Quant({1}, 64, 128, 32, false, true);
Quant({1}, 64, 32, 128, false, true);
Quant({2, 3}, 64, 32, 128, false, true);
Quant({1}, 64, 128, 32, true, false);
Quant({1}, 64, 32, 128, true, false);
Quant({2, 3}, 64, 32, 128, true, false);
Quant({1}, 64, 128, 32, true, true);
Quant({1}, 64, 32, 128, true, true);
Quant({2, 3}, 64, 32, 128, true, true);
// UnAligned
Quant({2}, 3, 3, 3, false, false);
Quant({16}, 31, 61, 67, false, true);
Quant({31}, 31, 61, 67, true, false);
Quant({2, 3}, 31, 61, 67, true, true);
}
// TODO(liyin): test transpose after implementing gpu runtime
// now transpose test is in kernels_test
......
......@@ -136,14 +136,13 @@ void SetGPUHints(GPUPerfHint perf_hint, GPUPriorityHint priority_hint);
/// also be truncated to the corresponding cores number when num_threads_hint
/// is larger than it.
/// The OpenMP threads will be bind to (via sched_setaffinity) big cores
/// (AFFINITY_BIG_ONLY) and little cores (AFFINITY_LITTLE_ONLY).
/// (AFFINITY_BIG_ONLY) or little cores (AFFINITY_LITTLE_ONLY).
///
/// \param num_threads_hint it is only a hint.
/// \param policy one of CPUAffinityPolicy
/// \param status MACE_SUCCESS for successful, or it can't reliabley
/// detect big-LITTLE cores (see GetBigLittleCoreIDs). In such cases, it's
/// suggested to use AFFINITY_NONE to use all cores.
/// \return
/// \return MACE_SUCCESS for success, or it can't reliably detect big-LITTLE
/// cores (see GetBigLittleCoreIDs). In such cases, it's suggested to use
/// AFFINITY_NONE to use all cores.
__attribute__((visibility("default")))
MaceStatus SetOpenMPThreadPolicy(int num_threads_hint,
CPUAffinityPolicy policy);
......@@ -159,7 +158,6 @@ MaceStatus SetOpenMPThreadPolicy(int num_threads_hint,
///
/// \param num_threads
/// \param cpu_ids
/// \param status
/// \return
__attribute__((visibility("default")))
MaceStatus SetOpenMPThreadAffinity(int num_threads,
......@@ -180,13 +178,25 @@ __attribute__((visibility("default")))
MaceStatus GetBigLittleCoreIDs(std::vector<int> *big_core_ids,
std::vector<int> *little_core_ids);
// Set gemmlowp threads number and processor affinity.
// gemmlowp is used by mace for quantization.
// Caution: this function may hurt performance if improper parameters provided.
//
// This function may not work well on some chips (e.g. MTK). Setting thread
// affinity to offline cores may run very slow or unexpectedly. In such cases,
// please use SetGemmlowpThreadPolicy with default policy instead.
/// \brief Set gemmlowp threads number and affinity policy for quantization.
///
/// Caution: this function may hurt performance if improper parameters provided.
/// gemmlowp shares threads with OpenMP, which are set by SetOpenMPThreadPolicy,
/// so affinity policy set by these two functions should be the same.
/// When num_threads_hint is zero or negative,
/// the function will set the threads number equaling to the number of
/// big (AFFINITY_BIG_ONLY), little (AFFINITY_LITTLE_ONLY) or all
/// (AFFINITY_NONE) cores according to the policy. The threads number will
/// also be truncated to the corresponding cores number when num_threads_hint
/// is larger than it.
/// The gemmlowp threads will be bind to (via sched_setaffinity) big cores
/// (AFFINITY_BIG_ONLY) or little cores (AFFINITY_LITTLE_ONLY).
///
/// \param num_threads_hint it is only a hint.
/// \param policy one of CPUAffinityPolicy
/// \return MACE_SUCCESS for success, or it can't reliably detect big-LITTLE
/// cores (see GetBigLittleCoreIDs). In such cases, it's suggested to use
/// AFFINITY_NONE to use all cores.
__attribute__((visibility("default")))
MaceStatus SetGemmlowpThreadPolicy(int num_threads_hint,
CPUAffinityPolicy policy);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册