提交 714cb232 编写于 作者: M Megvii Engine Team

feat(dnn): add gemv supports in conv1x1 for NCHW44 and NCHW44_DOT(aarch64 binary size grows 2KB)

GitOrigin-RevId: f8b6d7a1b749fa469aa86f9b8bc2e9ef7f25c937
上级 9cc66963
......@@ -226,7 +226,13 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
break; \
default: \
if (OH * OW == 1) { \
if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \
megdnn_assert(pack_oc_size == 4, \
"Only support nchw44 in ARM"); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \
} \
break; \
} \
megdnn_throw("quantized unsupported biasmode"); \
......
......@@ -101,6 +101,91 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern(
return int8x8x32_gemv_kern;
}
/* ===================== Int8x8x32 Gemv MK4 algo ===================== */
namespace {
void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::usable(
const KernSizeParam& kern_size_param) const {
auto M = kern_size_param.M;
auto N = kern_size_param.N;
auto K = kern_size_param.K;
auto LDB = kern_size_param.LDB;
bool is_dtype_ok =
kern_size_param.A_type == kern_size_param.B_type &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32);
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB &&
M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4;
}
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::preferred(
const KernSizeParam& kern_size_param) const {
MEGDNN_MARK_USED_VAR(kern_size_param);
return true;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern(
const KernSizeParam&) const {
return int8x8x32_gemv_mk4_kern;
}
#if __ARM_FEATURE_DOTPROD
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */
namespace {
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable(
const KernSizeParam& kern_size_param) const {
auto M = kern_size_param.M;
auto N = kern_size_param.N;
auto K = kern_size_param.K;
auto LDB = kern_size_param.LDB;
bool is_dtype_ok =
kern_size_param.A_type == kern_size_param.B_type &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32);
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB &&
M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4;
}
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::preferred(
const KernSizeParam& kern_size_param) const {
return true;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern(
const KernSizeParam&) const {
return int8x8x32_gemv_mk4_dot_kern;
}
#endif
/* ===================== F32 Gemv algo ===================== */
namespace {
void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
......@@ -137,6 +222,46 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(
return f32_gemv_kern;
}
/* ================== F32 Gemv MK4 algo ================== */
namespace {
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
} // anonymous namespace
bool MatrixMulImpl::AlgoF32GemvMK4::usable(
const KernSizeParam& kern_size_param) const {
// enumerate the M, N, K, only usable when preferred
auto M = kern_size_param.M;
auto N = kern_size_param.N;
auto K = kern_size_param.K;
auto LDB = kern_size_param.LDB;
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 &&
LDB == 4;
}
bool MatrixMulImpl::AlgoF32GemvMK4::preferred(
const KernSizeParam& kern_size_param) const {
MEGDNN_MARK_USED_VAR(kern_size_param);
return true;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern(
const KernSizeParam&) const {
return f32_gemv_mk4_kern;
}
/* ===================== F32 Gevm algo ===================== */
namespace {
template <typename stype, typename dtype>
......
......@@ -43,6 +43,36 @@ public:
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
};
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
};
#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
};
#endif
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase {
protected:
~AlgoF32Gemv() = default;
......@@ -60,6 +90,20 @@ public:
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
};
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase {
public:
......@@ -87,10 +131,9 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4)
};
} // namespace arm_common
} // namespace megdnn
......
......@@ -13,11 +13,11 @@
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h"
#include <cstddef>
#include "include/megdnn/oprs.h"
#include "midout.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fp32_sgemv)
using namespace megdnn;
......@@ -68,18 +68,10 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B,
#if !defined(__aarch64__)
#undef vaddvq_f32
#endif
} // namespace
namespace megdnn {
namespace arm_common {
void gemv_like(const float* __restrict A, const float* __restrict B,
void sgemv_naive_m(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1));
if (N == 1) {
return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
size_t m = 0;
for (; m + 4 <= M; m += 4) {
size_t k = 0;
......@@ -762,6 +754,85 @@ void gemv_like(const float* __restrict A, const float* __restrict B,
}
}
}
void sgemv_naive_n_mk4(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
constexpr size_t PACK_SIZE = 4;
megdnn_assert(N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 &&
K % PACK_SIZE == 0);
auto Aptr = A;
auto Cptr = C;
size_t m = 0;
while (m < M) {
auto Aptr0 = Aptr;
auto Cptr0 = Cptr;
float32x4_t c[4];
#define INIT(step) c[step] = vdupq_n_f32(0.0f);
UNROLL_CALL_RAW(4, INIT)
#undef INIT
auto Bptr = B;
size_t k = 0;
while (k < K) {
float32x4_t b = vld1q_f32(Bptr);
float32x4x2_t a[2];
#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8);
UNROLL_CALL_RAW(2, LOAD_A)
#undef LOAD_A
#define COMPT(step) \
c[step] = vfmaq_laneq_f32(c[step], a[step / 2].val[step % 2], b, step % 4);
UNROLL_CALL_RAW(4, COMPT)
#undef COMPT
Bptr += Bstride;
Aptr0 += PACK_SIZE * PACK_SIZE;
k += PACK_SIZE;
}
#define ADD_C(step, stride) c[step] = vaddq_f32(c[step], c[step + stride]);
UNROLL_CALL_RAW(2, ADD_C, 2)
UNROLL_CALL_RAW(1, ADD_C, 1)
#undef ADD_C
vst1q_f32(Cptr0, c[0]);
Aptr += Astride;
Cptr += Cstride;
m += PACK_SIZE;
}
}
} // namespace
namespace megdnn {
namespace arm_common {
void gemv_like(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1));
if (N == 1) {
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_N"_hash)) {
return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
MIDOUT_END();
} else {
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_M"_hash)) {
return sgemv_naive_m(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
MIDOUT_END();
}
}
void gemv_like_mk4(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1 && Bstride == 4);
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW44_N"_hash)) {
return sgemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
MIDOUT_END();
}
} // namespace arm_common
} // namespace megdnn
......
......@@ -24,6 +24,9 @@ void gemv_like(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
void gemv_like_mk4(const float* __restrict A, const float* __restrict B,
float* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
} // namespace arm_common
} // namespace megdnn
......
......@@ -10,8 +10,8 @@
*/
#include <cstddef>
#include "src/arm_common/matrix_mul/int8/gemv.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/matrix_mul/int8/gemv.h"
#include "src/common/utils.h"
#include "megdnn/oprs.h"
......@@ -95,6 +95,80 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
C[m * Cstride] = acc0;
}
}
void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
constexpr size_t PACK_SIZE = 4;
megdnn_assert(N == 1 && Bstride == 4);
auto Aptr = A;
size_t m = 0;
for (; m < M; m += PACK_SIZE) {
auto Bptr = B;
auto Aptr0 = Aptr;
int32_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0;
size_t k = 0;
for (; k + 16 <= K; k += 16) {
int8x16x4_t a = vld4q_s8(Aptr0);
int8x16_t b = vld1q_s8(Bptr);
int16x8_t c[4];
c[0] = vmull_s8(vget_low_s8(a.val[0]), vget_low_s8(b));
c[1] = vmull_s8(vget_low_s8(a.val[1]), vget_low_s8(b));
c[2] = vmull_s8(vget_low_s8(a.val[2]), vget_low_s8(b));
c[3] = vmull_s8(vget_low_s8(a.val[3]), vget_low_s8(b));
c[0] = vmlal_high_s8(c[0], a.val[0], b);
c[1] = vmlal_high_s8(c[1], a.val[1], b);
c[2] = vmlal_high_s8(c[2], a.val[2], b);
c[3] = vmlal_high_s8(c[3], a.val[3], b);
acc0 += vaddlvq_s16(c[0]);
acc1 += vaddlvq_s16(c[1]);
acc2 += vaddlvq_s16(c[2]);
acc3 += vaddlvq_s16(c[3]);
Bptr += 16;
Aptr0 += PACK_SIZE * 16;
}
for (; k + 8 <= K; k += 8) {
int8x8x4_t a = vld4_s8(Aptr0);
int8x8_t b = vld1_s8(Bptr);
int16x8_t c[4];
c[0] = vmull_s8(a.val[0], b);
c[1] = vmull_s8(a.val[1], b);
c[2] = vmull_s8(a.val[2], b);
c[3] = vmull_s8(a.val[3], b);
acc0 += vaddlvq_s16(c[0]);
acc1 += vaddlvq_s16(c[1]);
acc2 += vaddlvq_s16(c[2]);
acc3 += vaddlvq_s16(c[3]);
Bptr += 8;
Aptr0 += PACK_SIZE * 8;
}
for (; k < K; ++k) {
acc0 += static_cast<int32_t>(*(Aptr0 + 0)) * B[k];
acc1 += static_cast<int32_t>(*(Aptr0 + 1)) * B[k];
acc2 += static_cast<int32_t>(*(Aptr0 + 2)) * B[k];
acc3 += static_cast<int32_t>(*(Aptr0 + 3)) * B[k];
Aptr0 += 4;
}
C[0] = acc0;
C[1] = acc1;
C[2] = acc2;
C[3] = acc3;
Aptr += Astride;
C += Cstride;
}
}
} // namespace
#endif
......@@ -169,6 +243,139 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3];
}
}
void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
constexpr size_t PACK_SIZE = 4;
megdnn_assert(N == 1 && Bstride == 4);
auto Aptr = A;
size_t m = 0;
for (; m < M; m += PACK_SIZE) {
auto Bptr = B;
auto Aptr0 = Aptr;
int32_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0;
size_t k = 0;
if (k + 16 <= K) {
int32x4_t acc_neon[4];
acc_neon[0] = vdupq_n_s32(0);
acc_neon[1] = vdupq_n_s32(0);
acc_neon[2] = vdupq_n_s32(0);
acc_neon[3] = vdupq_n_s32(0);
for (; k + 16 <= K; k += 16) {
int8x16x4_t a = vld4q_s8(Aptr0);
int8x16_t b = vld1q_s8(Bptr);
acc_neon[0] = vdotq_s32(acc_neon[0], a.val[0], b);
acc_neon[1] = vdotq_s32(acc_neon[1], a.val[1], b);
acc_neon[2] = vdotq_s32(acc_neon[2], a.val[2], b);
acc_neon[3] = vdotq_s32(acc_neon[3], a.val[3], b);
Bptr += 16;
Aptr0 += PACK_SIZE * 16;
}
acc0 = vaddvq_s32(acc_neon[0]);
acc1 = vaddvq_s32(acc_neon[1]);
acc2 = vaddvq_s32(acc_neon[2]);
acc3 = vaddvq_s32(acc_neon[3]);
}
if (k + 8 <= K) {
int32x2_t acc_neon[4];
acc_neon[0] = vdup_n_s32(0);
acc_neon[1] = vdup_n_s32(0);
acc_neon[2] = vdup_n_s32(0);
acc_neon[3] = vdup_n_s32(0);
int8x8x4_t a = vld4_s8(Aptr0);
int8x8_t b = vld1_s8(Bptr);
acc_neon[0] = vdot_s32(acc_neon[0], a.val[0], b);
acc_neon[1] = vdot_s32(acc_neon[1], a.val[1], b);
acc_neon[2] = vdot_s32(acc_neon[2], a.val[2], b);
acc_neon[3] = vdot_s32(acc_neon[3], a.val[3], b);
Bptr += 8;
Aptr0 += PACK_SIZE * 8;
k += 8;
acc0 += vaddv_s32(acc_neon[0]);
acc1 += vaddv_s32(acc_neon[1]);
acc2 += vaddv_s32(acc_neon[2]);
acc3 += vaddv_s32(acc_neon[3]);
}
for (; k < K; ++k) {
acc0 += static_cast<int32_t>(*(Aptr0 + 0)) * B[k];
acc1 += static_cast<int32_t>(*(Aptr0 + 1)) * B[k];
acc2 += static_cast<int32_t>(*(Aptr0 + 2)) * B[k];
acc3 += static_cast<int32_t>(*(Aptr0 + 3)) * B[k];
Aptr0 += 4;
}
C[0] = acc0;
C[1] = acc1;
C[2] = acc2;
C[3] = acc3;
Aptr += Astride;
C += Cstride;
}
}
void gemv_naive_n_mk4_dot(const int8_t* __restrict A,
const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride,
size_t Bstride, size_t Cstride) {
constexpr size_t PACK_SIZE = 4;
megdnn_assert(N == 1 && Bstride == 4);
auto Aptr = A;
size_t m = 0;
for (; m < M; m += PACK_SIZE) {
auto Bptr = B;
auto Aptr0 = Aptr;
size_t k = 0;
int32x4_t acc_neon;
acc_neon = vdupq_n_s32(0);
for (; k + 16 <= K; k += 16) {
int8x16_t a0 = vld1q_s8(Aptr0);
int8x16_t a1 = vld1q_s8(Aptr0 + 16);
int8x16_t a2 = vld1q_s8(Aptr0 + 32);
int8x16_t a3 = vld1q_s8(Aptr0 + 48);
int8x16_t b = vld1q_s8(Bptr);
acc_neon = vdotq_laneq_s32(acc_neon, a0, b, 0);
acc_neon = vdotq_laneq_s32(acc_neon, a1, b, 1);
acc_neon = vdotq_laneq_s32(acc_neon, a2, b, 2);
acc_neon = vdotq_laneq_s32(acc_neon, a3, b, 3);
Bptr += 16;
Aptr0 += PACK_SIZE * 16;
}
if (k + 8 <= K) {
int8x16_t a0 = vld1q_s8(Aptr0);
int8x16_t a1 = vld1q_s8(Aptr0 + 16);
int8x8_t b = vld1_s8(Bptr);
acc_neon = vdotq_lane_s32(acc_neon, a0, b, 0);
acc_neon = vdotq_lane_s32(acc_neon, a1, b, 1);
Bptr += 8;
Aptr0 += PACK_SIZE * 8;
k += 8;
}
if (k + 4 <= K) {
int8x16_t a = vld1q_s8(Aptr0);
int32_t tmp = *(reinterpret_cast<const int32_t*>(Bptr));
int8x8_t b = vdup_n_s32(tmp);
acc_neon = vdotq_lane_s32(acc_neon, a, b, 0);
}
vst1q_s32(C, acc_neon);
Aptr += Astride;
C += Cstride;
}
}
} // namespace
#endif
......@@ -201,4 +408,33 @@ void arm_common::gemv_like(const int8_t* __restrict A,
MIDOUT_END();
}
void arm_common::gemv_like_mk4(const int8_t* __restrict A,
const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N,
size_t K, size_t Astride, size_t Bstride,
size_t Cstride) {
megdnn_assert(N == 1);
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv,
midout_iv("INT8_gemv_like_mk4"_hash)) {
return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
MIDOUT_END();
}
#if __ARM_FEATURE_DOTPROD
void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A,
const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N,
size_t K, size_t Astride, size_t Bstride,
size_t Cstride) {
megdnn_assert(N == 1);
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv,
midout_iv("INT8_gemv_like_mk4_dot"_hash)) {
return gemv_naive_n_mk4_dot(A, B, C, M, N, K, Astride, Bstride,
Cstride);
}
MIDOUT_END();
}
#endif
// vim: syntax=cpp.doxygen
......@@ -24,6 +24,16 @@ void gemv_like(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
#if __ARM_FEATURE_DOTPROD
void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
#endif
} // namespace arm_common
} // namespace megdnn
......
......@@ -28,14 +28,24 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF16Gemv f16gemv;
#endif
AlgoInt8x8x32Gemv int8x8x32_gemv;
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4;
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot;
#endif
AlgoGevm gevm;
AlgoF32GemvMK4 f32_gemv_mk4;
public:
AlgoPack() {
all_algos.emplace_back(&int8x8x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16gemv);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_gemv_mk4_dot);
#endif
all_algos.emplace_back(&int8x8x32_gemv);
all_algos.emplace_back(&int8x8x32_gemv_mk4);
all_algos.emplace_back(&f32_gemv_mk4);
all_algos.emplace_back(&gevm);
}
SmallVector<AlgoBase*> all_algos;
......
......@@ -25,11 +25,16 @@ public:
protected:
static void* const sm_arm_common_algo_type;
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv
class AlgoF32Gemv; // Arm_common F32 Gemv
class AlgoGevm; // Arm_common Gemv(support int8 and fp32)
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv
class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44
class AlgoGevm; // Arm_common Gevm(support int8 and fp32)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16Gemv;
#endif
#if __ARM_FEATURE_DOTPROD
class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT
#endif
class AlgoInt8x8x16; // Arm_common Int 8x8x16
class AlgoPack;
......
......@@ -407,6 +407,16 @@ __ai int32_t vaddv_s32(int32x2_t a) {
return vget_lane_s32(a, 0) + vget_lane_s32(a, 1);
}
__ai int32_t vaddvq_s32(int32x4_t a) {
return vgetq_lane_s32(a, 0) + vgetq_lane_s32(a, 1) +
vgetq_lane_s32(a, 2) + vgetq_lane_s32(a, 3);
}
__ai float32_t vaddvq_f32(float32x4_t a) {
return vgetq_lane_f32(a, 0) + vgetq_lane_f32(a, 1) +
vgetq_lane_f32(a, 2) + vgetq_lane_f32(a, 3);
}
#endif // MEGDNN_ARMV7
//! pack vmovl_low_xx() on armv7 and armv8
......
......@@ -42,14 +42,27 @@ using namespace conv1x1;
namespace {
#if MEGDNN_X86
template <typename stype, typename btype, param::ConvBias::Format F>
struct GemvLike {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
megdnn_throw("x86 conv1x1 gemv only supports format : NCHW");
MEGDNN_MARK_USED_VAR(A);
MEGDNN_MARK_USED_VAR(B);
MEGDNN_MARK_USED_VAR(C);
MEGDNN_MARK_USED_VAR(M);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(K);
MEGDNN_MARK_USED_VAR(LDA);
MEGDNN_MARK_USED_VAR(LDB);
MEGDNN_MARK_USED_VAR(LDC);
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn_assert(false,
"unspported conv1x1 gemv : \nsrc_type : "
"%s\nfilter_type : %s\n",
src.name(), filter.name());
}
};
......@@ -66,39 +79,29 @@ struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> {
}
};
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
template <typename stype, typename btype, param::ConvBias::Format F>
struct GemvLike {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
megdnn_throw("arm conv1x1 gemv only supports format : NCHW");
}
};
template <typename stype, typename btype>
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
template <>
struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const dt_uint8* A, const dt_uint8* B,
dt_int32* C, size_t M, size_t N, size_t K,
size_t LDA, size_t LDB, size_t LDC, DType src,
DType filter) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC);
uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point;
uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point;
megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA,
LDB, LDC, zp0, zp1);
}
};
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
template <>
struct GemvLike<dt_int8, dt_int16, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int16* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
struct GemvLike<dt_float32, dt_float32, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const dt_float32* A, const dt_float32* B,
dt_float32* C, size_t M, size_t N, size_t K,
size_t LDA, size_t LDB, size_t LDC, DType src,
DType filter) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::fallback::gemv_like<dt_int8, dt_int16>(A, B, C, M, N, K, LDA,
LDB, LDC);
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC);
}
};
......@@ -118,21 +121,47 @@ struct GemvLike<dt_float16, dt_float16, param::ConvBias::Format::NCHW> {
}
};
#endif
#endif
template <>
struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const dt_uint8* A, const dt_uint8* B,
dt_int32* C, size_t M, size_t N, size_t K,
size_t LDA, size_t LDB, size_t LDC, DType src,
struct GemvLike<dt_int8, dt_int32, param::ConvBias::Format::NCHW> {
inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int32* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point;
uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point;
megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA,
LDB, LDC, zp0, zp1);
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC);
}
};
template <typename stype, typename btype>
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44> {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::arm_common::gemv_like_mk4(A, B, C, M, N, K, LDA, LDB, LDC);
}
};
#if __ARM_FEATURE_DOTPROD
template <typename stype, typename btype>
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44_DOT> {
inline static void do_gemv(const stype* A, const stype* B, btype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, DType src,
DType filter) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
megdnn::arm_common::gemv_like_mk4_dot(A, B, C, M, N, K, LDA, LDB, LDC);
}
};
#endif
#endif
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode,
......@@ -185,19 +214,18 @@ struct Conv1x1GemvWorker {
is_dst_8bit ? matmul_temp_dst
: reinterpret_cast<bias_ctype*>(conv_bias_dst);
size_t pack_size = megdnn::fallback::pack_size(format);
GemvLike<src_ctype, bias_ctype, format>::do_gemv(
Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC, 1, 1,
ncb_param.filter_type, ncb_param.src_type);
Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC * pack_size,
pack_size, pack_size, ncb_param.filter_type,
ncb_param.src_type);
//! do postprocess
void* bias_ptr = nullptr;
if (param.bias_mode == megdnn::BiasMode::BIAS) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS) {
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>(
ncb_param.bias<bias_ctype>(batch_id, group_id) +
numbers_of_ncb_dst_offset));
} else {
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>(
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start));
}
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
......@@ -211,9 +239,13 @@ struct Conv1x1GemvWorker {
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
midout_iv("AlgoConv1x1Gemv::get_oc_tile"_hash)) {
size_t OC = param.filter_meta.ocpg;
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads);
return round_up<size_t>(oc_block_size_one_thread, 16);
}
MIDOUT_END();
}
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace(
......@@ -286,6 +318,11 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb1(param::ConvBias::Format::NCHW, dt_float16, __fp16,
PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT16_FP16"_hash);
#else
#if !MEGDNN_DISABLE_FLOAT16
cb1(param::ConvBias::Format::NCHW, dt_float16, dt_float16,
PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash);
#endif
#endif
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32,
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
......@@ -311,6 +348,37 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
"NCHW::GEMV::QUINT8x8x32_QUINT8"_hash);
break;
case param::ConvBias::Format::NCHW44:
cb1(param::ConvBias::Format::NCHW44, dt_float32, dt_float32,
PostprocessMode::FLOAT, "NCHW44::GEMV::FLOAT"_hash);
cb2(param::ConvBias::Format::NCHW44, dt_int8, dt_int32, dt_int32,
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"NCHW44::GEMV::INT8x8x32_INT32"_hash);
cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
"NCHW44::GEMV::QINT8x8x32_QINT32"_hash);
cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
dt_int8, PostprocessMode::QUANTIZED,
"NCHW44::GEMV::QINT8x8x32_QINT8"_hash);
break;
case param::ConvBias::Format::NCHW44_DOT:
cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32,
dt_int32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash);
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
"NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash);
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
dt_int8, PostprocessMode::QUANTIZED,
"NCHW44_DOT::GEMV::QINT8x8x32_QINT8"_hash);
break;
default:
megdnn_throw("Invalid Format");
break;
......@@ -338,6 +406,16 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
midout_iv("AlgoConv1x1Gemv::usable"_hash)) {
#if MEGDNN_X86
if (opr->param().format != param::ConvBias::Format::NCHW)
return false;
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
if (opr->param().format != param::ConvBias::Format::NCHW &&
opr->param().format != param::ConvBias::Format::NCHW44 &&
opr->param().format != param::ConvBias::Format::NCHW44_DOT)
return false;
#endif
//! whether 1x1
size_t FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
......@@ -390,59 +468,43 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr,
param.src_type.enumv() != DTypeEnum::Float32) {
return false;
}
bool is_param_ok =
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT;
bool is_format_and_dtype_ok = false;
#if MEGDNN_X86
if (opr->param().format == param::ConvBias::Format::NCHW) {
//! x86 supports all dtypes in NCHW
is_format_and_dtype_ok = true;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
if (opr->param().format == param::ConvBias::Format::NCHW44) {
if (param.src_type.enumv() != DTypeEnum::Float32 &&
param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8) {
return false;
}
} else if (opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
if (param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8) {
return false;
}
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
//! add NCHW44 and NCHW44_DOT support in the future
if (opr->param().format == param::ConvBias::Format::NCHW) {
//! NCHW format supports all dtype
is_format_and_dtype_ok = true;
}
#endif
return is_param_ok && is_format_and_dtype_ok;
return (param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT;
}
MIDOUT_END();
return false;
}
bool ConvBiasImpl::AlgoConv1x1Gemv::is_preferred(
ConvBiasImpl*, const NCBKernSizeParam& param) const {
size_t OC = param.filter_meta.ocpg;
if (OC <= 2 && param.src_type.enumv() != DTypeEnum::Float32)
return true;
ConvBiasImpl* opr, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
midout_iv("AlgoConv1x1Gemv::is_preferred"_hash)) {
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
//! maybe add support for QuantizedAsym in the future
return (param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32) ||
#if !MEGDNN_DISABLE_FLOAT16
(param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16) ||
#endif
(param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32);
#else
if (opr->param().format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
return false;
}
#endif
return true;
}
MIDOUT_END();
return false;
}
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -2036,7 +2036,6 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle,
RUNS;
auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS;
printf("\n%s: ", matmul_algo_name);
printf("%s %s:\n matmul: %f ms %f Gflops\nconv1x1: %f ms %f GFlops "
"speedup: "
"%f\n",
......@@ -2120,6 +2119,82 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) {
#endif
}
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args;
param::ConvBias conv_param;
conv_param.stride_h = 1;
conv_param.stride_w = 1;
conv_param.pad_h = 0;
conv_param.pad_w = 0;
conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY;
auto run = [&](size_t M, size_t K){
args.emplace_back(conv_param, TensorShape{1, K, 1, 1},
TensorShape{M, K, 1, 1}, TensorShape{});
};
for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 256, 1024, 4096})
run(M, K);
constexpr size_t RUNS = 50;
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
Benchmarker<MatrixMul> benchmark_matmul(handle());
benchmark_matmul.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV"));
benchmark_matmul.set_times(RUNS)
.set_dtype(0, dtype::Float32{})
.set_dtype(1, dtype::Float32{})
.set_dtype(2, dtype::Float32{})
.set_param(param)
.set_display(false);
Benchmarker<ConvBias> benchmark_conv1x1(handle());
benchmark_conv1x1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>("CONV1x1_GEMV"));
benchmark_conv1x1.set_times(RUNS)
.set_dtype(0, dtype::Float32{})
.set_dtype(1, dtype::Float32{})
.set_dtype(2, dtype::Float32{})
.set_dtype(4, dtype::Float32{})
.set_display(false);
std::cout << "warm up:\n";
for (int i = 0; i < 50; i++) {
benchmark_matmul.exec({{1, 1024}, {1024, 512}, {}});
benchmark_matmul.set_display(true);
}
for (auto&& arg : args) {
size_t IC = arg.src[1];
size_t OH = arg.src[2];
size_t OW = arg.src[3];
size_t OC = arg.filter[0];
size_t M = OC;
size_t K = IC;
size_t N = OH * OW;
float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3;
TensorShape A, B;
A = TensorShape{M, K};
B = TensorShape{K, N};
auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}}) /
RUNS;
auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS;
printf("%s %s:\n gemv: %f ms %f Gflops\nconv1x1: %f ms %f GFlops "
"speedup: "
"%f\n",
arg.src.to_string().c_str(), arg.filter.to_string().c_str(),
matmul_used, computations / matmul_used, conv1x1_used,
computations / conv1x1_used, matmul_used / conv1x1_used);
}
}
#ifndef __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) {
std::vector<TestArg> conv_bias_1x1_args_nchw44 =
......
......@@ -180,12 +180,15 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
for (size_t kernel : kernel_vec)
for (size_t oc : {4, 12})
for (size_t ic : {1, 3, 4, 12})
for (size_t h : {3, 5, 12})
for (size_t w : {7, 16, 23}) {
for (size_t h : {1, 3, 12})
for (size_t w : {1, 16, 23}) {
for (size_t group = 1;
group <=
std::min(std::min(oc, ic), 4_z);
++group) {
if (kernel != 1 && (h == 1 || w == 1)) {
continue;
}
pack(n, oc, ic, h, w, kernel, stride,
group, nlmode, bias);
}
......@@ -1897,6 +1900,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
#elif MEGDNN_ARMV7
check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24");
#endif
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) {
......@@ -1932,7 +1941,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
#endif
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
......@@ -2138,4 +2147,40 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
}
#endif
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, false, false);
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
checker_conv_bias(gemv_args, handle(), &rng, epsilon,
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f),
"CONV1x1_GEMV");
}
#ifdef __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, false, false, false, true);
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
checker_conv_bias(gemv_args, handle(), &rng, epsilon,
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f),
"CONV1x1_GEMV");
}
#endif
// vim: syntax=cpp.doxygen
......@@ -164,6 +164,70 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) {
run(M, K, N);
}
TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV_MK4"));
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127);
checker.set_rng(0, rng.get()).set_rng(1, rng.get());
auto run = [&](size_t M, size_t K, size_t N) {
Param param;
param.format = param::MatrixMul::Format::MK4;
param.transposeA = false;
param.transposeB = false;
TensorShape A, B;
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
checker.set_param(param)
.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.execs({A, B, {}});
};
// N = 1
for (size_t M : {4, 16, 128, 1024})
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024})
run(M, K, 1);
}
#if __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"));
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127);
checker.set_rng(0, rng.get()).set_rng(1, rng.get());
auto run = [&](size_t M, size_t K, size_t N) {
Param param;
param.format = param::MatrixMul::Format::MK4_DOT;
param.transposeA = false;
param.transposeB = false;
TensorShape A, B;
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
checker.set_param(param)
.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.execs({A, B, {}});
};
// N = 1
for (size_t M : {4, 16, 128, 1024})
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024})
run(M, K, 1);
}
#endif
TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
......@@ -220,6 +284,31 @@ TEST_F(ARM_COMMON, FP32_GEVM) {
run(M, K, N);
}
TEST_F(ARM_COMMON, FP32_GEMV_MK4) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV_MK4"));
checker.set_epsilon(1e-2);
auto run = [&](size_t M, size_t K) {
Param param;
param.format = param::MatrixMul::Format::MK4;
param.transposeA = false;
param.transposeB = false;
TensorShape A, B;
A = TensorShape{M/4, K/4, 4, 4};
B = TensorShape{K/4, 1, 4};
checker.set_param(param).execs({A, B, {}});
};
// N = 1
for (size_t M : {4, 16, 128, 1024})
for (size_t K : {4, 8, 12, 128, 256, 4096})
run(M, K);
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {
......@@ -228,18 +317,16 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {
benchmarker.set_times(exec_times);
auto run = [&](size_t M, size_t K, size_t N) {
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")"
<< std::endl;
printf("SGEMV: (%zu, %zu, %zu)\n", M, K, N);
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32());
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times;
auto computations = 2.f * M * K * N * 1e-6;
auto perf = computations / time;
std::cout << "gemv fp32, Performance is " << perf << " Gflops"
<< std::endl;
printf("gemv fp32, Performance is %f Gflops\n", perf);
};
std::cout << "warm up:\n";
printf("warm up:\n");
for (int i = 0; i < 50; i++) {
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
......@@ -253,6 +340,10 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {
for (size_t K : {1024, 1536, 2048})
for (size_t N : {512, 1024})
run(M, K, N);
for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 256, 1024, 4096})
run(M, K, 1);
}
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
......@@ -263,23 +354,20 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV"));
auto run = [&](size_t M, size_t K, size_t N) {
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")"
<< std::endl;
printf("SGEMV: (%zu, %zu, %zu)\n", M, K, N);
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32());
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times;
auto computations = 2 * M * K * N * 1e-6;
auto perf = computations / time;
std::cout << "gemv fp32, Performance is " << perf << " Gflops"
<< std::endl;
printf("gemv fp32, Performance is %f Gflops\n", perf);
};
std::cout << "warm up:\n";
printf("warm up:\n");
for (int i = 0; i < 50; i++) {
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_display(false)
.exec({{2, 1024}, {1024, 512}, {}});
benchmarker.set_display(true);
......@@ -298,6 +386,45 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
run(1024, 256, 1);
}
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) {
int exec_times = 10;
using Param = MatrixMul::Param;
Param param;
param.format = param::MatrixMul::Format::MK4;
param.transposeA = false;
param.transposeB = false;
Benchmarker<MatrixMul> benchmarker(handle());
benchmarker.set_times(exec_times);
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_param(param);
auto run = [&](size_t M, size_t K) {
printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N);
TensorShape A, B;
A = TensorShape{M/4, K/4, 4, 4};
B = TensorShape{K/4, 1, 4};
auto time = benchmarker.exec({A, B, {}}) / exec_times;
auto computations = 2.f * M * K * 1e-6;
auto perf = computations / time;
printf("gemv mk4 fp32, Performance is %f Gflops\n", perf);
};
printf("warm up:\n");
for (int i = 0; i < 50; i++) {
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_display(false)
.exec({{4, 256, 4, 4}, {256, 1, 4}, {}});
}
// run gemv mk4
for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 1024, 4096})
run(M, K);
}
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) {
int exec_times = 50;
Benchmarker<MatrixMul> benchmarker(handle());
......@@ -306,19 +433,17 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) {
AlgoChecker<MatrixMul>("ARM_COMMON_F16_GEMV"));
auto run = [&](size_t M, size_t K, size_t N) {
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")"
<< std::endl;
printf("SGEMV_FP16: (%zu, %zu, %zu)\n", M, K, N);
benchmarker.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16());
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times;
auto computations = 2 * M * K * N * 1e-6;
auto perf = computations / time;
std::cout << "gemv fp16, Performance is " << perf << " Gflops"
<< std::endl;
printf("gemv fp16, Performance is %f Gflops\n", perf);
};
std::cout << "warm up:\n";
printf("warm up:\n");
for (int i = 0; i < 50; i++) {
benchmarker.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
......@@ -343,17 +468,15 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) {
float mod = 1000 * exec_times / 1e9;
auto run = [&](size_t M, size_t K, size_t N) {
float time = 1.f, perf = 1.f;
std::cout << "SGEMM: (" << M << ", " << K << ", " << N << ")"
<< std::endl;
printf("SGEMM: (%zu, %zu, %zu)\n", M, K, N);
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32());
time = benchmarker.exec({{M, K}, {K, N}, {}});
perf = 2.f * M * K * N / time * mod;
std::cout << "gemm fp32, Performance is " << perf << " Gflops"
<< std::endl;
printf("gemm, Performance is %f Gflops\n", perf);
};
std::cout << "warm up:\n";
printf("warm up:\n");
for (int i = 0; i < 50; i++) {
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册