提交 a1677d7a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/arm_common): add fp32 gevm

GitOrigin-RevId: 4d348bbb345f4537b011c1c23f6d1b2ccee5739f
上级 5d950063
...@@ -138,6 +138,63 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( ...@@ -138,6 +138,63 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(
return f32_gemv_kern; return f32_gemv_kern;
} }
/* ===================== F32 Gevm algo ===================== */
namespace {
void gevm_fp32_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDB = kern_param.LDB;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
arm_common::sgemm_sgemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
}
void gevm_int8_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDB = kern_param.LDB;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
arm_common::matmul::gemv_like_int8(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
}
} // anonymous namespace
bool MatrixMulImpl::AlgoGevm::usable(
const KernSizeParam& kern_size_param) const {
// enumerate the M, N, K, only usable when preferred
bool fp32_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32();
return (fp32_ok || can_be_treated_as_int8x8x32(kern_size_param)) &&
preferred(kern_size_param);
}
bool MatrixMulImpl::AlgoGevm::preferred(
const KernSizeParam& kern_size_param) const {
auto M = kern_size_param.M;
return kern_size_param.trB && M == 1;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern(
const KernSizeParam& kern_size_param) const {
if (kern_size_param.A_type == dtype::Float32()) {
return gevm_fp32_kern;
} else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) {
return gevm_int8_kern;
} else {
megdnn_assert(
false, "no avaliable kern got A_type: %s B_type: %s C_type: %s",
kern_size_param.A_type.name(), kern_size_param.B_type.name(),
kern_size_param.C_type.name());
}
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/* ===================== F16 Gemv algo ===================== */ /* ===================== F16 Gemv algo ===================== */
namespace { namespace {
......
...@@ -70,6 +70,21 @@ public: ...@@ -70,6 +70,21 @@ public:
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
}; };
#endif #endif
class MatrixMulImpl::AlgoGevm : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_COMMON_GEVM"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
......
...@@ -28,6 +28,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -28,6 +28,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF16Gemv f16gemv; AlgoF16Gemv f16gemv;
#endif #endif
AlgoInt8x8x32Gemv int8x8x32_gemv; AlgoInt8x8x32Gemv int8x8x32_gemv;
AlgoGevm gevm;
public: public:
AlgoPack() { AlgoPack() {
all_algos.emplace_back(&int8x8x16); all_algos.emplace_back(&int8x8x16);
...@@ -35,6 +36,7 @@ public: ...@@ -35,6 +36,7 @@ public:
all_algos.emplace_back(&f16gemv); all_algos.emplace_back(&f16gemv);
#endif #endif
all_algos.emplace_back(&int8x8x32_gemv); all_algos.emplace_back(&int8x8x32_gemv);
all_algos.emplace_back(&gevm);
} }
SmallVector<AlgoBase*> all_algos; SmallVector<AlgoBase*> all_algos;
}; };
......
...@@ -27,6 +27,7 @@ protected: ...@@ -27,6 +27,7 @@ protected:
static void* const sm_arm_common_algo_type; static void* const sm_arm_common_algo_type;
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv
class AlgoF32Gemv; // Arm_common F32 Gemv class AlgoF32Gemv; // Arm_common F32 Gemv
class AlgoGevm; // Arm_common Gemv(support int8 and fp32)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16Gemv; class AlgoF16Gemv;
#endif #endif
......
...@@ -164,6 +164,62 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { ...@@ -164,6 +164,62 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) {
run(M, K, N); run(M, K, N);
} }
TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127);
checker.set_rng(0, rng.get()).set_rng(1, rng.get());
auto run = [&](size_t M, size_t K, size_t N) {
Param param;
param.transposeA = false;
param.transposeB = true;
TensorShape A, B;
A = TensorShape{M, K};
B = TensorShape{N, K};
checker.set_param(param)
.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.execs({A, B, {}});
};
// M = 1
for (size_t N : {1, 10, 16, 33, 64})
for (size_t K : {7, 512, 1024})
for (size_t M : {1})
run(M, K, N);
}
TEST_F(ARM_COMMON, FP32_GEVM) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
checker.set_epsilon(1e-2);
auto run = [&](size_t M, size_t K, size_t N) {
Param param;
param.transposeA = false;
param.transposeB = true;
TensorShape A, B;
A = TensorShape{M, K};
B = TensorShape{N, K};
checker.set_param(param).execs({A, B, {}});
};
// M = 1
for (size_t M : {1})
for (size_t K : {1000, 4096, 25088})
for (size_t N : {1000, 4096})
run(M, K, N);
}
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册