From 409a8772676feb891dc2f0db5a690a1eb8e238af Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 6 Jan 2021 23:47:23 +0800 Subject: [PATCH] feat(dnn): add algo interface for rocm&fallback matmul and batched matrix mul GitOrigin-RevId: dea03a0f7a3ec436719d9c69570cf827bbb444b5 --- dnn/src/fallback/batched_matrix_mul/algos.cpp | 107 +++++++++++ dnn/src/fallback/batched_matrix_mul/algos.h | 109 +++++++++++ .../fallback/batched_matrix_mul/opr_impl.cpp | 94 +++++---- .../fallback/batched_matrix_mul/opr_impl.h | 54 ++++-- dnn/src/fallback/conv_bias/im2col/factory.h | 7 + .../conv_bias/im2col/strategy_nopack.cpp | 4 + dnn/src/fallback/handle.cpp | 2 +- dnn/src/fallback/matrix_mul/algos.cpp | 84 ++++++++- dnn/src/fallback/matrix_mul/algos.h | 22 +++ dnn/src/fallback/matrix_mul/opr_impl.cpp | 21 ++- dnn/src/fallback/matrix_mul/opr_impl.h | 2 + dnn/src/naive/matrix_mul/matrix_mul_helper.h | 178 +++++++++++++++--- dnn/src/naive/matrix_mul/opr_impl.cpp | 75 +------- dnn/src/rocm/batched_matrix_mul/algos.cpp | 59 ++++++ dnn/src/rocm/batched_matrix_mul/algos.h | 118 ++++++++++++ dnn/src/rocm/batched_matrix_mul/blas.cpp | 140 ++++++++++++++ dnn/src/rocm/batched_matrix_mul/opr_impl.cpp | 135 ++++--------- dnn/src/rocm/batched_matrix_mul/opr_impl.h | 36 ++-- dnn/src/rocm/matrix_mul/algos.cpp | 62 ++++++ dnn/src/rocm/matrix_mul/algos.h | 118 ++++++++++++ dnn/src/rocm/matrix_mul/blas.cpp | 162 ++++++++++++++++ dnn/src/rocm/matrix_mul/opr_impl.cpp | 174 ++++------------- dnn/src/rocm/matrix_mul/opr_impl.h | 23 ++- dnn/test/fallback/matrix_mul.cpp | 31 +++ dnn/test/naive/matrix_mul.cpp | 4 +- 25 files changed, 1380 insertions(+), 441 deletions(-) create mode 100644 dnn/src/fallback/batched_matrix_mul/algos.cpp create mode 100644 dnn/src/fallback/batched_matrix_mul/algos.h create mode 100644 dnn/src/rocm/batched_matrix_mul/algos.cpp create mode 100644 dnn/src/rocm/batched_matrix_mul/algos.h create mode 100644 dnn/src/rocm/batched_matrix_mul/blas.cpp create mode 100644 dnn/src/rocm/matrix_mul/algos.cpp create mode 100644 dnn/src/rocm/matrix_mul/algos.h create mode 100644 dnn/src/rocm/matrix_mul/blas.cpp diff --git a/dnn/src/fallback/batched_matrix_mul/algos.cpp b/dnn/src/fallback/batched_matrix_mul/algos.cpp new file mode 100644 index 000000000..9e809bd7f --- /dev/null +++ b/dnn/src/fallback/batched_matrix_mul/algos.cpp @@ -0,0 +1,107 @@ +/** + * \file dnn/src/fallback/batched_matrix_mul/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/fallback/batched_matrix_mul/algos.h" +#include "src/common/algo_base.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace fallback; + +BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { + all_algos.push_back(&algo_default); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } +} + +BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack; + +MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) + +BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( + BatchedMatrixMulForwardImpl* o, const TensorLayout& A, + const TensorLayout& B, const TensorLayout& C) + : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} + +BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( + BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, + _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) + : SizeArgs(opr, A.layout, B.layout, C.layout), + tensor_a{A}, + tensor_b{B}, + tensor_c{C}, + workspace{workspace} {} + +std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { + auto&& param = opr->param(); + size_t m = layout_a.shape[0], n = layout_b.shape[1], + k = layout_a.shape[param.transposeA ? 0 : 1]; + MEGDNN_MARK_USED_VAR(m); + MEGDNN_MARK_USED_VAR(n); + MEGDNN_MARK_USED_VAR(k); + return megdnn_mangle(ssprintf( + "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " + "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", + m, k, k, n, m, n, param.transposeA, param.transposeB, + layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); +} + +/* ===================== default algo ===================== */ +size_t BatchedMatrixMulForwardImpl::AlgoDefault::get_workspace_in_bytes( + const SizeArgs& args) const { + auto opr = inplace_cpu_handle()->create_operator(); + auto A_ = args.layout_a.remove_axis(0), B_ = args.layout_b.remove_axis(0), + C_ = args.layout_c.remove_axis(0); + opr->param() = args.opr->param(); + return opr->get_workspace_in_bytes(A_, B_, C_); +} + +void BatchedMatrixMulForwardImpl::AlgoDefault::exec( + const ExecArgs& args) const { + //! As megbrain may modify param when checking all transpose situations, so + //! here we should copy the param when dispatching kern + auto param = args.opr->param(); + auto kern = [args, param]() { + auto N = args.layout_a.shape[0]; + TensorND A_, B_, C_; + A_.raw_ptr = args.tensor_a.raw_ptr; + A_.layout = args.layout_a.remove_axis(0); + B_.raw_ptr = args.tensor_b.raw_ptr; + B_.layout = args.layout_b.remove_axis(0); + C_.raw_ptr = args.tensor_c.raw_ptr; + C_.layout = args.layout_c.remove_axis(0); + + auto Astrd = args.layout_a.dtype.size() * args.layout_a.stride[0], + Bstrd = args.layout_b.dtype.size() * args.layout_b.stride[0], + Cstrd = args.layout_c.dtype.size() * args.layout_c.stride[0]; + + auto advance_ptr = [](TensorND& dest, ptrdiff_t d) { + dest.raw_ptr = + static_cast(static_cast(dest.raw_ptr) + d); + }; + + auto opr = inplace_cpu_handle()->create_operator(); + opr->param() = param; + rep(n, N) { + opr->exec(A_, B_, C_, args.workspace); + advance_ptr(A_, Astrd); + advance_ptr(B_, Bstrd); + advance_ptr(C_, Cstrd); + } + }; + + static_cast(args.opr->handle())->dispatch_kern(kern); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/batched_matrix_mul/algos.h b/dnn/src/fallback/batched_matrix_mul/algos.h new file mode 100644 index 000000000..99626ba84 --- /dev/null +++ b/dnn/src/fallback/batched_matrix_mul/algos.h @@ -0,0 +1,109 @@ +/** + * \file dnn/src/fallback/batched_matrix_mul/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" +#include "src/common/utils.h" +#include "src/fallback/batched_matrix_mul/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace fallback { + +/*! + * \brief base class for matrix mul algos + * + */ +class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + enum class AlgoType : uint32_t { + fallback_BLAS, + }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; } + + struct SizeArgs { + BatchedMatrixMulForwardImpl* opr; + TensorLayout layout_a, layout_b, layout_c; + + std::string to_string() const; + SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A, + const TensorLayout& B, const TensorLayout& C); + }; + struct ExecArgs : public SizeArgs { + TensorND tensor_a, tensor_b, tensor_c; + Workspace workspace; + + ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, + _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace); + }; + + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs&) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) const { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) const { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert( + req <= workspace.size, + "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } +}; + +class BatchedMatrixMulForwardImpl::AlgoDefault final : public AlgoBase { +public: + AlgoDefault() = default; + bool is_available(const SizeArgs&) const override { return true; } + size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; + const char* name() const override { return "DEFAULT"; } + virtual void exec(const ExecArgs&) const override; + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(fallback_BLAS) +}; + +class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; + +public: + AlgoPack(); + AlgoDefault algo_default; + std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/batched_matrix_mul/opr_impl.cpp b/dnn/src/fallback/batched_matrix_mul/opr_impl.cpp index 9a681e14e..7f613509c 100644 --- a/dnn/src/fallback/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/batched_matrix_mul/opr_impl.cpp @@ -6,67 +6,61 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "./opr_impl.h" -#include "src/naive/handle.h" +#include "./algos.h" +#include "hcc_detail/hcc_defs_prologue.h" + +#include "src/common/algo_chooser.h" +#include "src/common/utils.cuh" +#include "src/fallback/handle.h" using namespace megdnn; using namespace fallback; -BatchedMatrixMulImpl::BatchedMatrixMulImpl(Handle *handle): - BatchedMatrixMulForwardImpl(handle), - m_storage(new CpuOprDelegationStorage<>), - m_opr(m_storage->get()) -{ +std::vector +BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_all_algorithms(args); } -size_t BatchedMatrixMulImpl::get_workspace_in_bytes( - const TensorLayout &A, const TensorLayout &B, - const TensorLayout &C) { - auto A_ = A.remove_axis(0), B_ = B.remove_axis(0), C_ = C.remove_axis(0); - m_opr->param() = param(); - return m_opr->get_workspace_in_bytes(A_, B_, C_); +BatchedMatrixMulForwardImpl::Algorithm* +BatchedMatrixMulForwardImpl::get_algorithm_heuristic( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_limit_in_bytes, bool reproducible) { + AlgoBase::SizeArgs args{this, A, B, C}; + if (sm_algo_pack.algo_default.is_available_reproducible( + args, reproducible, workspace_limit_in_bytes)) { + return &sm_algo_pack.algo_default; + } + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.all_algos, args, workspace_limit_in_bytes, + "batched matrix mul forward"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.all_algos, args, workspace_limit_in_bytes, + "batched matrix mul forward"); + } } -void BatchedMatrixMulImpl::exec(_megdnn_tensor_in A, - _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace) { - check_exec(A.layout, B.layout, C.layout, workspace.size); - - m_opr->param() = this->param(); - auto kern = [this, A, B, C, workspace]() { - auto N = A.layout.shape[0]; - TensorND A_, B_, C_; - A_.raw_ptr = A.raw_ptr; - A_.layout = A.layout.remove_axis(0); - B_.raw_ptr = B.raw_ptr; - B_.layout = B.layout.remove_axis(0); - C_.raw_ptr = C.raw_ptr; - C_.layout = C.layout.remove_axis(0); - - auto Astrd = A.layout.dtype.size() * A.layout.stride[0], - Bstrd = B.layout.dtype.size() * B.layout.stride[0], - Cstrd = C.layout.dtype.size() * C.layout.stride[0]; - - auto advance_ptr = [](TensorND &dest, ptrdiff_t d) { - dest.raw_ptr = static_cast( - static_cast(dest.raw_ptr) + d); - }; - - rep(n, N) { - m_opr->exec(A_, B_, C_, workspace); - advance_ptr(A_, Astrd); - advance_ptr(B_, Bstrd); - advance_ptr(C_, Cstrd); - } - }; - - static_cast(handle())->dispatch_kern(kern); +size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); } +void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, + _megdnn_tensor_out C, + _megdnn_workspace workspace) { + check_exec(A.layout, B.layout, C.layout, workspace.size); + AlgoBase::ExecArgs args(this, A, B, C, workspace); + auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); + algo->check_workspace(args, workspace).exec(args); +} // vim: syntax=cpp.doxygen - - diff --git a/dnn/src/fallback/batched_matrix_mul/opr_impl.h b/dnn/src/fallback/batched_matrix_mul/opr_impl.h index b7f6e1f5f..6ddb78956 100644 --- a/dnn/src/fallback/batched_matrix_mul/opr_impl.h +++ b/dnn/src/fallback/batched_matrix_mul/opr_impl.h @@ -15,26 +15,42 @@ namespace megdnn { namespace fallback { -class BatchedMatrixMulImpl: public naive::BatchedMatrixMulForwardImpl { - public: - BatchedMatrixMulImpl(Handle *handle); - void exec( - _megdnn_tensor_in A, - _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace) override; - - size_t get_workspace_in_bytes(const TensorLayout &A, - const TensorLayout &B, - const TensorLayout &C) override; - - private: - std::unique_ptr> m_storage; - MatrixMulForward* m_opr; +class BatchedMatrixMulForwardImpl: public naive::BatchedMatrixMulForwardImpl { +public: + using naive::BatchedMatrixMulForwardImpl::BatchedMatrixMulForwardImpl; + void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override; + + bool is_thread_safe() const override { return true; } + + class AlgoBase; + class AlgoDefault; + class AlgoPack; + static const AlgoPack& algo_pack() { return sm_algo_pack; } + + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + +private: + std::vector get_all_algorithms( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/) override; + + Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, + const TensorLayout& /*B*/, + const TensorLayout& /*C*/, + size_t /*workspace_limit_in_bytes*/, + bool /*reproducible*/) override; + + const char* get_algorithm_set_name() const override { + return "FALLBACK BATCHED MATMUL"; + } + + static AlgoPack sm_algo_pack; }; -} // namespace fallback -} // namespace megdnn +} // namespace fallback +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index a5024b504..4b35e8848 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -473,6 +473,13 @@ public: PostprocessMode::NO_PROCESS, "NoPackStrategyType::FLOAT16_FLOAT16"_hash); break; +#endif +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case StrategyType::FLOAT_FP16: + cb1(NCHW, NO_PACK, dt_float16, __fp16, + PostprocessMode::NO_PROCESS, + "NoPackStrategyType::FLOAT_FP16"_hash); + break; #endif case StrategyType::INT8x8x16: cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp index cb574b748..ad2ca1584 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp @@ -169,6 +169,10 @@ INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, megdnn::PostprocessMode::NO_PROCESS) #endif +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, + megdnn::PostprocessMode::NO_PROCESS) +#endif #undef INSTANTIAL_CLASS } // namespace megdnn diff --git a/dnn/src/fallback/handle.cpp b/dnn/src/fallback/handle.cpp index 55b1e7074..051e5b6c1 100644 --- a/dnn/src/fallback/handle.cpp +++ b/dnn/src/fallback/handle.cpp @@ -67,7 +67,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseMultiType) MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdate) MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward) MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize) -MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMul) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward) MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC) diff --git a/dnn/src/fallback/matrix_mul/algos.cpp b/dnn/src/fallback/matrix_mul/algos.cpp index 708e7fe59..371702ad3 100644 --- a/dnn/src/fallback/matrix_mul/algos.cpp +++ b/dnn/src/fallback/matrix_mul/algos.cpp @@ -10,13 +10,18 @@ */ #include "src/fallback/matrix_mul/algos.h" +#include "megdnn/opr_param_defs.h" #include "src/fallback/matrix_mul/gemm_impl.h" #include "src/fallback/matrix_mul/gemv.h" #include "src/fallback/matrix_mul/generic_strategy.h" + +#include "src/naive/matrix_mul/matrix_mul_helper.h" + #include "midout.h" MIDOUT_DECL(megdnn_fb_matmul_f32_kern) MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) +MIDOUT_DECL(megdnn_fb_matmul_naive) using namespace megdnn; using namespace fallback; @@ -39,6 +44,32 @@ void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) { } MIDOUT_END(); } + +void kern_naive(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) { + size_t M = kern_param.M, N = kern_param.N, K = kern_param.K; + size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + +#define DISPATCH(TA, TB) \ + if (kern_param.trA == TA && kern_param.trB == TB) { \ + naive::dispatch_ta_tb( \ + kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \ + kern_param.workspace_ptr, M, N, K, LDA, LDB, LDC, \ + kern_param.A_type, kern_param.B_type, kern_param.C_type, \ + kern_param.format, kern_param.compute_mode); \ + return; \ + } + DISPATCH(true, true); + DISPATCH(true, false); + DISPATCH(false, true); + DISPATCH(false, false); +#undef DISPATCH + megdnn_assert_internal(0); + + } + MIDOUT_END(); + +} } // anonymous namespace ////////////////////// AlgoF32K8x12x1 /////////////////////////// @@ -84,11 +115,14 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, bool MatrixMulImpl::AlgoGemv::usable( const KernSizeParam& kern_size_param) const { return !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.format == param::MatrixMul::Format::DEFAULT && - !((kern_size_param.A_type.enumv() == - kern_size_param.B_type.enumv()) && - (kern_size_param.A_type.enumv() == DTypeEnum::Int16) && - (kern_size_param.C_type.enumv() == DTypeEnum::Int32)); + kern_size_param.format == + param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == + param::MatrixMul::ComputeMode::DEFAULT && + !((kern_size_param.A_type.enumv() == + kern_size_param.B_type.enumv()) && + (kern_size_param.A_type.enumv() == DTypeEnum::Int16) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32)); } bool MatrixMulImpl::AlgoGemv::preferred( @@ -128,4 +162,44 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoGemv::get_kern( megdnn_assert(0); } +/* ===================== naive algo ===================== */ +bool MatrixMulImpl::AlgoNaive::usable(const KernSizeParam&) const { + return true; +} + +bool MatrixMulImpl::AlgoNaive::preferred(const KernSizeParam&) const { + return false; +} + +size_t MatrixMulImpl::AlgoNaive::get_workspace( + const KernSizeParam& kern_param) const { + MIDOUT_BEGIN( + megdnn_fb_matmul_naive, + midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { + if (kern_param.A_type.enumv() == DTypeEnum::Quantized4Asymm || + kern_param.A_type.enumv() == DTypeEnum::QuantizedS4) { + size_t ret = 0; + if (kern_param.trA) { + ret += kern_param.LDA * kern_param.K; + } else { + ret += kern_param.LDA * kern_param.M; + } + if (kern_param.trB) { + ret += kern_param.LDB * kern_param.N; + } else { + ret += kern_param.LDB * kern_param.K; + } + return ret; + } + return 0; + } + MIDOUT_END(); + +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern( + const KernSizeParam&) const { + return kern_naive; +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/algos.h b/dnn/src/fallback/matrix_mul/algos.h index f0cd51be6..df98c56f8 100644 --- a/dnn/src/fallback/matrix_mul/algos.h +++ b/dnn/src/fallback/matrix_mul/algos.h @@ -52,6 +52,28 @@ public: DEFAULT) }; +class MatrixMulImpl::AlgoNaive final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "FB_NAIVE"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMM; } + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) + MEGDNN_OVERRIDE_MATMUL_DESC( + 8, 16, 1, 4, + static_cast( + static_cast(AlgoDataType::FLOAT16) | + static_cast(AlgoDataType::FLOAT32) | + static_cast(AlgoDataType::INT8X8X16) | + static_cast(AlgoDataType::QINT8X8X32) | + static_cast(AlgoDataType::QUINT8X8X32)), + DEFAULT) +}; + } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index c82c65ba2..b576aa7fa 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -35,6 +35,7 @@ using namespace fallback; class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF32K8x12x1 f32_k8x12x1; AlgoGemv gemv; + AlgoNaive naive; SmallVector m_all_algos; AlgoBase::Mapper m_all_algos_map; @@ -42,6 +43,7 @@ public: AlgoPack() { m_all_algos.emplace_back(&gemv); m_all_algos.emplace_back(&f32_k8x12x1); + m_all_algos.emplace_back(&naive); for (auto&& algo : m_all_algos) { m_all_algos_map.emplace(algo->info().desc, algo); } @@ -147,19 +149,26 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( algo_type.format = kern_size_param.format; auto algos = select_algo_type(algo_type); Algorithm *heuristic_algo = nullptr; + Algorithm *usable_algo = nullptr; for (auto&& algo : algos) { if (static_cast(algo)->usable(kern_size_param) && - static_cast(algo)->preferred_reproducible( - kern_size_param, reproducible) && static_cast(algo)->get_workspace(kern_size_param) <= workspace_limit_in_bytes) { - if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { - return algo; - } else if (!heuristic_algo) { - heuristic_algo = algo; + if (static_cast(algo)->preferred_reproducible( + kern_size_param, reproducible)) { + //! use gemv algo if it's prefered + if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { + return algo; + } else if (!heuristic_algo) { + heuristic_algo = algo; + } + } else if (!usable_algo) { + usable_algo = algo; } } } + if (!heuristic_algo) heuristic_algo = usable_algo; + megdnn_assert(heuristic_algo, "No usable algorithm found"); return heuristic_algo; } diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index 64bc738ff..74181cb8d 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -110,6 +110,7 @@ public: //! fallback FB_F32K8x12x1 = 1 << 0, FB_GEMV, + FB_NAIVE, #if MEGDNN_X86 //! x86 @@ -233,6 +234,7 @@ public: private: class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 class AlgoGemv; + class AlgoNaive; class AlgoPack; //! maintain all the algos of in the opr of fallback static const AlgoPack& algo_pack(); diff --git a/dnn/src/naive/matrix_mul/matrix_mul_helper.h b/dnn/src/naive/matrix_mul/matrix_mul_helper.h index 843f754a5..bfa227bce 100644 --- a/dnn/src/naive/matrix_mul/matrix_mul_helper.h +++ b/dnn/src/naive/matrix_mul/matrix_mul_helper.h @@ -141,20 +141,39 @@ void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, } template -void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, - _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace, - const param::MatrixMul& param) { +void exec_matrix_mul_quint4x4x32_helper( + const void* A, const void* B, void* C, void* workspace, size_t M, + size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC, + DType A_type, DType B_type, DType C_type, + const MatrixMul::Param::Format& format, + const MatrixMul::Param::ComputeMode& compute_mode) { + MEGDNN_MARK_USED_VAR(C_type); + MEGDNN_MARK_USED_VAR(format); + MEGDNN_MARK_USED_VAR(compute_mode); auto convert_layout = [](const TensorLayout& layout) { auto ret = layout; auto param = layout.dtype.param(); ret.dtype = dtype::Quantized8Asymm(param.scale, param.zero_point); return ret; }; - TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; - TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), - convert_layout(B.layout)}; + TensorLayout A_layout, B_layout; + if (transA) { + A_layout = TensorLayout({K, M}, {LDA, 1}, A_type); + } else { + A_layout = TensorLayout({M, K}, {LDA, 1}, A_type); + } + if (transB) { + B_layout = TensorLayout({N, K}, {LDB, 1}, B_type); + } else { + B_layout = TensorLayout({K, N}, {LDB, 1}, B_type); + } + + TensorND tensorA{const_cast(A), A_layout}; + TensorND tensorB{const_cast(B), B_layout}; + TensorND nA = {workspace, convert_layout(A_layout)}; + TensorND nB = { + static_cast(workspace) + nA.layout.span().dist_byte(), + convert_layout(B_layout)}; auto convert_4to8 = [](const TensorND& in, const TensorND& out) { auto ptr = static_cast(in.raw_ptr) + in.layout.span().low_byte; @@ -168,31 +187,48 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, out_ptr[i + 1] = val1; } }; - convert_4to8(A, nA); - convert_4to8(B, nB); - auto M = C.layout.shape[0], N = C.layout.shape[1]; - auto K = A.layout.shape[param.transposeA ? 0 : 1]; - auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], - LDC = C.layout.stride[0]; + convert_4to8(tensorA, nA); + convert_4to8(tensorB, nB); run_matrix_mul_tpl( nA.compatible_ptr(), nB.compatible_ptr(), - C.compatible_ptr(), M, N, K, LDA, LDB, LDC, - nA.layout.dtype, nB.layout.dtype); + static_cast(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype, + nB.layout.dtype); } template -void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace, - const param::MatrixMul& param) { +void exec_matrix_mul_qint4x4x16_helper( + const void* A, const void* B, void* C, void* workspace, size_t M, + size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC, + DType A_type, DType B_type, DType C_type, + const MatrixMul::Param::Format& format, + const MatrixMul::Param::ComputeMode& compute_mode) { + MEGDNN_MARK_USED_VAR(C_type); + MEGDNN_MARK_USED_VAR(format); + MEGDNN_MARK_USED_VAR(compute_mode); auto convert_layout = [](const TensorLayout& layout) { auto ret = layout; auto param = layout.dtype.param(); ret.dtype = dtype::QuantizedS8(param.scale); return ret; }; - TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; - TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), - convert_layout(B.layout)}; + TensorLayout A_layout, B_layout; + if (transA) { + A_layout = TensorLayout({K, M}, {LDA, 1}, A_type); + } else { + A_layout = TensorLayout({M, K}, {LDA, 1}, A_type); + } + if (transB) { + B_layout = TensorLayout({N, K}, {LDB, 1}, B_type); + } else { + B_layout = TensorLayout({K, N}, {LDB, 1}, B_type); + } + + TensorND tensorA{const_cast(A), A_layout}; + TensorND tensorB{const_cast(B), B_layout}; + + TensorND nA = {workspace, convert_layout(A_layout)}; + TensorND nB = { + static_cast(workspace) + nA.layout.span().dist_byte(), + convert_layout(B_layout)}; auto convert_4to8 = [](const TensorND& in, const TensorND& out) { auto ptr = static_cast(in.raw_ptr) + in.layout.span().low_byte; auto out_ptr = @@ -204,18 +240,98 @@ void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, out_ptr[i + 1] = cur >> 4; } }; - convert_4to8(A, nA); - convert_4to8(B, nB); - auto M = C.layout.shape[0], N = C.layout.shape[1]; - auto K = A.layout.shape[param.transposeA ? 0 : 1]; - auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], - LDC = C.layout.stride[0]; + convert_4to8(tensorA, nA); + convert_4to8(tensorB, nB); run_matrix_mul_tpl( nA.compatible_ptr(), nB.compatible_ptr(), - C.compatible_ptr(), M, N, K, LDA, LDB, LDC, - nA.layout.dtype, nB.layout.dtype); + static_cast(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype, + nB.layout.dtype); } +template +void dispatch_ta_tb(const void* A, const void* B, void* C, void* workspace, + size_t M, size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, + ptrdiff_t LDC, DType A_type, DType B_type, DType C_type, + const MatrixMul::Param::Format& format, + const MatrixMul::Param::ComputeMode& compute_mode) { +#define cb(_itype, _otype, _comp_type) \ + if (format == param::MatrixMul::Format::DEFAULT) { \ + return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ + static_cast(A), static_cast(B), \ + static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ + B_type); \ + } else if (format == param::MatrixMul::Format::MK4) { \ + return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ + static_cast(A), static_cast(B), \ + static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ + B_type); \ + } else if (format == param::MatrixMul::Format::MK4_DOT) { \ + return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ + static_cast(A), static_cast(B), \ + static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ + B_type); \ + } else if (format == param::MatrixMul::Format::MK8) { \ + return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ + static_cast(A), static_cast(B), \ + static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ + B_type); \ + } + + if (A_type == dtype::Float32()) { + cb(dt_float32, dt_float32, dt_float32); +#if !MEGDNN_DISABLE_FLOAT16 + } else if (A_type == dtype::Float16()) { + using Param = MatrixMul::Param; + if (compute_mode == Param::ComputeMode::DEFAULT) { + cb(dt_float16, dt_float16, dt_float16); + } else if (compute_mode == Param::ComputeMode::FLOAT32) { + cb(dt_float16, dt_float16, dt_float32); + } + } else if (A_type == dtype::BFloat16()) { + using Param = MatrixMul::Param; + if (compute_mode == Param::ComputeMode::DEFAULT) { + cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); + } else if (compute_mode == Param::ComputeMode::FLOAT32) { + cb(dt_bfloat16, dt_bfloat16, dt_float32); + } +#endif + } else if (A_type == dtype::Int8() && + C_type == dtype::Int16()) { + cb(dt_int8, dt_int16, dt_int16); + } else if (A_type == dtype::Int16() && + C_type == dtype::Int32()) { + cb(dt_int16, dt_int32, dt_int32); + } else if ((A_type == dtype::Int8() || + A_type.enumv() == DTypeEnum::QuantizedS8) && + (C_type == dtype::Int32() || + C_type.enumv() == DTypeEnum::QuantizedS32)) { + cb(dt_int8, dt_int32, dt_int32); + } else if (A_type.enumv() == DTypeEnum::Quantized8Asymm && + C_type.enumv() == DTypeEnum::QuantizedS32) { + cb(uint8_t, dt_int32, dt_int32); + } else if (A_type.enumv() == DTypeEnum::Quantized4Asymm && + C_type.enumv() == DTypeEnum::QuantizedS32 && + format == param::MatrixMul::Format::DEFAULT) { + exec_matrix_mul_quint4x4x32_helper( + A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type, + C_type, format, compute_mode); + return; + } else if (A_type.enumv() == DTypeEnum::QuantizedS4 && + C_type.enumv() == DTypeEnum::QuantizedS16 && + format == param::MatrixMul::Format::DEFAULT) { + exec_matrix_mul_qint4x4x16_helper( + A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type, + C_type, format, compute_mode); + return; + } +#undef cb + megdnn_throw( + ssprintf("unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", + A_type.name(), B_type.name(), C_type.name(), + static_cast(compute_mode))); +} + + } // namespace naive } // namespace megdnn diff --git a/dnn/src/naive/matrix_mul/opr_impl.cpp b/dnn/src/naive/matrix_mul/opr_impl.cpp index fd5128e50..2ba27d1f3 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/matrix_mul/opr_impl.cpp @@ -45,77 +45,10 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], LDC = C.layout.stride[0]; -#define cb(_itype, _otype, _comp_type) \ - if (param.format == param::MatrixMul::Format::DEFAULT) { \ - return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ - A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ - C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ - A.layout.dtype, B.layout.dtype); \ - } else if (param.format == param::MatrixMul::Format::MK4) { \ - return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ - A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ - C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ - A.layout.dtype, B.layout.dtype); \ - } else if (param.format == param::MatrixMul::Format::MK4_DOT) { \ - return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ - A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ - C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ - A.layout.dtype, B.layout.dtype); \ - } else if (param.format == param::MatrixMul::Format::MK8) { \ - return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ - A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ - C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ - A.layout.dtype, B.layout.dtype); \ - } - - if (A.layout.dtype == dtype::Float32()) { - cb(dt_float32, dt_float32, dt_float32); -#if !MEGDNN_DISABLE_FLOAT16 - } else if (A.layout.dtype == dtype::Float16()) { - using Param = MatrixMul::Param; - if (param.compute_mode == Param::ComputeMode::DEFAULT) { - cb(dt_float16, dt_float16, dt_float16); - } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { - cb(dt_float16, dt_float16, dt_float32); - } - } else if (A.layout.dtype == dtype::BFloat16()) { - using Param = MatrixMul::Param; - if (param.compute_mode == Param::ComputeMode::DEFAULT) { - cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); - } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { - cb(dt_bfloat16, dt_bfloat16, dt_float32); - } -#endif - } else if (A.layout.dtype == dtype::Int8() && - C.layout.dtype == dtype::Int16()) { - cb(dt_int8, dt_int16, dt_int16); - } else if (A.layout.dtype == dtype::Int16() && - C.layout.dtype == dtype::Int32()) { - cb(dt_int16, dt_int32, dt_int32); - } else if ((A.layout.dtype == dtype::Int8() || - A.layout.dtype.enumv() == DTypeEnum::QuantizedS8) && - (C.layout.dtype == dtype::Int32() || - C.layout.dtype.enumv() == DTypeEnum::QuantizedS32)) { - cb(dt_int8, dt_int32, dt_int32); - } else if (A.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && - C.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { - cb(uint8_t, dt_int32, dt_int32); - } else if (A.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && - C.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && - param.format == param::MatrixMul::Format::DEFAULT) { - exec_matrix_mul_quint4x4x32_helper(A, B, C, workspace, param); - return; - } else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && - C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 && - param.format == param::MatrixMul::Format::DEFAULT) { - exec_matrix_mul_qint4x4x16_helper(A, B, C, workspace, param); - return; - } -#undef cb - megdnn_throw(ssprintf( - "unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", - A.layout.dtype.name(), B.layout.dtype.name(), C.layout.dtype.name(), - static_cast(param.compute_mode))); + dispatch_ta_tb(A.raw_ptr, B.raw_ptr, C.raw_ptr, workspace.raw_ptr, + M, N, K, LDA, LDB, LDC, A.layout.dtype, + B.layout.dtype, C.layout.dtype, param.format, + param.compute_mode); } void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A, diff --git a/dnn/src/rocm/batched_matrix_mul/algos.cpp b/dnn/src/rocm/batched_matrix_mul/algos.cpp new file mode 100644 index 000000000..431dfe5ac --- /dev/null +++ b/dnn/src/rocm/batched_matrix_mul/algos.cpp @@ -0,0 +1,59 @@ +/** + * \file dnn/src/rocm/batched_matrix_mul/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/rocm/batched_matrix_mul/algos.h" +#include "src/common/algo_base.h" + +using namespace megdnn; +using namespace rocm; + +BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { + all_algos.push_back(&blas); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } +} + +BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack; + +MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) + +BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( + BatchedMatrixMulForwardImpl* o, const TensorLayout& A, + const TensorLayout& B, const TensorLayout& C) + : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} + +BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( + BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, + _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) + : SizeArgs(opr, A.layout, B.layout, C.layout), + tensor_a{A}, + tensor_b{B}, + tensor_c{C}, + workspace{workspace} {} + +std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { + auto&& param = opr->param(); + size_t m = layout_a.shape[0], n = layout_b.shape[1], + k = layout_a.shape[param.transposeA ? 0 : 1]; + MEGDNN_MARK_USED_VAR(m); + MEGDNN_MARK_USED_VAR(n); + MEGDNN_MARK_USED_VAR(k); + return megdnn_mangle(ssprintf( + "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " + "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", + m, k, k, n, m, n, param.transposeA, param.transposeB, + layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/batched_matrix_mul/algos.h b/dnn/src/rocm/batched_matrix_mul/algos.h new file mode 100644 index 000000000..de2d158ac --- /dev/null +++ b/dnn/src/rocm/batched_matrix_mul/algos.h @@ -0,0 +1,118 @@ +/** + * \file dnn/src/rocm/batched_matrix_mul/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" +#include "src/common/utils.h" +#include "src/rocm/batched_matrix_mul/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace rocm { + +/*! + * \brief base class for matrix mul algos + * + */ +class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + enum class AlgoType : uint32_t { + ROCM_BLAS, + }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } + struct SizeArgs { + BatchedMatrixMulForwardImpl* opr; + TensorLayout layout_a, layout_b, layout_c; + + std::string to_string() const; + SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A, + const TensorLayout& B, const TensorLayout& C); + + bool can_be_treated_as_int8x8x32() const { + return layout_a.dtype.enumv() == layout_b.dtype.enumv() && + (layout_a.dtype.enumv() == DTypeEnum::Int8 || + layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) && + (layout_c.dtype.enumv() == DTypeEnum::Int32 || + layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) && + opr->param().format == param::MatrixMul::Format::DEFAULT; + } + }; + struct ExecArgs : public SizeArgs { + TensorND tensor_a, tensor_b, tensor_c; + Workspace workspace; + + ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, + _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) const { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) const { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert( + req <= workspace.size, + "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } +}; + +class BatchedMatrixMulForwardImpl::AlgoBlas final : public AlgoBase { +public: + AlgoBlas() = default; + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { + return 0_z; + } + const char* name() const override { return "BLAS"; } + void exec(const ExecArgs& args) const override; + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) +}; + +class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; + +public: + AlgoPack(); + AlgoBlas blas; + std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/batched_matrix_mul/blas.cpp b/dnn/src/rocm/batched_matrix_mul/blas.cpp new file mode 100644 index 000000000..9672377fa --- /dev/null +++ b/dnn/src/rocm/batched_matrix_mul/blas.cpp @@ -0,0 +1,140 @@ +/** + * \file dnn/src/rocm/batched_matrix_mul/Blas.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/rocm/batched_matrix_mul/algos.h" + +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/handle.h" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +bool BatchedMatrixMulForwardImpl::AlgoBlas::is_available( + const SizeArgs& args) const { + if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) + return false; + if (args.layout_a.dtype == dtype::Float32() || + args.layout_a.dtype == dtype::Float16()) { + return true; + } + return false; +} + +void BatchedMatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const { + auto batch = args.layout_a.shape[0]; + auto m = args.layout_c.shape[1], n = args.layout_c.shape[2]; + auto k = args.layout_a.shape[args.opr->param().transposeA ? 1 : 2]; + auto&& handle = concrete_handle(args.opr->handle()); + auto rocblas_handle_ = handle->get_rocblas_handle(); + + auto sgemm = [&]() { + auto zero = handle->zero_device(); + auto one = handle->one_device(); + rocblas_check(rocblas_sgemm_strided_batched( + rocblas_handle_, + args.opr->param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + args.opr->param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, args.tensor_b.ptr(), + (rocblas_int)(args.layout_b.stride[1]), + (rocblas_int)(args.layout_b.stride[0]), + args.tensor_a.ptr(), + (rocblas_int)(args.layout_a.stride[1]), + (rocblas_int)(args.layout_a.stride[0]), zero, + args.tensor_c.ptr(), + (rocblas_int)(args.layout_c.stride[1]), + (rocblas_int)(args.layout_c.stride[0]), (rocblas_int)(batch))); + + }; + +#if !MEGDNN_DISABLE_FLOAT16 + //! used for FLOAT_IO16xC32, not tested + auto gemm_ex = [&]() { + auto zero = handle->zero_device(); + auto one = handle->one_device(); + //! These two arguments for future use, see + //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp + int32_t solution_index = 0; + uint32_t flags = 1; + size_t ws_size = 0; + + rocblas_check(rocblas_gemm_strided_batched_ex( + rocblas_handle_, + args.opr->param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + args.opr->param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r, + args.layout_b.stride[1], args.layout_b.stride[0], + args.tensor_a.raw_ptr, rocblas_datatype_i8_r, + args.layout_a.stride[1], args.layout_a.stride[0], zero, + args.tensor_c.raw_ptr, rocblas_datatype_i32_r, + args.layout_c.stride[1], args.layout_c.stride[0], + args.tensor_c.raw_ptr, rocblas_datatype_i32_r, + args.layout_c.stride[1], args.layout_c.stride[0], batch, + rocblas_datatype_i32_r, rocblas_gemm_algo_standard, + solution_index, flags, &ws_size, nullptr)); + + MEGDNN_MARK_USED_VAR(ws_size); + }; + + auto hgemm = [&]() { + auto one_half = handle->one_device_h(); + auto zero_half = handle->zero_device_h(); + rocblas_check(rocblas_hgemm_strided_batched( + rocblas_handle_, + args.opr->param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + args.opr->param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, reinterpret_cast(one_half), + static_cast(args.tensor_b.raw_ptr), + args.layout_b.stride[1], args.layout_b.stride[0], + static_cast(args.tensor_a.raw_ptr), + args.layout_a.stride[1], args.layout_a.stride[0], + reinterpret_cast(zero_half), + static_cast(args.tensor_c.raw_ptr), + args.layout_c.stride[1], args.layout_c.stride[0], batch)); + + }; +#endif + + if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) { + if (args.layout_a.dtype == dtype::Float32()) { + sgemm(); + } +#if !MEGDNN_DISABLE_FLOAT16 + else { + megdnn_assert(args.layout_a.dtype == dtype::Float16(), + "invalid matmul data type"); + hgemm(); + } +#endif + } +#if !MEGDNN_DISABLE_FLOAT16 + else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) { + megdnn_assert(args.layout_b.dtype == dtype::Float16() && + args.layout_c.dtype == dtype::Float16() && + args.layout_a.dtype == dtype::Float16(), + "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " + "B, C are all Float16"); + gemm_ex(); + } +#endif + else { + megdnn_throw("Unsupported data_type of matrix mul on rocm."); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp b/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp index c53764dd3..496d86bf2 100644 --- a/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp @@ -10,111 +10,58 @@ * implied. */ #include "./opr_impl.h" +#include "./algos.h" #include "hcc_detail/hcc_defs_prologue.h" +#include "src/common/algo_chooser.h" #include "src/common/utils.cuh" #include "src/rocm/handle.h" #include "src/rocm/utils.h" -namespace megdnn { -namespace rocm { +using namespace megdnn; +using namespace rocm; + +std::vector +BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_all_algorithms(args); +} + +BatchedMatrixMulForwardImpl::Algorithm* +BatchedMatrixMulForwardImpl::get_algorithm_heuristic( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_limit_in_bytes, bool reproducible) { + AlgoBase::SizeArgs args{this, A, B, C}; + if (sm_algo_pack.blas.is_available_reproducible(args, reproducible, + workspace_limit_in_bytes)) { + return &sm_algo_pack.blas; + } + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.all_algos, args, workspace_limit_in_bytes, + "batched matrix mul forward"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.all_algos, args, workspace_limit_in_bytes, + "batched matrix mul forward"); + } +} + +size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); +} void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) { check_exec(A.layout, B.layout, C.layout, workspace.size); - auto dtype = A.layout.dtype; - megdnn_assert(dtype.category() == DTypeCategory::FLOAT && - param().format == param::MatrixMul::Format::DEFAULT); - - if (dtype == dtype::Float32() || - MEGDNN_FLOAT16_SELECT(dtype == dtype::Float16(), false)) { - auto batch = A.layout.shape[0]; - auto m = C.layout.shape[1], n = C.layout.shape[2]; - auto k = A.layout.shape[param().transposeA ? 1 : 2]; - auto handle = concrete_handle(this->handle()); - auto rocblas_handle_ = handle->get_rocblas_handle(); - - auto io32_c32 = [&]() { - auto zero = handle->zero_device(); - auto one = handle->one_device(); - rocblas_check(rocblas_sgemm_strided_batched( - rocblas_handle_, - param().transposeB ? rocblas_operation_transpose - : rocblas_operation_none, - param().transposeA ? rocblas_operation_transpose - : rocblas_operation_none, - n, m, k, one, B.ptr(), - (rocblas_int)(B.layout.stride[1]), - (rocblas_int)(B.layout.stride[0]), A.ptr(), - (rocblas_int)(A.layout.stride[1]), - (rocblas_int)(A.layout.stride[0]), zero, - C.ptr(), (rocblas_int)(C.layout.stride[1]), - (rocblas_int)(C.layout.stride[0]), (rocblas_int)(batch))); - }; - -#if !MEGDNN_DISABLE_FLOAT16 - auto io16_c32 = [&]() { - auto zero = handle->zero_device(); - auto one = handle->one_device(); - int32_t solution_index = 0; - uint32_t flags = 1; - size_t ws_size = 0; - - rocblas_check(rocblas_gemm_strided_batched_ex( - rocblas_handle_, - param().transposeB ? rocblas_operation_transpose - : rocblas_operation_none, - param().transposeA ? rocblas_operation_transpose - : rocblas_operation_none, - n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, - B.layout.stride[1], B.layout.stride[0], A.raw_ptr, - rocblas_datatype_i8_r, A.layout.stride[1], - A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, - C.layout.stride[1], C.layout.stride[0], C.raw_ptr, - rocblas_datatype_i32_r, C.layout.stride[1], - C.layout.stride[0], batch, rocblas_datatype_i32_r, - rocblas_gemm_algo_standard, solution_index, flags, &ws_size, - nullptr)); - }; - - auto io16_c16 = [&]() { - auto zero_half = handle->zero_device_h(); - auto one_half = handle->one_device_h(); - rocblas_check(rocblas_hgemm_strided_batched( - rocblas_handle_, - param().transposeB ? rocblas_operation_transpose - : rocblas_operation_none, - param().transposeA ? rocblas_operation_transpose - : rocblas_operation_none, - n, m, k, reinterpret_cast(one_half), - static_cast(B.raw_ptr), - B.layout.stride[1], B.layout.stride[0], - static_cast(A.raw_ptr), - A.layout.stride[1], A.layout.stride[0], - reinterpret_cast(zero_half), - static_cast(C.raw_ptr), C.layout.stride[1], - C.layout.stride[0], batch)); - - }; -#endif - - if (dtype == dtype::Float32()) { - io32_c32(); - } -#if !MEGDNN_DISABLE_FLOAT16 - else { - if (param().compute_mode == Param::ComputeMode::FLOAT32) { - io16_c32(); - } else { - io16_c16(); - } - } -#endif - } + AlgoBase::ExecArgs args(this, A, B, C, workspace); + auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); + algo->check_workspace(args, workspace).exec(args); } -} // namespace rocm -} // namespace megdnn - // vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.h b/dnn/src/rocm/batched_matrix_mul/opr_impl.h index 60ca11172..fa9b0bfb0 100644 --- a/dnn/src/rocm/batched_matrix_mul/opr_impl.h +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -17,36 +18,35 @@ namespace rocm { class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward { public: using BatchedMatrixMulForward::BatchedMatrixMulForward; - BatchedMatrixMulForwardImpl(Handle* handle) - : BatchedMatrixMul(handle), - m_opr(handle->create_operator()) {} void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&) override { - return 0; - } + const TensorLayout&) override; + + bool is_thread_safe() const override { return true; } + class AlgoBase; + class AlgoBlas; + class AlgoPack; + static const AlgoPack& algo_pack() { return sm_algo_pack; } + + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); +private: std::vector get_all_algorithms( const TensorLayout& /*A*/, const TensorLayout& /*B*/, - const TensorLayout& /*C*/) override { - return {}; - } + const TensorLayout& /*C*/) override; Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, - bool /* reproducible */) override { - return nullptr; - } - - const char* get_algorithm_set_name() const override { return "DEFAULT"; } + bool /*reproducible*/) override; - bool is_thread_safe() const override { return true; } + const char* get_algorithm_set_name() const override { + return "ROCM BATCHED MATMUL"; + } -private: - std::unique_ptr m_opr; + static AlgoPack sm_algo_pack; }; } // namespace rocm diff --git a/dnn/src/rocm/matrix_mul/algos.cpp b/dnn/src/rocm/matrix_mul/algos.cpp new file mode 100644 index 000000000..d8063540e --- /dev/null +++ b/dnn/src/rocm/matrix_mul/algos.cpp @@ -0,0 +1,62 @@ +/** + * \file dnn/src/rocm/matrix_mul/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/rocm/matrix_mul/algos.h" +#include "src/common/algo_base.h" + +using namespace megdnn; +using namespace rocm; + +MatrixMulForwardImpl::AlgoPack::AlgoPack() { + all_algos.push_back(&blas); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } +} + +MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; + +MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) + +MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, + const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) + : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} + +MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(MatrixMulForwardImpl* opr, + _megdnn_tensor_in A, + _megdnn_tensor_in B, + _megdnn_tensor_out C, + _megdnn_workspace workspace) + : SizeArgs(opr, A.layout, B.layout, C.layout), + tensor_a{A}, + tensor_b{B}, + tensor_c{C}, + workspace{workspace} {} + +std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { + auto&& param = opr->param(); + size_t m = layout_a.shape[0], n = layout_b.shape[1], + k = layout_a.shape[param.transposeA ? 0 : 1]; + MEGDNN_MARK_USED_VAR(m); + MEGDNN_MARK_USED_VAR(n); + MEGDNN_MARK_USED_VAR(k); + return megdnn_mangle(ssprintf( + "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " + "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", + m, k, k, n, m, n, param.transposeA, param.transposeB, + layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/matrix_mul/algos.h b/dnn/src/rocm/matrix_mul/algos.h new file mode 100644 index 000000000..dcf280c8c --- /dev/null +++ b/dnn/src/rocm/matrix_mul/algos.h @@ -0,0 +1,118 @@ +/** + * \file dnn/src/rocm/matrix_mul/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" +#include "src/common/utils.h" +#include "src/rocm/matrix_mul/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace rocm { + +/*! + * \brief base class for matrix mul algos + * + */ +class MatrixMulForwardImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + enum class AlgoType : uint32_t { + ROCM_BLAS, + }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } + struct SizeArgs { + MatrixMulForwardImpl* opr; + TensorLayout layout_a, layout_b, layout_c; + + std::string to_string() const; + SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, + const TensorLayout& B, const TensorLayout& C); + + bool can_be_treated_as_int8x8x32() const { + return layout_a.dtype.enumv() == layout_b.dtype.enumv() && + (layout_a.dtype.enumv() == DTypeEnum::Int8 || + layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) && + (layout_c.dtype.enumv() == DTypeEnum::Int32 || + layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) && + opr->param().format == param::MatrixMul::Format::DEFAULT; + } + }; + struct ExecArgs : public SizeArgs { + TensorND tensor_a, tensor_b, tensor_c; + Workspace workspace; + + ExecArgs(MatrixMulForwardImpl* opr, _megdnn_tensor_in A, + _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) const { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) const { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert( + req <= workspace.size, + "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } +}; + +class MatrixMulForwardImpl::AlgoBlas final : public AlgoBase { +public: + AlgoBlas() = default; + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { + return 0_z; + } + const char* name() const override { return "BLAS"; } + void exec(const ExecArgs& args) const override; + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) +}; + +class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; + +public: + AlgoPack(); + AlgoBlas blas; + std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/matrix_mul/blas.cpp b/dnn/src/rocm/matrix_mul/blas.cpp new file mode 100644 index 000000000..8bc2f58cb --- /dev/null +++ b/dnn/src/rocm/matrix_mul/blas.cpp @@ -0,0 +1,162 @@ +/** + * \file dnn/src/rocm/matrix_mul/Blas.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/rocm/matrix_mul/algos.h" + +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/handle.h" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +bool MatrixMulForwardImpl::AlgoBlas::is_available( + const SizeArgs& args) const { + if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) + return false; + if (args.layout_a.dtype == dtype::Float32() || + args.layout_a.dtype == dtype::Float16()) { + return true; + } else if (args.layout_a.dtype.enumv() == DTypeEnum::Int8 || + args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) { + auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1]; + //! see + //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 + bool rocblas_int8x8x32_valid = true; + rocblas_int8x8x32_valid &= (k % 4 == 0); + rocblas_int8x8x32_valid &= (!args.opr->param().transposeB || + args.layout_b.stride[0] % 4 == 0); + rocblas_int8x8x32_valid &= (!args.opr->param().transposeA || + args.layout_a.stride[0] % 4 == 0); + return rocblas_int8x8x32_valid; + } + return false; +} + +void MatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const { + auto m = args.layout_c.shape[0], n = args.layout_c.shape[1]; + auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1]; + auto&& handle = concrete_handle(args.opr->handle()); + auto rocblas_handle_ = handle->get_rocblas_handle(); + + auto sgemm = [&]() { + auto zero = handle->zero_device(); + auto one = handle->one_device(); + rocblas_check(rocblas_sgemm( + rocblas_handle_, + args.opr->param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + args.opr->param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, args.tensor_b.ptr(), + args.layout_b.stride[0], args.tensor_a.ptr(), + args.layout_a.stride[0], zero, args.tensor_c.ptr(), + args.layout_c.stride[0])); + }; + +#if !MEGDNN_DISABLE_FLOAT16 + //! used for FLOAT_IO16xC32, not tested + auto gemm_ex = [&]() { + auto zero = handle->zero_device(); + auto one = handle->one_device(); + //! These two arguments for future use, see + //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp + int32_t solution_index = 0; + uint32_t flags = 1; + size_t ws_size = 0; + auto gemm_ex_err = rocblas_gemm_ex( + rocblas_handle_, + args.opr->param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + args.opr->param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_f16_r, + args.layout_b.stride[0], args.tensor_a.raw_ptr, + rocblas_datatype_f16_r, args.layout_a.stride[0], zero, + args.tensor_c.raw_ptr, rocblas_datatype_f16_r, + args.layout_c.stride[0], args.tensor_c.raw_ptr, + rocblas_datatype_f16_r, args.layout_c.stride[0], + rocblas_datatype_f32_r, rocblas_gemm_algo_standard, + solution_index, flags, &ws_size, nullptr); + rocblas_check(gemm_ex_err); + MEGDNN_MARK_USED_VAR(ws_size); + }; + + auto hgemm = [&]() { + auto one_half = handle->one_device_h(); + auto zero_half = handle->zero_device_h(); + auto hgemm_err = rocblas_hgemm( + rocblas_handle_, + args.opr->param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + args.opr->param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, reinterpret_cast(one_half), + static_cast(args.tensor_b.raw_ptr), + args.layout_b.stride[0], + static_cast(args.tensor_a.raw_ptr), + args.layout_a.stride[0], + reinterpret_cast(zero_half), + static_cast(args.tensor_c.raw_ptr), + args.layout_c.stride[0]); + rocblas_check(hgemm_err); + }; +#endif + + if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) { + if (args.layout_a.dtype == dtype::Float32()) { + sgemm(); + } +#if !MEGDNN_DISABLE_FLOAT16 + else { + megdnn_assert(args.layout_a.dtype == dtype::Float16(), + "invalid matmul data type"); + hgemm(); + } +#endif + } +#if !MEGDNN_DISABLE_FLOAT16 + else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) { + megdnn_assert(args.layout_b.dtype == dtype::Float16() && + args.layout_c.dtype == dtype::Float16() && + args.layout_a.dtype == dtype::Float16(), + "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " + "B, C are all Float16"); + gemm_ex(); + } +#endif + else { + megdnn_assert(args.can_be_treated_as_int8x8x32()); + int32_t solution_index = 0; + uint32_t flags = 1; + size_t ws_size = 0; + auto zero = handle->zero_device_i32(); + auto one = handle->one_device_i32(); + rocblas_check(rocblas_gemm_ex( + rocblas_handle_, + args.opr->param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + args.opr->param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r, + args.layout_b.stride[0], args.tensor_a.raw_ptr, + rocblas_datatype_i8_r, args.layout_a.stride[0], zero, + args.tensor_c.raw_ptr, rocblas_datatype_i32_r, + args.layout_c.stride[0], args.tensor_c.raw_ptr, + rocblas_datatype_i32_r, args.layout_c.stride[0], + rocblas_datatype_i32_r, rocblas_gemm_algo_standard, + solution_index, flags, &ws_size, nullptr)); + MEGDNN_MARK_USED_VAR(ws_size); + } + +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/matrix_mul/opr_impl.cpp b/dnn/src/rocm/matrix_mul/opr_impl.cpp index e34ad53a5..75fe8f148 100644 --- a/dnn/src/rocm/matrix_mul/opr_impl.cpp +++ b/dnn/src/rocm/matrix_mul/opr_impl.cpp @@ -13,147 +13,53 @@ #include "src/rocm/utils.h" #include "src/rocm/handle.h" +#include "./algos.h" +#include "src/common/algo_chooser.h" -namespace megdnn { -namespace rocm { +using namespace megdnn; +using namespace rocm; -void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, - _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace) -{ - check_exec(A.layout, B.layout, C.layout, workspace.size); - - auto m = C.layout.shape[0], n = C.layout.shape[1]; - auto k = A.layout.shape[param().transposeA ? 0 : 1]; - auto handle = concrete_handle(this->handle()); - auto rocblas_handle_ = handle->get_rocblas_handle(); - - auto sgemm = [&]() { - auto zero = handle->zero_device(); - auto one = handle->one_device(); - rocblas_check(rocblas_sgemm( - rocblas_handle_, - param().transposeB ? rocblas_operation_transpose - : rocblas_operation_none, - param().transposeA ? rocblas_operation_transpose - : rocblas_operation_none, - n, m, k, one, B.ptr(), B.layout.stride[0], - A.ptr(), A.layout.stride[0], zero, - C.ptr(), C.layout.stride[0])); - }; - -#if !MEGDNN_DISABLE_FLOAT16 - //! used for FLOAT_IO16xC32, not tested - auto gemm_ex = [&]() { - auto zero = handle->zero_device(); - auto one = handle->one_device(); - //! These two arguments for future use, see - //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp - int32_t solution_index = 0; - uint32_t flags = 1; - size_t ws_size = 0; - auto gemm_ex_err = rocblas_gemm_ex( - rocblas_handle_, - param().transposeB ? rocblas_operation_transpose - : rocblas_operation_none, - param().transposeA ? rocblas_operation_transpose - : rocblas_operation_none, - n, m, k, one, B.raw_ptr, rocblas_datatype_f16_r, - B.layout.stride[0], A.raw_ptr, rocblas_datatype_f16_r, - A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_f16_r, - C.layout.stride[0], C.raw_ptr, rocblas_datatype_f16_r, - C.layout.stride[0], rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, solution_index, flags, &ws_size, - nullptr); - rocblas_check(gemm_ex_err); - }; - - auto hgemm = [&]() { - auto one_half = handle->one_device_h(); - auto zero_half = handle->zero_device_h(); - auto hgemm_err = rocblas_hgemm( - rocblas_handle_, - param().transposeB ? rocblas_operation_transpose - : rocblas_operation_none, - param().transposeA ? rocblas_operation_transpose - : rocblas_operation_none, - n, m, k, reinterpret_cast(one_half), - static_cast(B.raw_ptr), B.layout.stride[0], - static_cast(A.raw_ptr), A.layout.stride[0], - reinterpret_cast(zero_half), - static_cast(C.raw_ptr), C.layout.stride[0]); - rocblas_check(hgemm_err); - }; -#endif +std::vector +MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_all_algorithms(args); +} - if (param().compute_mode == Param::ComputeMode::DEFAULT) { - if (A.layout.dtype == dtype::Float32()) { - sgemm(); - } -#if !MEGDNN_DISABLE_FLOAT16 - else { - megdnn_assert(A.layout.dtype == dtype::Float16(), - "invalid matmul data type"); - hgemm(); - } -#endif +MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_limit_in_bytes, bool reproducible) { + AlgoBase::SizeArgs args{this, A, B, C}; + if (sm_algo_pack.blas.is_available_reproducible( + args, reproducible, workspace_limit_in_bytes)) { + return &sm_algo_pack.blas; } -#if !MEGDNN_DISABLE_FLOAT16 - else if (param().compute_mode == Param::ComputeMode::FLOAT32) { - megdnn_assert(B.layout.dtype == dtype::Float16() && - C.layout.dtype == dtype::Float16() && - A.layout.dtype == dtype::Float16(), - "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " - "B, C are all Float16"); - gemm_ex(); - } -#endif - else if (A.layout.dtype == dtype::Int8() && - B.layout.dtype == dtype::Int8() && - C.layout.dtype == dtype::Int32()) { - //! see - //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 - bool rocblas_int8x8x32_valid = true; - rocblas_int8x8x32_valid &= (k % 4 == 0); - rocblas_int8x8x32_valid &= - (!param().transposeB || B.layout.stride[0] % 4 == 0); - rocblas_int8x8x32_valid &= - (!param().transposeA || A.layout.stride[0] % 4 == 0); - megdnn_assert(rocblas_int8x8x32_valid, - "rocblas int8x8x32 matmul requires K must be a multiple " - "of 4, and/or LDA/LDB based on transpose mode" - "get: %zu, is_trans_b = %d, %zu, is_trans_a = %d, %zu", - k, param().transposeB, B.layout.stride[0], - param().transposeA, A.layout.stride[0]); - int32_t solution_index = 0; - uint32_t flags = 1; - size_t ws_size = 0; - auto zero = handle->zero_device_i32(); - auto one = handle->one_device_i32(); - rocblas_check(rocblas_gemm_ex( - rocblas_handle_, - param().transposeB ? rocblas_operation_transpose - : rocblas_operation_none, - param().transposeA ? rocblas_operation_transpose - : rocblas_operation_none, - n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, - B.layout.stride[0], A.raw_ptr, rocblas_datatype_i8_r, - A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, - C.layout.stride[0], C.raw_ptr, rocblas_datatype_i32_r, - C.layout.stride[0], rocblas_datatype_i32_r, - rocblas_gemm_algo_standard, solution_index, flags, &ws_size, - nullptr)); + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.all_algos, args, workspace_limit_in_bytes, + "matrix mul forward"); } else { - megdnn_assert((A.layout.dtype == dtype::Int8() && - B.layout.dtype == dtype::Int8() && - C.layout.dtype == dtype::Int16()), - "invalid matmul data type"); - megdnn_throw("cuda matmul does not support INT8x8x16 now"); + return megdnn::get_usable_algo( + sm_algo_pack.all_algos, args, workspace_limit_in_bytes, + "matrix mul forward"); } } -} // namespace rocm -} // namespace megdnn +size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); +} + +void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, + _megdnn_tensor_out C, + _megdnn_workspace workspace) { + check_exec(A.layout, B.layout, C.layout, workspace.size); + AlgoBase::ExecArgs args(this, A, B, C, workspace); + auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); + algo->check_workspace(args, workspace).exec(args); +} // vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/matrix_mul/opr_impl.h b/dnn/src/rocm/matrix_mul/opr_impl.h index 5d8abad4e..ca2a190fa 100644 --- a/dnn/src/rocm/matrix_mul/opr_impl.h +++ b/dnn/src/rocm/matrix_mul/opr_impl.h @@ -20,29 +20,32 @@ public: void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&) override { - return 0; - } + const TensorLayout&) override; bool is_thread_safe() const override { return true; } + + class AlgoBase; + class AlgoBlas; + class AlgoPack; + static const AlgoPack& algo_pack() { return sm_algo_pack; } + + static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + private: std::vector get_all_algorithms( const TensorLayout& /*A*/, const TensorLayout& /*B*/, - const TensorLayout& /*C*/) override { - return {}; - } + const TensorLayout& /*C*/) override; Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, - bool /*reproducible*/) override { - return nullptr; - } - + bool /*reproducible*/) override; const char* get_algorithm_set_name() const override { return "ROCM MATMUL"; } + + static AlgoPack sm_algo_pack; }; } // namespace rocm diff --git a/dnn/test/fallback/matrix_mul.cpp b/dnn/test/fallback/matrix_mul.cpp index 74b77b29e..92807b9ac 100644 --- a/dnn/test/fallback/matrix_mul.cpp +++ b/dnn/test/fallback/matrix_mul.cpp @@ -46,6 +46,37 @@ TEST_F(FALLBACK, MATRIX_MUL) { } } +TEST_F(FALLBACK, MATRIX_MUL_NAIVE) { + Checker checker(handle()); + checker.set_before_exec_callback(AlgoChecker("FB_NAIVE")); + using Param = MatrixMul::Param; + auto args = matrix_mul::get_matmul_args(); + for (auto arg : args) { + auto m = arg.m, n = arg.n, k = arg.k; + auto mask = arg.mask; + Param param; + param.transposeA = mask & 1; + param.transposeB = mask & 2; + TensorShape AS, BS, CS; + + if (param.transposeA) + AS = TensorShape{k, m}; + else + AS = TensorShape{m, k}; + if (param.transposeB) + BS = TensorShape{n, k}; + else + BS = TensorShape{k, n}; + CS = TensorShape{m, n}; + TensorLayout AL, BL, CL; + AL = TensorLayout(AS, dtype::Float32()); + BL = TensorLayout(BS, dtype::Float32()); + CL = TensorLayout(CS, dtype::Float32()); + checker.set_param(param); + checker.execl({AL, BL, CL}); + } +} + TEST_F(FALLBACK, BATCHED_MATRIX_MUL) { Checker checker(handle()); diff --git a/dnn/test/naive/matrix_mul.cpp b/dnn/test/naive/matrix_mul.cpp index b708d7f73..a4fe85e96 100644 --- a/dnn/test/naive/matrix_mul.cpp +++ b/dnn/test/naive/matrix_mul.cpp @@ -232,7 +232,7 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { 2, 5, 3, 3, 7, 4, -7, 1, -5, 7, -4, -1, -1, 2, 4, 1, 7, 2, -6, -2, -6, 3, 4, 4, - -2, 2, 3, 0, 6, 5, 3, 4, + -2, 2, 3, 0, 6, 5, 3, 4, -1, -1, -5, 5, 2, 5, 1, 4, 6, 2, 0, 0, 3, 2, 2, 1, -4, -3, 7, 5, 0, 3, 2, 3}), @@ -243,7 +243,7 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { 3, -1, 2, 2, 7, 3, 6, 0, 5, 4, 0, 2, 2, 3, 3, 2, 1, -8, -7, -6, 0, -5, -4, 4, - -3, 7, 1, 6, -2, 2, -1, 5, + -3, 7, 1, 6, -2, 2, -1, 5, 2, 0, 7, 6, 5, 4, 3, 2, 0, 0, 1, 0, 5, 2, 2, 6}), {}}, -- GitLab