提交 cdbe44f8 编写于 作者: M Megvii Engine Team

feat(dnn): add gemv supports in conv1x1 with format NCHW

GitOrigin-RevId: 97679e85265dd8aa18ed022f3e02f240a796f79e
上级 6972fc7d
......@@ -193,7 +193,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
DEFAULT \
}
#define FOR_BIAS(_bias_mode) \
#define FOR_BIAS(_bias_mode, OH, OW) \
switch (_bias_mode) { \
case megdnn::BiasMode::NO_BIAS: \
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \
......@@ -208,6 +208,10 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
} \
break; \
default: \
if (OH * OW == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
break; \
} \
megdnn_throw("quantized unsupported biasmode"); \
break; \
}
......@@ -218,7 +222,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
FOR_BIAS(bias_mode);
//! when OH * OW = 1, the bias_mode will be BiasMode::BIAS. It is wrong,
//! we deal this case at default branch.
FOR_BIAS(bias_mode, OH, OW);
}
};
......
......@@ -43,9 +43,9 @@ void exec_int_8x8x16(const MatrixMulImpl::KernParam& kern_param) {
size_t N = kern_param.N;
size_t K = kern_param.K;
size_t LDB = kern_param.LDB;
exec_gemm_int8_int8_int16(kern_param.A<dt_int8>(),
kern_param.B<dt_int8>(),
kern_param.C<dt_int16>(), M, K, N, LDB, w0, w1);
exec_gemm_int8_int8_int16(
kern_param.A<dt_int8>(), kern_param.B<dt_int8>(),
kern_param.C<dt_int16>(), M, K, N, LDB, w0, w1);
}
MIDOUT_END();
}
......@@ -79,8 +79,7 @@ void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
arm_common::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB,
LDC);
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
} // anonymous namespace
......@@ -110,7 +109,7 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
} // anonymous namespace
......@@ -140,25 +139,14 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(
/* ===================== F32 Gevm algo ===================== */
namespace {
void gevm_fp32_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDB = kern_param.LDB;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
arm_common::sgemm_sgemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
}
void gevm_int8_kern(const MatrixMulImpl::KernParam& kern_param) {
template <typename stype, typename dtype>
void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDB = kern_param.LDB;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
arm_common::matmul::gemv_like_int8(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>();
auto Cptr = kern_param.C<dtype>();
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
}
} // anonymous namespace
bool MatrixMulImpl::AlgoGevm::usable(
......@@ -170,8 +158,16 @@ bool MatrixMulImpl::AlgoGevm::usable(
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32();
return (fp32_ok || can_be_treated_as_int8x8x32(kern_size_param)) &&
preferred(kern_size_param);
bool fp16_ok = false;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
fp16_ok = kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float16();
#endif
bool int8_ok = can_be_treated_as_int8x8x32(kern_size_param);
return (fp32_ok || fp16_ok || int8_ok) && preferred(kern_size_param);
}
bool MatrixMulImpl::AlgoGevm::preferred(
......@@ -183,11 +179,17 @@ bool MatrixMulImpl::AlgoGevm::preferred(
MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern(
const KernSizeParam& kern_size_param) const {
if (kern_size_param.A_type == dtype::Float32()) {
return gevm_fp32_kern;
return gevm_like_kern<dt_float32, dt_float32>;
} else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) {
return gevm_int8_kern;
} else {
return gevm_like_kern<dt_int8, dt_int32>;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
else if (kern_size_param.A_type == dtype::Float16()) {
return gevm_like_kern<__fp16, __fp16>;
}
#endif
else {
megdnn_assert(
false, "no avaliable kern got A_type: %s B_type: %s C_type: %s",
kern_size_param.A_type.name(), kern_size_param.B_type.name(),
......@@ -205,10 +207,10 @@ void f16_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
Bptr = kern_param.B<dt_float16>();
auto Cptr = kern_param.C<dt_float16>();
MIDOUT_BEGIN(megdnn_arm_hgemv, void) {
arm_common::hgemv_exec(reinterpret_cast<const __fp16*>(Aptr),
reinterpret_cast<const __fp16*>(Bptr),
reinterpret_cast<__fp16*>(Cptr), M, N, K, LDA,
LDB, LDC);
arm_common::gemv_like(reinterpret_cast<const __fp16*>(Aptr),
reinterpret_cast<const __fp16*>(Bptr),
reinterpret_cast<__fp16*>(Cptr), M, N, K, LDA,
LDB, LDC);
}
MIDOUT_END();
}
......
......@@ -96,11 +96,11 @@ void hgemv_naive_n(const __fp16* __restrict A, const __fp16* __restrict B,
}
} // namespace
void megdnn::arm_common::hgemv_exec(const __fp16* __restrict A,
const __fp16* __restrict B,
__fp16* __restrict C, size_t M, size_t N,
size_t K, size_t Astride, size_t Bstride,
size_t Cstride) {
void megdnn::arm_common::gemv_like(const __fp16* __restrict A,
const __fp16* __restrict B,
__fp16* __restrict C, size_t M, size_t N,
size_t K, size_t Astride, size_t Bstride,
size_t Cstride) {
megdnn_assert((M <= 4) || (M == 8 && K <= 2) || (N == 1 && Bstride == 1));
if (N == 1) {
return hgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
......
......@@ -16,13 +16,14 @@
namespace megdnn {
namespace arm_common {
void hgemv_exec(const __fp16* __restrict A, const __fp16* __restrict B,
__fp16* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
bool is_hgemv_preferred(bool transposeA, bool transposeB, size_t M, size_t N,
size_t K, size_t /*LDA*/, size_t LDB, size_t /*LDC*/);
void gemv_like(const __fp16* __restrict A, const __fp16* __restrict B,
__fp16* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
} // namespace aarch64
} // namespace megdnn
......
......@@ -42,40 +42,6 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B,
#define calculate(i) sum##i = vmlaq_f32(sum##i, a##i, b0);
#define vstore(i) C[(m + i) * Cstride] = vaddvq_f32(sum##i) + acc##i;
size_t m = 0;
for (; m + 4 <= M; m += 4) {
float acc0, acc1, acc2, acc3;
float32x4_t a0, a1, a2, a3, b0;
float32x4_t sum0, sum1, sum2, sum3;
UNROLL_OUT(vdupq_sum, 4)
size_t k = 0;
for (; k + 4 <= K; k += 4) {
UNROLL_OUT(loadA, 4)
UNROLL_OUT(loadB, 1)
UNROLL_OUT(calculate, 4)
}
UNROLL_OUT(reset_acc, 4)
for (; k < K; ++k) {
UNROLL_OUT(acc_calu, 4)
}
UNROLL_OUT(vstore, 4)
}
for (; m + 2 <= M; m += 2) {
float acc0, acc1;
float32x4_t a0, a1, b0;
float32x4_t sum0, sum1;
UNROLL_OUT(vdupq_sum, 2)
size_t k = 0;
for (; k + 4 <= K; k += 4) {
UNROLL_OUT(loadA, 2)
UNROLL_OUT(loadB, 1)
UNROLL_OUT(calculate, 2)
}
UNROLL_OUT(reset_acc, 2)
for (; k < K; ++k) {
UNROLL_OUT(acc_calu, 2)
}
UNROLL_OUT(vstore, 2)
}
for (; m < M; m += 1) {
float acc0;
float32x4_t a0, b0;
......@@ -107,9 +73,9 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B,
namespace megdnn {
namespace arm_common {
void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
void gemv_like(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1));
if (N == 1) {
return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
......
......@@ -20,9 +20,10 @@ bool is_sgemv_like_preferred(bool row_major, bool transposeA, bool transposeB,
size_t /* LDA */, size_t LDB, float beta,
size_t /* LDC */);
void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
void gemv_like(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
} // namespace arm_common
} // namespace megdnn
......
......@@ -172,9 +172,10 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
} // namespace
#endif
bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB,
size_t M, size_t N, size_t K,
size_t LDA, size_t LDB, size_t LDC) {
bool arm_common::is_gemv_like_preferred_int8(bool transposeA, bool transposeB,
size_t M, size_t N, size_t K,
size_t LDA, size_t LDB,
size_t LDC) {
MEGDNN_MARK_USED_VAR(LDA);
MEGDNN_MARK_USED_VAR(LDB);
MEGDNN_MARK_USED_VAR(LDC);
......@@ -188,15 +189,16 @@ bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB,
return N == 1 && LDB == 1;
}
void matmul::gemv_like_int8(const int8_t* __restrict A,
const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride,
size_t Bstride, size_t Cstride) {
void arm_common::gemv_like(const int8_t* __restrict A,
const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride,
size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1);
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv) {
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv,
midout_iv("INT8_gemv_like"_hash)) {
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
} MIDOUT_END();
}
MIDOUT_END();
}
// vim: syntax=cpp.doxygen
......@@ -15,16 +15,15 @@
namespace megdnn {
namespace arm_common {
namespace matmul {
bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M,
size_t N, size_t K, size_t LDA, size_t LDB,
size_t LDC);
void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
} // namespace matmul
void gemv_like(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
} // namespace arm_common
} // namespace megdnn
......
......@@ -54,7 +54,7 @@ size_t ConvBiasImpl::AlgoConv1x1::get_workspace(
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param);
auto matmul_param =
get_matmul_kern_param(param, OH * OW, compt_oc_block_size);
utils::get_matmul_kern_param(param, OH * OW, compt_oc_block_size);
auto pack_mode = m_matmul_algo->packmode();
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT) {
......@@ -92,7 +92,6 @@ size_t ConvBiasImpl::AlgoConv1x1::get_workspace(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns(
ConvBiasImpl* opr, const NCBKernSizeParam& param) const {
SmallVector<ConvBiasImpl::NCBKern> ret_kern;
size_t OH = param.osz[0];
size_t OW = param.osz[1];
size_t OC = param.filter_meta.ocpg;
......@@ -102,7 +101,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns(
size_t oc_blocks_per_group = div_ceil(OC, compt_oc_block_size);
auto matmul_param =
get_matmul_kern_param(param, OH * OW, compt_oc_block_size);
utils::get_matmul_kern_param(param, OH * OW, compt_oc_block_size);
WorkspaceBundle whole_bundle = {nullptr, {}};
WorkspaceBundle thread_bundle = {nullptr, {}};
WorkspaceBundle matmul_bundle = {nullptr, {}};
......@@ -138,7 +137,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns(
}
//! get thread bundle
thread_bundle = get_thread_bundle(param, matmul_bundle.get_size(2),
thread_bundle = utils::get_thread_bundle(param, matmul_bundle.get_size(2),
compt_oc_block_size);
Conv1x1StrategyBase* conv1x1_strategy =
......@@ -178,7 +177,6 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns(
}
}
ret_kern.push_back({kern_compt, {BATCH, GROUP, oc_blocks_per_group}});
return ret_kern;
}
......@@ -201,8 +199,11 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1)
return false;
if (param.src_type.enumv() != param.filter_type.enumv() &&
param.src_type.enumv() != DTypeEnum::Int8 &&
if(param.src_type.enumv() != param.filter_type.enumv()) {
return false;
}
if (param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
#if !MEGDNN_DISABLE_FLOAT16
......@@ -211,6 +212,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
param.src_type.enumv() != DTypeEnum::Float32) {
return false;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode
//! is identity otherwise return false mean that 8x8x32 and 8x8x16
//! not support PostProcess
......@@ -233,7 +235,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
size_t OH = param.osz[0];
size_t OW = param.osz[1];
MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param(
MatrixMulImpl::KernSizeParam matmul_param = utils::get_matmul_kern_param(
param, OH * OW, get_oc_tile_size_heuristic(param));
bool matmul_usable = m_matmul_algo->usable(matmul_param);
......@@ -250,3 +253,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
MIDOUT_END();
return false;
}
bool ConvBiasImpl::AlgoConv1x1::is_preferred(
ConvBiasImpl*, const NCBKernSizeParam& param) const {
size_t OH = param.osz[0];
size_t OW = param.osz[1];
if (OH * OW != 1) {
return true;
} else {
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
if (param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16) {
return true;
}
#elif MEGDNN_X86
size_t OC = param.filter_meta.ocpg;
if (OC > 2 || param.src_type.enumv() == DTypeEnum::Float32)
return true;
#endif
return false;
}
}
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -41,9 +41,7 @@ public:
SmallVector<NCBKern> dispatch_kerns(
ConvBiasImpl* opr, const NCBKernSizeParam& param) const override;
bool is_preferred(ConvBiasImpl*, const NCBKernSizeParam&) const override{
return true;
}
bool is_preferred(ConvBiasImpl*, const NCBKernSizeParam&) const override;
protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;
......
/**
* \file dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "megdnn/opr_param_defs.h"
#include "src/naive/convolution/helper.h"
#include "src/fallback/matrix_mul/gemv.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h"
#include "src/arm_common/matrix_mul/fp16/hgemv.h"
#include "src/arm_common/matrix_mul/int8/gemv.h"
#endif
#include "midout.h"
MIDOUT_DECL(megdnn_fallback_conv1x1_gemv)
using namespace megdnn;
using namespace fallback;
#if MEGDNN_X86
using namespace x86;
#endif
using namespace conv1x1;
namespace {
#if MEGDNN_X86
template <typename stype, typename btype, param::ConvBias::Format F>
struct GemvLike {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
megdnn_throw("x86 conv1x1 gemv only supports format : NCHW");
}
};
template <typename stype, typename btype>
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::fallback::gemv_like<stype, btype>(A, B, C, M, N, K, LDA, LDB,
LDC);
}
};
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
template <typename stype, typename btype, param::ConvBias::Format F>
struct GemvLike {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
megdnn_throw("arm conv1x1 gemv only supports format : NCHW");
}
};
template <typename stype, typename btype>
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC);
}
};
template <>
struct GemvLike<dt_int8, dt_int16, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int16* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::fallback::gemv_like<dt_int8, dt_int16>(A, B, C, M, N, K, LDA,
LDB, LDC);
}
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
struct GemvLike<dt_float16, dt_float16, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const dt_float16* A, const dt_float16* B,
dt_float16* C, size_t M, size_t N, size_t K,
size_t LDA, size_t LDB, size_t LDC, DType src,
DType filter) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::arm_common::gemv_like(reinterpret_cast<const __fp16*>(A),
reinterpret_cast<const __fp16*>(B),
reinterpret_cast<__fp16*>(C), M, N, K,
LDA, LDB, LDC);
}
};
#endif
#endif
template <>
struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const dt_uint8* A, const dt_uint8* B,
dt_int32* C, size_t M, size_t N, size_t K,
size_t LDA, size_t LDB, size_t LDC, DType src,
DType filter) {
uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point;
uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point;
megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA,
LDB, LDC, zp0, zp1);
}
};
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode,
param::ConvBias::Format format>
struct Conv1x1GemvWorker {
static void exec(WorkspaceBundle& whole_bundle,
WorkspaceBundle& thread_bundle, size_t oc_tile_size,
const ConvBiasImpl::NCBKernSizeParam& param,
const ConvBiasImpl::NCBKernParam& ncb_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
whole_bundle.set(ncb_param.workspace_ptr);
size_t OC = param.filter_meta.ocpg;
size_t IC = param.filter_meta.icpg;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t oc_tile_id_in_group = ncb_index.ndrange_id[2];
size_t thread_id = ncb_index.thread_id;
size_t oc_start = oc_tile_size * oc_tile_id_in_group;
size_t oc_end = oc_start + oc_tile_size;
oc_end = (oc_end <= OC ? oc_end : OC);
size_t numbers_of_ncb_filter_offset =
oc_tile_size * IC * oc_tile_id_in_group;
const src_ctype* Aptr = ncb_param.filter<src_ctype>(group_id) +
numbers_of_ncb_filter_offset;
const src_ctype* Bptr = ncb_param.src<src_ctype>(batch_id, group_id);
size_t thread_offset = thread_bundle.total_size_in_bytes() * thread_id;
size_t bytes_offset_of_matmul_dst_this_thread =
thread_offset + thread_bundle.get_size(0);
bias_ctype* matmul_temp_dst = reinterpret_cast<bias_ctype*>(
reinterpret_cast<int8_t*>(whole_bundle.get(0)) +
bytes_offset_of_matmul_dst_this_thread);
size_t numbers_of_ncb_dst_offset = oc_tile_size * oc_tile_id_in_group;
dst_ctype* conv_bias_dst =
ncb_param.dst<dst_ctype>(batch_id, group_id) +
numbers_of_ncb_dst_offset;
bool is_dst_8bit =
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
bias_ctype* gemv_dst =
is_dst_8bit ? matmul_temp_dst
: reinterpret_cast<bias_ctype*>(conv_bias_dst);
GemvLike<src_ctype, bias_ctype, format>::do_gemv(
Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC, 1, 1,
ncb_param.filter_type, ncb_param.src_type);
//! do postprocess
void* bias_ptr = nullptr;
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>(
ncb_param.bias<bias_ctype>(batch_id, group_id) +
numbers_of_ncb_dst_offset));
} else {
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>(
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start));
}
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
gemv_dst, bias_ptr, conv_bias_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z,
oc_end - oc_start, 1, 1, 1);
}
};
} // namespace
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic(
const NCBKernSizeParam& param) const {
size_t OC = param.filter_meta.ocpg;
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads);
return round_up<size_t>(oc_block_size_one_thread, 16);
}
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace(
ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
midout_iv("AlgoConv1x1Gemv::get_workspace"_hash)) {
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param);
auto thread_bundle =
utils::get_thread_bundle(param, 0, compt_oc_block_size);
return WorkspaceBundle{
nullptr,
{thread_bundle.total_size_in_bytes() * param.nr_threads}}
.total_size_in_bytes();
}
MIDOUT_END();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
ConvBiasImpl* opr, const NCBKernSizeParam& param) const {
SmallVector<ConvBiasImpl::NCBKern> ret_kern;
size_t OC = param.filter_meta.ocpg;
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param);
size_t GROUP = param.filter_meta.group;
size_t BATCH = param.n;
size_t oc_blocks_per_group = div_ceil(OC, compt_oc_block_size);
//! get thread bundle
auto thread_bundle =
utils::get_thread_bundle(param, 0, compt_oc_block_size);
auto whole_bundle = WorkspaceBundle{
nullptr, {thread_bundle.total_size_in_bytes() * param.nr_threads}};
using conv1x1_gemv_kern =
std::function<void(WorkspaceBundle&, WorkspaceBundle&, size_t,
const ConvBiasImpl::NCBKernSizeParam&,
const ConvBiasImpl::NCBKernParam&,
const ConvBiasImpl::NCBKernIndex&)>;
conv1x1_gemv_kern conv1x1_gemv_worker = nullptr;
#define cb1(_format, _dt, _post_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
conv1x1_gemv_worker = \
Conv1x1GemvWorker<_dt, _dt, _dt, _post_ctype, _post_ctype, \
_postprocess_mode, _format>::exec; \
} \
} \
MIDOUT_END()
#define cb2(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
conv1x1_gemv_worker = \
Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \
DTypeTrait<_i_bias_type>::ctype, \
DTypeTrait<_i_dst_type>::ctype, \
_postprocess_mode, _format>::exec; \
} \
} \
MIDOUT_END()
switch (opr->param().format) {
case param::ConvBias::Format::NCHW:
cb1(param::ConvBias::Format::NCHW, dt_float32, dt_float32,
PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT"_hash);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb1(param::ConvBias::Format::NCHW, dt_float16, __fp16,
PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT16_FP16"_hash);
#endif
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32,
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"NCHW::GEMV::INT8x8x32_INT32"_hash);
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16,
dt_int8, dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"NCHW::GEMV::INT8x8x16_INT16"_hash);
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
"NCHW::GEMV::QINT8x8x32_QINT32"_hash);
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
dt_int8, PostprocessMode::QUANTIZED,
"NCHW::GEMV::QINT8x8x32_QINT8"_hash);
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm,
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
"NCHW::GEMV::QUINT8x8x32_QINT32"_hash);
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm,
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32,
dt_uint8, PostprocessMode::QUANTIZED,
"NCHW::GEMV::QUINT8x8x32_QUINT8"_hash);
break;
default:
megdnn_throw("Invalid Format");
break;
}
#undef cb1
#undef cb2
megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker");
auto kern_compt =
[compt_oc_block_size, param, conv1x1_gemv_worker, whole_bundle,
thread_bundle](
const ConvBiasImpl::NCBKernParam& ncb_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) mutable {
conv1x1_gemv_worker(whole_bundle, thread_bundle,
compt_oc_block_size, param, ncb_param,
std::move(ncb_index));
};
ret_kern.push_back({kern_compt, {BATCH, GROUP, oc_blocks_per_group}});
return ret_kern;
}
bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr,
const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
midout_iv("AlgoConv1x1Gemv::usable"_hash)) {
//! whether 1x1
size_t FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
size_t PH = param.filter_meta.padding[0],
PW = param.filter_meta.padding[1];
size_t SH = param.filter_meta.stride[0],
SW = param.filter_meta.stride[1];
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) {
return false;
}
//! whether gemv
size_t OH = param.osz[0];
size_t OW = param.osz[1];
if (OH * OW != 1) {
return false;
}
//! even no naive support in gemv
if ((param.src_type.enumv() == param.filter_type.enumv() &&
param.src_type.enumv() == DTypeEnum::Int16) &&
param.dst_type.enumv() == DTypeEnum::Int32) {
return false;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode
//! is identity otherwise return false mean that 8x8x32 and 8x8x16
//! not support PostProcess
if (param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
}
//! supports a few dtypes
if (param.src_type.enumv() != param.filter_type.enumv()) {
return false;
}
if (param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
#if !MEGDNN_DISABLE_FLOAT16
param.src_type.enumv() != DTypeEnum::Float16 &&
#endif
param.src_type.enumv() != DTypeEnum::Float32) {
return false;
}
bool is_param_ok =
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT;
bool is_format_and_dtype_ok = false;
#if MEGDNN_X86
if (opr->param().format == param::ConvBias::Format::NCHW) {
//! x86 supports all dtypes in NCHW
is_format_and_dtype_ok = true;
}
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
//! add NCHW44 and NCHW44_DOT support in the future
if (opr->param().format == param::ConvBias::Format::NCHW) {
//! NCHW format supports all dtype
is_format_and_dtype_ok = true;
}
#endif
return is_param_ok && is_format_and_dtype_ok;
}
MIDOUT_END();
return false;
}
bool ConvBiasImpl::AlgoConv1x1Gemv::is_preferred(
ConvBiasImpl*, const NCBKernSizeParam& param) const {
size_t OC = param.filter_meta.ocpg;
if (OC <= 2 && param.src_type.enumv() != DTypeEnum::Float32)
return true;
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
//! maybe add support for QuantizedAsym in the future
return (param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32) ||
#if !MEGDNN_DISABLE_FLOAT16
(param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16) ||
#endif
(param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32);
#else
return false;
#endif
}
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/thin/small_vector.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn {
namespace fallback {
class ConvBiasImpl::AlgoConv1x1Gemv final : public AlgoBase {
public:
AlgoConv1x1Gemv() = default;
bool is_reproducible() const override { return true; }
const char* name() const override {
return "CONV1x1_GEMV";
}
bool usable(ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(
ConvBiasImpl* opr, const NCBKernSizeParam& param) const override;
bool is_preferred(ConvBiasImpl*, const NCBKernSizeParam&) const override;
protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -11,30 +11,12 @@
#pragma once
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h"
namespace megdnn {
namespace fallback {
namespace conv1x1 {
namespace {
//! get_thread_bundle
WorkspaceBundle get_thread_bundle(const ConvBiasImpl::NCBKernSizeParam& param,
size_t matmul_c_size, size_t oc_tile_size) {
//! for some cases, matmul result need temp space to store
size_t OH = param.osz[0];
size_t OW = param.osz[1];
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t matmul_dst_bytes_per_thread =
is_dst_8bit ? oc_tile_size * OH * OW * sizeof(param.bias_type) : 0;
return WorkspaceBundle{nullptr,
{matmul_c_size, matmul_dst_bytes_per_thread}};
}
} // anonymous namespace
template <MatrixMulImpl::AlgoBase::PackMode pack_mode>
class Conv1x1Kerns {
public:
......@@ -51,7 +33,7 @@ public:
//! matmul_param records a matmul with M = oc_tile_size, K = IC, N = OH
//! * OW this does not bother packb bytes
auto matmul_bundle = matmul_algo->get_bundle(matmul_param);
auto thread_bundle = get_thread_bundle(param, matmul_bundle.get_size(2),
auto thread_bundle = utils::get_thread_bundle(param, matmul_bundle.get_size(2),
oc_tile_size);
//! size per thread
......@@ -86,7 +68,7 @@ public:
const MatrixMulImpl::AlgoBase* matmul_algo,
size_t oc_tile_size) {
size_t matmul_size = matmul_algo->get_workspace(matmul_param);
auto thread_bundle = get_thread_bundle(param, matmul_size, oc_tile_size);
auto thread_bundle = utils::get_thread_bundle(param, matmul_size, oc_tile_size);
//! size per thread
size_t all_threads_bytes =
thread_bundle.total_size_in_bytes() * param.nr_threads;
......
......@@ -10,8 +10,8 @@
* implied.
*/
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h"
#include <unordered_map>
#include "midout.h"
......@@ -20,53 +20,7 @@ MIDOUT_DECL(megdnn_fallback_conv1x1_factory_strategy)
namespace megdnn {
namespace fallback {
namespace conv1x1 {
namespace {
struct StrategyHashParam {
ConvBiasImpl::NCBKernSizeParam param;
param::ConvBias::Format format;
MatrixMulImpl::AlgoBase::PackMode packmode;
};
struct StrategyHashParamHash {
std::size_t operator()(const StrategyHashParam& sparam) const {
constexpr size_t base = 1; //! avoid hashkey is zero
std::size_t result =
static_cast<std::size_t>(sparam.param.src_type.enumv()) + base;
result = result ^
((static_cast<std::size_t>(sparam.param.dst_type.enumv()) +
base)
<< 3);
result = result ^
((static_cast<std::size_t>(sparam.param.filter_type.enumv()) +
base)
<< 6);
result = result ^
((static_cast<std::size_t>(sparam.param.bias_type.enumv()) +
base)
<< 9);
result = result ^
((static_cast<std::size_t>(sparam.format) + base) << 12);
result = result ^
((static_cast<std::size_t>(sparam.packmode) + base) << 15);
return result;
};
};
struct StrategyHashParamEqual {
bool operator()(const StrategyHashParam& param1,
const StrategyHashParam& param2) const {
bool flags = true;
flags = param1.param.src_type == param2.param.src_type && flags;
flags = param1.param.filter_type == param2.param.filter_type && flags;
flags = param1.param.bias_type == param2.param.bias_type && flags;
flags = param1.param.dst_type == param2.param.dst_type && flags;
flags = param1.format == param2.format && flags;
flags = param1.packmode == param2.packmode && flags;
return flags;
};
};
//! NOTE: must keep consistence with can_make_conv1x1_strategy when you modify
//! this function
std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
......@@ -176,39 +130,14 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
megdnn_throw("Invalid Data Type");
return nullptr;
}
class StrategyDelegationStorage {
public:
Conv1x1StrategyBase* get(const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode,
param::ConvBias::Format format) {
MEGDNN_LOCK_GUARD(m_mtx);
StrategyHashParam sparam;
sparam.param = param;
sparam.format = format;
sparam.packmode = pack_mode;
if (m_map_strategies.find(sparam) == m_map_strategies.end()) {
auto strategy = create_conv1x1_strategy(param, pack_mode, format);
m_map_strategies[sparam] = std::move(strategy);
}
return m_map_strategies[sparam].get();
}
private:
std::mutex m_mtx;
std::unordered_map<StrategyHashParam, std::unique_ptr<Conv1x1StrategyBase>,
StrategyHashParamHash, StrategyHashParamEqual>
m_map_strategies;
};
} // anonymous namespace
Conv1x1StrategyBase* Conv1x1Factory::make_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode,
param::ConvBias::Format format) {
static StrategyDelegationStorage storage;
return storage.get(param, pack_mode, format);
static utils::StrategyDelegationStorage<Conv1x1StrategyBase> storage;
return storage.get(param, pack_mode, format, create_conv1x1_strategy);
}
bool Conv1x1Factory::can_make_conv1x1_strategy(
......@@ -277,3 +206,5 @@ bool Conv1x1Factory::can_make_conv1x1_strategy(
} // namespace conv1x1
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -13,6 +13,8 @@
#pragma once
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
......@@ -27,44 +29,6 @@ namespace conv1x1 {
using namespace x86;
#endif
namespace {
//! get_matmul_kern_param
MatrixMulImpl::KernSizeParam get_matmul_kern_param(
const ConvBiasImpl::NCBKernSizeParam& param, size_t n, size_t m) {
size_t M = m;
size_t N = n;
size_t K = param.filter_meta.icpg; //! K = IC
size_t LDA = K, LDB = N, LDC = N;
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t pack_c_size = pack_size(param.filter_meta.format);
auto format = param::MatrixMul::Format::DEFAULT;
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
format = param::MatrixMul::Format::MK4;
} else if (param.filter_meta.format ==
param::ConvBias::Format::NCHW44_DOT) {
format = param::MatrixMul::Format::MK4_DOT;
}
return {param.filter_type,
param.src_type,
is_dst_8bit ? param.bias_type : param.dst_type,
M,
N,
K,
LDA * pack_c_size,
LDB * pack_c_size,
LDC * pack_c_size,
false,
false,
param::MatrixMul::ComputeMode::DEFAULT,
format};
}
} // namespace
class Conv1x1StrategyBase {
public:
virtual void packA(WorkspaceBundle& whole_bundle,
......@@ -134,7 +98,7 @@ public:
size_t IC = param.filter_meta.icpg;
MatrixMulImpl::KernParam matmul_kern_param;
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) =
get_matmul_kern_param(param, OH * OW, oc_end - oc_start);
utils::get_matmul_kern_param(param, OH * OW, oc_end - oc_start);
size_t bytes_offset_of_a_panel =
group_id * packa_bytes_per_group +
......@@ -176,8 +140,7 @@ public:
MatrixMulImpl::KernParam matmul_kern_param;
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) =
get_matmul_kern_param(param, OH * OW, OC);
utils::get_matmul_kern_param(param, OH * OW, OC);
rep(batch, BATCH) {
rep(g, GROUP) {
......@@ -238,7 +201,7 @@ public:
MatrixMulImpl::KernParam matmul_kern_param;
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) =
get_matmul_kern_param(param, OH * OW, oc_end - oc_start);
utils::get_matmul_kern_param(param, OH * OW, oc_end - oc_start);
size_t bytes_offset_of_a_panel =
group_id * packa_bytes_per_group +
......@@ -328,7 +291,6 @@ public:
MatrixMulImpl::AlgoBase::PackMode pack_mode,
param::ConvBias::Format format);
};
} // namespace conv1x1
} // namespace fallback
} // namespace megdnn
......
/**
* \file dnn/src/fallback/conv_bias/conv1x1/conv1x1_utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h"
namespace megdnn {
namespace fallback {
namespace conv1x1 {
namespace utils{
//! get_thread_bundle
WorkspaceBundle get_thread_bundle(const ConvBiasImpl::NCBKernSizeParam& param,
size_t matmul_c_size, size_t oc_tile_size) {
//! for some cases, matmul result need temp space to store
size_t OH = param.osz[0];
size_t OW = param.osz[1];
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t matmul_dst_bytes_per_thread =
is_dst_8bit ? oc_tile_size * OH * OW * sizeof(param.bias_type) : 0;
return WorkspaceBundle{nullptr,
{matmul_c_size, matmul_dst_bytes_per_thread}};
}
//! get_matmul_kern_param
MatrixMulImpl::KernSizeParam get_matmul_kern_param(
const ConvBiasImpl::NCBKernSizeParam& param, size_t n, size_t m) {
size_t M = m;
size_t N = n;
size_t K = param.filter_meta.icpg; //! K = IC
size_t LDA = K, LDB = N, LDC = N;
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t pack_c_size = pack_size(param.filter_meta.format);
auto format = param::MatrixMul::Format::DEFAULT;
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
format = param::MatrixMul::Format::MK4;
} else if (param.filter_meta.format ==
param::ConvBias::Format::NCHW44_DOT) {
format = param::MatrixMul::Format::MK4_DOT;
}
return {param.filter_type,
param.src_type,
is_dst_8bit ? param.bias_type : param.dst_type,
M,
N,
K,
LDA * pack_c_size,
LDB * pack_c_size,
LDC * pack_c_size,
false,
false,
param::MatrixMul::ComputeMode::DEFAULT,
format};
}
} // namespace utils
} // namespace conv1x1
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file dnn/src/fallback/conv_bias/conv1x1/conv1x1_utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <unordered_map>
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn {
namespace fallback {
namespace conv1x1 {
namespace utils {
struct StrategyHashKey {
ConvBiasImpl::NCBKernSizeParam param;
param::ConvBias::Format format;
MatrixMulImpl::AlgoBase::PackMode packmode;
};
struct StrategyHasher {
std::size_t operator()(const StrategyHashKey& key) const {
constexpr size_t base = 1; //! avoid hashkey is zero
std::size_t result =
static_cast<std::size_t>(key.param.src_type.enumv()) + base;
result = result ^
((static_cast<std::size_t>(key.param.dst_type.enumv()) + base)
<< 3);
result = result ^
((static_cast<std::size_t>(key.param.filter_type.enumv()) +
base)
<< 6);
result = result ^
((static_cast<std::size_t>(key.param.bias_type.enumv()) + base)
<< 9);
result = result ^ ((static_cast<std::size_t>(key.format) + base) << 12);
result = result ^
((static_cast<std::size_t>(key.packmode) + base) << 15);
return result;
}
};
struct StrategyHashKeyEqual {
bool operator()(const StrategyHashKey& key1,
const StrategyHashKey& key2) const {
return key1.param.src_type == key2.param.src_type &&
key1.param.filter_type == key2.param.filter_type &&
key1.param.bias_type == key2.param.bias_type &&
key1.param.dst_type == key2.param.dst_type &&
key1.format == key2.format && key1.packmode == key2.packmode;
}
};
template <typename T>
class StrategyDelegationStorage {
using creator = std::function<std::unique_ptr<T>(
const ConvBiasImpl::NCBKernSizeParam&,
MatrixMulImpl::AlgoBase::PackMode, param::ConvBias::Format)>;
public:
T* get(const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode,
param::ConvBias::Format format, creator Fun) {
MEGDNN_LOCK_GUARD(m_mtx);
StrategyHashKey key;
key.param = param;
key.format = format;
key.packmode = pack_mode;
if (m_map_strategies.find(key) == m_map_strategies.end()) {
auto strategy = Fun(param, pack_mode, format);
m_map_strategies[key] = std::move(strategy);
}
return m_map_strategies[key].get();
}
private:
std::mutex m_mtx;
std::unordered_map<StrategyHashKey, std::unique_ptr<T>, StrategyHasher,
StrategyHashKeyEqual>
m_map_strategies;
};
//! get_thread_bundle
WorkspaceBundle get_thread_bundle(const ConvBiasImpl::NCBKernSizeParam& param,
size_t matmul_c_size, size_t oc_tile_size);
//! get_matmul_kern_param
MatrixMulImpl::KernSizeParam get_matmul_kern_param(
const ConvBiasImpl::NCBKernSizeParam& param, size_t n, size_t m);
} // namespace utils
} // namespace conv1x1
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -647,8 +647,11 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return false;
}
if (param.src_type.enumv() != param.filter_type.enumv() &&
param.src_type.enumv() != DTypeEnum::Int8 &&
if(param.src_type.enumv() != param.filter_type.enumv()) {
return false;
}
if (param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
#if !MEGDNN_DISABLE_FLOAT16
......
......@@ -16,6 +16,7 @@
#include "src/common/utils.h"
#include "src/fallback/conv_bias/algos.h"
#include "src/fallback/conv_bias/conv1x1/algos.h"
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
#include "src/fallback/conv_bias/im2col/algos.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/naive/convolution/algorithms.h"
......@@ -53,6 +54,10 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
refhold.emplace_back(new AlgoConv1x1Gemv());
all_algos.emplace_back(refhold.back().get());
static CpuOprDelegationStorage<> storage;
auto matmul_opr = storage.get<MatrixMul>();
auto&& matmul_algos =
......
......@@ -259,6 +259,7 @@ private:
class AlgoNaive;
class AlgoIm2col;
class AlgoConv1x1;
class AlgoConv1x1Gemv;
class AlgoWinogradF32;
class AlgoWinogradF32_4x4;
class AlgoWinogradQS8;
......
......@@ -11,6 +11,7 @@
#include "src/fallback/matrix_mul/algos.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#include "src/fallback/matrix_mul/gemv.h"
#include "src/fallback/matrix_mul/generic_strategy.h"
#include "midout.h"
......@@ -71,39 +72,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern,
float);
/* ===================== gemv algo ===================== */
namespace {
template <typename itype, typename otype, bool have_zp = false>
void gemm_gemv_like(const MatrixMulImpl::KernParam& kern_param) {
const itype* A = kern_param.A<itype>();
const itype* B = kern_param.B<itype>();
uint8_t zp0, zp1;
if (have_zp) {
zp0 = kern_param.A_type.param<dtype::Quantized8Asymm>().zero_point;
zp1 = kern_param.B_type.param<dtype::Quantized8Asymm>().zero_point;
}
otype* C = kern_param.C<otype>();
for (size_t m = 0; m < kern_param.M; ++m) {
memset(C + m * kern_param.LDC, 0, sizeof(otype) * kern_param.N);
for (size_t k = 0; k < kern_param.K; ++k)
for (size_t n = 0; n < kern_param.N; ++n) {
if (!have_zp)
C[m * kern_param.LDC + n] +=
static_cast<otype>(A[m * kern_param.LDA + k]) *
static_cast<otype>(B[k * kern_param.LDB + n]);
else {
C[m * kern_param.LDC + n] +=
(static_cast<otype>(A[m * kern_param.LDA + k]) -
static_cast<otype>(zp0)) *
(static_cast<otype>(B[k * kern_param.LDB + n]) -
static_cast<otype>(zp1));
}
}
}
}
} // anonymous namespace
bool MatrixMulImpl::AlgoGemv::usable(
const KernSizeParam& kern_size_param) const {
return !kern_size_param.trA && !kern_size_param.trB &&
......
/**
* \file dnn/src/fallback/matrix_mul/gemv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/opr_impl.h"
namespace megdnn {
namespace fallback{
template <typename itype, typename otype>
void gemv_like(const itype* A, const itype* B, otype* C, size_t M, size_t N,
size_t K, size_t LDA, size_t LDB, size_t LDC) {
for (size_t m = 0; m < M; ++m) {
memset(C + m * LDC, 0, sizeof(otype) * N);
for (size_t k = 0; k < K; ++k)
for (size_t n = 0; n < N; ++n) {
C[m * LDC + n] += static_cast<otype>(A[m * LDA + k]) *
static_cast<otype>(B[k * LDB + n]);
}
}
}
template <typename itype, typename otype>
void gemv_like(const itype* A, const itype* B, otype* C, size_t M, size_t N,
size_t K, size_t LDA, size_t LDB, size_t LDC, uint8_t zp0,
uint8_t zp1) {
for (size_t m = 0; m < M; ++m) {
memset(C + m * LDC, 0, sizeof(otype) * N);
for (size_t k = 0; k < K; ++k)
for (size_t n = 0; n < N; ++n) {
C[m * LDC + n] += (static_cast<otype>(A[m * LDA + k]) -
static_cast<otype>(zp0)) *
(static_cast<otype>(B[k * LDB + n]) -
static_cast<otype>(zp1));
}
}
}
template <typename itype, typename otype, bool have_zp = false>
void gemm_gemv_like(const MatrixMulImpl::KernParam& kern_param) {
const itype* A = kern_param.A<itype>();
const itype* B = kern_param.B<itype>();
otype* C = kern_param.C<otype>();
size_t M = kern_param.M;
size_t N = kern_param.N;
size_t K = kern_param.K;
size_t LDA = kern_param.LDA;
size_t LDB = kern_param.LDB;
size_t LDC = kern_param.LDC;
if (have_zp) {
uint8_t zp0 = kern_param.A_type.param<dtype::Quantized8Asymm>().zero_point;
uint8_t zp1 = kern_param.B_type.param<dtype::Quantized8Asymm>().zero_point;
gemv_like<itype, otype>(A, B, C, M, N, K, LDA, LDB, LDC, zp0, zp1);
}
else {
gemv_like<itype, otype>(A, B, C, M, N, K, LDA, LDB, LDC);
}
}
} // namespace fallback
} // namespace megdnn
\ No newline at end of file
......@@ -1861,6 +1861,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
#elif MEGDNN_ARMV7
check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48");
#endif
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
......@@ -1905,16 +1911,23 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
"CONV1x1:AARCH32_F16_K4X16X1:24");
#endif
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
}
#endif
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
std::vector<conv_bias::TestArg> args =
get_conv_bias_1x1_args(false, false, true, true);
#define cb(name) \
checker_conv_bias(args, handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
#if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
......@@ -1928,17 +1941,27 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48");
#endif
#undef cb
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
checker_conv_bias(gemv_args, handle(), &rng, epsilon,
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f),
"CONV1x1_GEMV");
}
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
NormalRNG rng(128.f);
#define cb(name) \
checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \
handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), \
UniformIntRNG rng{-50, 50};
std::vector<conv_bias::TestArg> args =
get_conv_bias_1x1_args(false, false, true, true);
#define cb(name) \
checker_conv_bias(args, handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), \
dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
......@@ -1952,17 +1975,29 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48");
#endif
#undef cb
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
checker_conv_bias(gemv_args, handle(), &rng, epsilon,
dtype::Quantized8Asymm(1.2f, (uint8_t)125),
dtype::Quantized8Asymm(1.3f, (uint8_t)129),
dtype::QuantizedS32(1.2 * 1.3),
dtype::Quantized8Asymm(50.3f, (uint8_t)120),
"CONV1x1_GEMV");
}
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
UniformIntRNG rng{-50, 50};
NormalRNG rng(128.f);
float epsilon = 0.001;
#define cb(name) \
checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
epsilon, dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
#define cb(name) \
checker_conv_bias(args, handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), {}, name);
#if MEGDNN_AARCH64
......@@ -1978,15 +2013,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
#endif
#undef cb
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
checker_conv_bias(gemv_args, handle(), &rng, epsilon,
dtype::Quantized8Asymm(1.2f, (uint8_t)125),
dtype::Quantized8Asymm(1.3f, (uint8_t)129),
dtype::QuantizedS32(1.2 * 1.3), {}, "CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \
epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \
dtype::Int16{}, name);
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
#define cb(name) \
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
......@@ -1997,6 +2042,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
#endif
cb("CONV1x1:ARM_COMMON_INT8X8X16:48");
#undef cb
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
checker_conv_bias(gemv_args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
"CONV1x1_GEMV");
}
#endif
......@@ -2024,6 +2079,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48");
#endif
#undef cb
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV");
}
#ifndef __ARM_FEATURE_DOTPROD
......
......@@ -254,6 +254,50 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {
for (size_t N : {512, 1024})
run(M, K, N);
}
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
int exec_times = 50;
Benchmarker<MatrixMul> benchmarker(handle());
benchmarker.set_times(exec_times);
benchmarker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV"));
auto run = [&](size_t M, size_t K, size_t N) {
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")"
<< std::endl;
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32());
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times;
auto computations = 2 * M * K * N * 1e-6;
auto perf = computations / time;
std::cout << "gemv fp32, Performance is " << perf << " Gflops"
<< std::endl;
};
std::cout << "warm up:\n";
for (int i = 0; i < 50; i++) {
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_display(false)
.exec({{2, 1024}, {1024, 512}, {}});
benchmarker.set_display(true);
}
// run gemv
run(12, 48, 1);
run(48, 12, 1);
run(32, 128, 1);
run(128, 32, 1);
run(64, 256, 1);
run(256, 64, 1);
run(128, 512, 1);
run(512, 128, 1);
run(256, 1024, 1);
run(1024, 256, 1);
}
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) {
int exec_times = 50;
Benchmarker<MatrixMul> benchmarker(handle());
......@@ -290,6 +334,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) {
for (size_t N : {512, 1024})
run(M, K, N);
}
TEST_F(ARM_COMMON, BENCHMARK_SGEMM) {
int exec_times = 10;
Benchmarker<MatrixMul> benchmarker(handle());
......
......@@ -1081,7 +1081,7 @@ std::vector<megdnn::test::conv_bias::TestArg> get_conv_bias_1x1_args(
for (size_t n : {1, 2})
for (size_t oc : {1, 9, 33})
for (size_t ic : {1, 16, 64})
for (size_t size : {7, 14, 28})
for (size_t size : {1, 7, 14, 28})
for (auto nlmode : nonlinemode)
for (auto convmode : convmodes) {
pack(n, oc, ic, size, size, 1, nlmode, convmode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册