提交 409a8772 编写于 作者: M Megvii Engine Team

feat(dnn): add algo interface for rocm&fallback matmul and batched matrix mul

GitOrigin-RevId: dea03a0f7a3ec436719d9c69570cf827bbb444b5
上级 27503620
/**
* \file dnn/src/fallback/batched_matrix_mul/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/fallback/batched_matrix_mul/algos.h"
#include "src/common/algo_base.h"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace fallback;
BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo_default);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack;
MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl)
BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(
BatchedMatrixMulForwardImpl* o, const TensorLayout& A,
const TensorLayout& B, const TensorLayout& C)
: opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {}
BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(
BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A,
_megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace)
: SizeArgs(opr, A.layout, B.layout, C.layout),
tensor_a{A},
tensor_b{B},
tensor_c{C},
workspace{workspace} {}
std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const {
auto&& param = opr->param();
size_t m = layout_a.shape[0], n = layout_b.shape[1],
k = layout_a.shape[param.transposeA ? 0 : 1];
MEGDNN_MARK_USED_VAR(m);
MEGDNN_MARK_USED_VAR(n);
MEGDNN_MARK_USED_VAR(k);
return megdnn_mangle(ssprintf(
"A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose "
"B=%d,ldA=%zu,ldB=%zu,ldC=%zu",
m, k, k, n, m, n, param.transposeA, param.transposeB,
layout_a.stride[0], layout_b.stride[0], layout_c.stride[0]));
}
/* ===================== default algo ===================== */
size_t BatchedMatrixMulForwardImpl::AlgoDefault::get_workspace_in_bytes(
const SizeArgs& args) const {
auto opr = inplace_cpu_handle()->create_operator<MatrixMul>();
auto A_ = args.layout_a.remove_axis(0), B_ = args.layout_b.remove_axis(0),
C_ = args.layout_c.remove_axis(0);
opr->param() = args.opr->param();
return opr->get_workspace_in_bytes(A_, B_, C_);
}
void BatchedMatrixMulForwardImpl::AlgoDefault::exec(
const ExecArgs& args) const {
//! As megbrain may modify param when checking all transpose situations, so
//! here we should copy the param when dispatching kern
auto param = args.opr->param();
auto kern = [args, param]() {
auto N = args.layout_a.shape[0];
TensorND A_, B_, C_;
A_.raw_ptr = args.tensor_a.raw_ptr;
A_.layout = args.layout_a.remove_axis(0);
B_.raw_ptr = args.tensor_b.raw_ptr;
B_.layout = args.layout_b.remove_axis(0);
C_.raw_ptr = args.tensor_c.raw_ptr;
C_.layout = args.layout_c.remove_axis(0);
auto Astrd = args.layout_a.dtype.size() * args.layout_a.stride[0],
Bstrd = args.layout_b.dtype.size() * args.layout_b.stride[0],
Cstrd = args.layout_c.dtype.size() * args.layout_c.stride[0];
auto advance_ptr = [](TensorND& dest, ptrdiff_t d) {
dest.raw_ptr =
static_cast<void*>(static_cast<dt_byte*>(dest.raw_ptr) + d);
};
auto opr = inplace_cpu_handle()->create_operator<MatrixMul>();
opr->param() = param;
rep(n, N) {
opr->exec(A_, B_, C_, args.workspace);
advance_ptr(A_, Astrd);
advance_ptr(B_, Bstrd);
advance_ptr(C_, Cstrd);
}
};
static_cast<naive::HandleImpl*>(args.opr->handle())->dispatch_kern(kern);
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/batched_matrix_mul/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/fallback/batched_matrix_mul/opr_impl.h"
#include <memory>
#include <unordered_map>
namespace megdnn {
namespace fallback {
/*!
* \brief base class for matrix mul algos
*
*/
class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
fallback_BLAS,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; }
struct SizeArgs {
BatchedMatrixMulForwardImpl* opr;
TensorLayout layout_a, layout_b, layout_c;
std::string to_string() const;
SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A,
const TensorLayout& B, const TensorLayout& C);
};
struct ExecArgs : public SizeArgs {
TensorND tensor_a, tensor_b, tensor_c;
Workspace workspace;
ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A,
_megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs&) const = 0;
bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(
req <= workspace.size,
"matrix mul fwd algo %s: required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}
};
class BatchedMatrixMulForwardImpl::AlgoDefault final : public AlgoBase {
public:
AlgoDefault() = default;
bool is_available(const SizeArgs&) const override { return true; }
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override;
const char* name() const override { return "DEFAULT"; }
virtual void exec(const ExecArgs&) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(fallback_BLAS)
};
class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
AlgoDefault algo_default;
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,67 +6,61 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./opr_impl.h"
#include "src/naive/handle.h"
#include "./algos.h"
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/common/algo_chooser.h"
#include "src/common/utils.cuh"
#include "src/fallback/handle.h"
using namespace megdnn;
using namespace fallback;
BatchedMatrixMulImpl::BatchedMatrixMulImpl(Handle *handle):
BatchedMatrixMulForwardImpl(handle),
m_storage(new CpuOprDelegationStorage<>),
m_opr(m_storage->get<MatrixMul>())
{
std::vector<BatchedMatrixMulForwardImpl::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args);
}
size_t BatchedMatrixMulImpl::get_workspace_in_bytes(
const TensorLayout &A, const TensorLayout &B,
const TensorLayout &C) {
auto A_ = A.remove_axis(0), B_ = B.remove_axis(0), C_ = C.remove_axis(0);
m_opr->param() = param();
return m_opr->get_workspace_in_bytes(A_, B_, C_);
BatchedMatrixMulForwardImpl::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) {
AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.algo_default.is_available_reproducible(
args, reproducible, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_default;
}
if (reproducible) {
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward");
} else {
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward");
}
}
void BatchedMatrixMulImpl::exec(_megdnn_tensor_in A,
_megdnn_tensor_in B,
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
}
void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
m_opr->param() = this->param();
auto kern = [this, A, B, C, workspace]() {
auto N = A.layout.shape[0];
TensorND A_, B_, C_;
A_.raw_ptr = A.raw_ptr;
A_.layout = A.layout.remove_axis(0);
B_.raw_ptr = B.raw_ptr;
B_.layout = B.layout.remove_axis(0);
C_.raw_ptr = C.raw_ptr;
C_.layout = C.layout.remove_axis(0);
auto Astrd = A.layout.dtype.size() * A.layout.stride[0],
Bstrd = B.layout.dtype.size() * B.layout.stride[0],
Cstrd = C.layout.dtype.size() * C.layout.stride[0];
auto advance_ptr = [](TensorND &dest, ptrdiff_t d) {
dest.raw_ptr = static_cast<void*>(
static_cast<dt_byte*>(dest.raw_ptr) + d);
};
rep(n, N) {
m_opr->exec(A_, B_, C_, workspace);
advance_ptr(A_, Astrd);
advance_ptr(B_, Bstrd);
advance_ptr(C_, Cstrd);
}
};
static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern);
AlgoBase::ExecArgs args(this, A, B, C, workspace);
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout);
algo->check_workspace(args, workspace).exec(args);
}
// vim: syntax=cpp.doxygen
......@@ -15,26 +15,42 @@
namespace megdnn {
namespace fallback {
class BatchedMatrixMulImpl: public naive::BatchedMatrixMulForwardImpl {
public:
BatchedMatrixMulImpl(Handle *handle);
void exec(
_megdnn_tensor_in A,
_megdnn_tensor_in B,
_megdnn_tensor_out C,
class BatchedMatrixMulForwardImpl: public naive::BatchedMatrixMulForwardImpl {
public:
using naive::BatchedMatrixMulForwardImpl::BatchedMatrixMulForwardImpl;
void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override;
size_t get_workspace_in_bytes(const TensorLayout &A,
const TensorLayout &B,
const TensorLayout &C) override;
bool is_thread_safe() const override { return true; }
private:
std::unique_ptr<CpuOprDelegationStorage<>> m_storage;
MatrixMulForward* m_opr;
class AlgoBase;
class AlgoDefault;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
private:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override;
const char* get_algorithm_set_name() const override {
return "FALLBACK BATCHED MATMUL";
}
static AlgoPack sm_algo_pack;
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -473,6 +473,13 @@ public:
PostprocessMode::NO_PROCESS,
"NoPackStrategyType::FLOAT16_FLOAT16"_hash);
break;
#endif
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16:
cb1(NCHW, NO_PACK, dt_float16, __fp16,
PostprocessMode::NO_PROCESS,
"NoPackStrategyType::FLOAT_FP16"_hash);
break;
#endif
case StrategyType::INT8x8x16:
cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
......
......@@ -169,6 +169,10 @@ INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#undef INSTANTIAL_CLASS
} // namespace megdnn
......
......@@ -67,7 +67,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseMultiType)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdate)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMul)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC)
......
......@@ -10,13 +10,18 @@
*/
#include "src/fallback/matrix_mul/algos.h"
#include "megdnn/opr_param_defs.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#include "src/fallback/matrix_mul/gemv.h"
#include "src/fallback/matrix_mul/generic_strategy.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fb_matmul_f32_kern)
MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)
MIDOUT_DECL(megdnn_fb_matmul_naive)
using namespace megdnn;
using namespace fallback;
......@@ -39,6 +44,32 @@ void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) {
}
MIDOUT_END();
}
void kern_naive(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) {
size_t M = kern_param.M, N = kern_param.N, K = kern_param.K;
size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
#define DISPATCH(TA, TB) \
if (kern_param.trA == TA && kern_param.trB == TB) { \
naive::dispatch_ta_tb<TA, TB>( \
kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \
kern_param.workspace_ptr, M, N, K, LDA, LDB, LDC, \
kern_param.A_type, kern_param.B_type, kern_param.C_type, \
kern_param.format, kern_param.compute_mode); \
return; \
}
DISPATCH(true, true);
DISPATCH(true, false);
DISPATCH(false, true);
DISPATCH(false, false);
#undef DISPATCH
megdnn_assert_internal(0);
}
MIDOUT_END();
}
} // anonymous namespace
////////////////////// AlgoF32K8x12x1 ///////////////////////////
......@@ -84,7 +115,10 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern,
bool MatrixMulImpl::AlgoGemv::usable(
const KernSizeParam& kern_size_param) const {
return !kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.format ==
param::MatrixMul::Format::DEFAULT &&
kern_size_param.compute_mode ==
param::MatrixMul::ComputeMode::DEFAULT &&
!((kern_size_param.A_type.enumv() ==
kern_size_param.B_type.enumv()) &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int16) &&
......@@ -128,4 +162,44 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoGemv::get_kern(
megdnn_assert(0);
}
/* ===================== naive algo ===================== */
bool MatrixMulImpl::AlgoNaive::usable(const KernSizeParam&) const {
return true;
}
bool MatrixMulImpl::AlgoNaive::preferred(const KernSizeParam&) const {
return false;
}
size_t MatrixMulImpl::AlgoNaive::get_workspace(
const KernSizeParam& kern_param) const {
MIDOUT_BEGIN(
megdnn_fb_matmul_naive,
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) {
if (kern_param.A_type.enumv() == DTypeEnum::Quantized4Asymm ||
kern_param.A_type.enumv() == DTypeEnum::QuantizedS4) {
size_t ret = 0;
if (kern_param.trA) {
ret += kern_param.LDA * kern_param.K;
} else {
ret += kern_param.LDA * kern_param.M;
}
if (kern_param.trB) {
ret += kern_param.LDB * kern_param.N;
} else {
ret += kern_param.LDB * kern_param.K;
}
return ret;
}
return 0;
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(
const KernSizeParam&) const {
return kern_naive;
}
// vim: syntax=cpp.doxygen
......@@ -52,6 +52,28 @@ public:
DEFAULT)
};
class MatrixMulImpl::AlgoNaive final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "FB_NAIVE"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMM; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_DECL_ALGO_TYPE(FB_NAIVE)
MEGDNN_OVERRIDE_MATMUL_DESC(
8, 16, 1, 4,
static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)),
DEFAULT)
};
} // namespace fallback
} // namespace megdnn
......
......@@ -35,6 +35,7 @@ using namespace fallback;
class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32K8x12x1 f32_k8x12x1;
AlgoGemv gemv;
AlgoNaive naive;
SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map;
......@@ -42,6 +43,7 @@ public:
AlgoPack() {
m_all_algos.emplace_back(&gemv);
m_all_algos.emplace_back(&f32_k8x12x1);
m_all_algos.emplace_back(&naive);
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
......@@ -147,19 +149,26 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
algo_type.format = kern_size_param.format;
auto algos = select_algo_type(algo_type);
Algorithm *heuristic_algo = nullptr;
Algorithm *usable_algo = nullptr;
for (auto&& algo : algos) {
if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) &&
static_cast<AlgoBase*>(algo)->preferred_reproducible(
kern_size_param, reproducible) &&
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <=
workspace_limit_in_bytes) {
if (static_cast<AlgoBase*>(algo)->preferred_reproducible(
kern_size_param, reproducible)) {
//! use gemv algo if it's prefered
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
return algo;
} else if (!heuristic_algo) {
heuristic_algo = algo;
}
} else if (!usable_algo) {
usable_algo = algo;
}
}
}
if (!heuristic_algo) heuristic_algo = usable_algo;
megdnn_assert(heuristic_algo, "No usable algorithm found");
return heuristic_algo;
}
......
......@@ -110,6 +110,7 @@ public:
//! fallback
FB_F32K8x12x1 = 1 << 0,
FB_GEMV,
FB_NAIVE,
#if MEGDNN_X86
//! x86
......@@ -233,6 +234,7 @@ public:
private:
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1
class AlgoGemv;
class AlgoNaive;
class AlgoPack;
//! maintain all the algos of in the opr of fallback
static const AlgoPack& algo_pack();
......
......@@ -141,20 +141,39 @@ void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M,
}
template <bool transA, bool transB>
void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A,
_megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace,
const param::MatrixMul& param) {
void exec_matrix_mul_quint4x4x32_helper(
const void* A, const void* B, void* C, void* workspace, size_t M,
size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC,
DType A_type, DType B_type, DType C_type,
const MatrixMul::Param::Format& format,
const MatrixMul::Param::ComputeMode& compute_mode) {
MEGDNN_MARK_USED_VAR(C_type);
MEGDNN_MARK_USED_VAR(format);
MEGDNN_MARK_USED_VAR(compute_mode);
auto convert_layout = [](const TensorLayout& layout) {
auto ret = layout;
auto param = layout.dtype.param<dtype::Quantized4Asymm>();
ret.dtype = dtype::Quantized8Asymm(param.scale, param.zero_point);
return ret;
};
TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)};
TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(),
convert_layout(B.layout)};
TensorLayout A_layout, B_layout;
if (transA) {
A_layout = TensorLayout({K, M}, {LDA, 1}, A_type);
} else {
A_layout = TensorLayout({M, K}, {LDA, 1}, A_type);
}
if (transB) {
B_layout = TensorLayout({N, K}, {LDB, 1}, B_type);
} else {
B_layout = TensorLayout({K, N}, {LDB, 1}, B_type);
}
TensorND tensorA{const_cast<void*>(A), A_layout};
TensorND tensorB{const_cast<void*>(B), B_layout};
TensorND nA = {workspace, convert_layout(A_layout)};
TensorND nB = {
static_cast<uint8_t*>(workspace) + nA.layout.span().dist_byte(),
convert_layout(B_layout)};
auto convert_4to8 = [](const TensorND& in, const TensorND& out) {
auto ptr =
static_cast<uint8_t*>(in.raw_ptr) + in.layout.span().low_byte;
......@@ -168,31 +187,48 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A,
out_ptr[i + 1] = val1;
}
};
convert_4to8(A, nA);
convert_4to8(B, nB);
auto M = C.layout.shape[0], N = C.layout.shape[1];
auto K = A.layout.shape[param.transposeA ? 0 : 1];
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0],
LDC = C.layout.stride[0];
convert_4to8(tensorA, nA);
convert_4to8(tensorB, nB);
run_matrix_mul_tpl<uint8_t, dt_int32, transA, transB, dt_int32>(
nA.compatible_ptr<uint8_t>(), nB.compatible_ptr<uint8_t>(),
C.compatible_ptr<dt_int32>(), M, N, K, LDA, LDB, LDC,
nA.layout.dtype, nB.layout.dtype);
static_cast<dt_int32*>(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype,
nB.layout.dtype);
}
template <bool transA, bool transB>
void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace,
const param::MatrixMul& param) {
void exec_matrix_mul_qint4x4x16_helper(
const void* A, const void* B, void* C, void* workspace, size_t M,
size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC,
DType A_type, DType B_type, DType C_type,
const MatrixMul::Param::Format& format,
const MatrixMul::Param::ComputeMode& compute_mode) {
MEGDNN_MARK_USED_VAR(C_type);
MEGDNN_MARK_USED_VAR(format);
MEGDNN_MARK_USED_VAR(compute_mode);
auto convert_layout = [](const TensorLayout& layout) {
auto ret = layout;
auto param = layout.dtype.param<dtype::QuantizedS4>();
ret.dtype = dtype::QuantizedS8(param.scale);
return ret;
};
TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)};
TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(),
convert_layout(B.layout)};
TensorLayout A_layout, B_layout;
if (transA) {
A_layout = TensorLayout({K, M}, {LDA, 1}, A_type);
} else {
A_layout = TensorLayout({M, K}, {LDA, 1}, A_type);
}
if (transB) {
B_layout = TensorLayout({N, K}, {LDB, 1}, B_type);
} else {
B_layout = TensorLayout({K, N}, {LDB, 1}, B_type);
}
TensorND tensorA{const_cast<void*>(A), A_layout};
TensorND tensorB{const_cast<void*>(B), B_layout};
TensorND nA = {workspace, convert_layout(A_layout)};
TensorND nB = {
static_cast<uint8_t*>(workspace) + nA.layout.span().dist_byte(),
convert_layout(B_layout)};
auto convert_4to8 = [](const TensorND& in, const TensorND& out) {
auto ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte;
auto out_ptr =
......@@ -204,18 +240,98 @@ void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B,
out_ptr[i + 1] = cur >> 4;
}
};
convert_4to8(A, nA);
convert_4to8(B, nB);
auto M = C.layout.shape[0], N = C.layout.shape[1];
auto K = A.layout.shape[param.transposeA ? 0 : 1];
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0],
LDC = C.layout.stride[0];
convert_4to8(tensorA, nA);
convert_4to8(tensorB, nB);
run_matrix_mul_tpl<int8_t, dt_int16, transA, transB, dt_int16>(
nA.compatible_ptr<int8_t>(), nB.compatible_ptr<int8_t>(),
C.compatible_ptr<dt_int16>(), M, N, K, LDA, LDB, LDC,
nA.layout.dtype, nB.layout.dtype);
static_cast<dt_int16*>(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype,
nB.layout.dtype);
}
template <bool TA, bool TB>
void dispatch_ta_tb(const void* A, const void* B, void* C, void* workspace,
size_t M, size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB,
ptrdiff_t LDC, DType A_type, DType B_type, DType C_type,
const MatrixMul::Param::Format& format,
const MatrixMul::Param::ComputeMode& compute_mode) {
#define cb(_itype, _otype, _comp_type) \
if (format == param::MatrixMul::Format::DEFAULT) { \
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \
B_type); \
} else if (format == param::MatrixMul::Format::MK4) { \
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \
B_type); \
} else if (format == param::MatrixMul::Format::MK4_DOT) { \
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \
B_type); \
} else if (format == param::MatrixMul::Format::MK8) { \
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \
B_type); \
}
if (A_type == dtype::Float32()) {
cb(dt_float32, dt_float32, dt_float32);
#if !MEGDNN_DISABLE_FLOAT16
} else if (A_type == dtype::Float16()) {
using Param = MatrixMul::Param;
if (compute_mode == Param::ComputeMode::DEFAULT) {
cb(dt_float16, dt_float16, dt_float16);
} else if (compute_mode == Param::ComputeMode::FLOAT32) {
cb(dt_float16, dt_float16, dt_float32);
}
} else if (A_type == dtype::BFloat16()) {
using Param = MatrixMul::Param;
if (compute_mode == Param::ComputeMode::DEFAULT) {
cb(dt_bfloat16, dt_bfloat16, dt_bfloat16);
} else if (compute_mode == Param::ComputeMode::FLOAT32) {
cb(dt_bfloat16, dt_bfloat16, dt_float32);
}
#endif
} else if (A_type == dtype::Int8() &&
C_type == dtype::Int16()) {
cb(dt_int8, dt_int16, dt_int16);
} else if (A_type == dtype::Int16() &&
C_type == dtype::Int32()) {
cb(dt_int16, dt_int32, dt_int32);
} else if ((A_type == dtype::Int8() ||
A_type.enumv() == DTypeEnum::QuantizedS8) &&
(C_type == dtype::Int32() ||
C_type.enumv() == DTypeEnum::QuantizedS32)) {
cb(dt_int8, dt_int32, dt_int32);
} else if (A_type.enumv() == DTypeEnum::Quantized8Asymm &&
C_type.enumv() == DTypeEnum::QuantizedS32) {
cb(uint8_t, dt_int32, dt_int32);
} else if (A_type.enumv() == DTypeEnum::Quantized4Asymm &&
C_type.enumv() == DTypeEnum::QuantizedS32 &&
format == param::MatrixMul::Format::DEFAULT) {
exec_matrix_mul_quint4x4x32_helper<TA, TB>(
A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type,
C_type, format, compute_mode);
return;
} else if (A_type.enumv() == DTypeEnum::QuantizedS4 &&
C_type.enumv() == DTypeEnum::QuantizedS16 &&
format == param::MatrixMul::Format::DEFAULT) {
exec_matrix_mul_qint4x4x16_helper<TA, TB>(
A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type,
C_type, format, compute_mode);
return;
}
#undef cb
megdnn_throw(
ssprintf("unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)",
A_type.name(), B_type.name(), C_type.name(),
static_cast<int>(compute_mode)));
}
} // namespace naive
} // namespace megdnn
......
......@@ -45,77 +45,10 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B,
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0],
LDC = C.layout.stride[0];
#define cb(_itype, _otype, _comp_type) \
if (param.format == param::MatrixMul::Format::DEFAULT) { \
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4) { \
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4_DOT) { \
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK8) { \
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
}
if (A.layout.dtype == dtype::Float32()) {
cb(dt_float32, dt_float32, dt_float32);
#if !MEGDNN_DISABLE_FLOAT16
} else if (A.layout.dtype == dtype::Float16()) {
using Param = MatrixMul::Param;
if (param.compute_mode == Param::ComputeMode::DEFAULT) {
cb(dt_float16, dt_float16, dt_float16);
} else if (param.compute_mode == Param::ComputeMode::FLOAT32) {
cb(dt_float16, dt_float16, dt_float32);
}
} else if (A.layout.dtype == dtype::BFloat16()) {
using Param = MatrixMul::Param;
if (param.compute_mode == Param::ComputeMode::DEFAULT) {
cb(dt_bfloat16, dt_bfloat16, dt_bfloat16);
} else if (param.compute_mode == Param::ComputeMode::FLOAT32) {
cb(dt_bfloat16, dt_bfloat16, dt_float32);
}
#endif
} else if (A.layout.dtype == dtype::Int8() &&
C.layout.dtype == dtype::Int16()) {
cb(dt_int8, dt_int16, dt_int16);
} else if (A.layout.dtype == dtype::Int16() &&
C.layout.dtype == dtype::Int32()) {
cb(dt_int16, dt_int32, dt_int32);
} else if ((A.layout.dtype == dtype::Int8() ||
A.layout.dtype.enumv() == DTypeEnum::QuantizedS8) &&
(C.layout.dtype == dtype::Int32() ||
C.layout.dtype.enumv() == DTypeEnum::QuantizedS32)) {
cb(dt_int8, dt_int32, dt_int32);
} else if (A.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm &&
C.layout.dtype.enumv() == DTypeEnum::QuantizedS32) {
cb(uint8_t, dt_int32, dt_int32);
} else if (A.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
C.layout.dtype.enumv() == DTypeEnum::QuantizedS32 &&
param.format == param::MatrixMul::Format::DEFAULT) {
exec_matrix_mul_quint4x4x32_helper<TA, TB>(A, B, C, workspace, param);
return;
} else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 &&
param.format == param::MatrixMul::Format::DEFAULT) {
exec_matrix_mul_qint4x4x16_helper<TA, TB>(A, B, C, workspace, param);
return;
}
#undef cb
megdnn_throw(ssprintf(
"unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)",
A.layout.dtype.name(), B.layout.dtype.name(), C.layout.dtype.name(),
static_cast<int>(param.compute_mode)));
dispatch_ta_tb<TA, TB>(A.raw_ptr, B.raw_ptr, C.raw_ptr, workspace.raw_ptr,
M, N, K, LDA, LDB, LDC, A.layout.dtype,
B.layout.dtype, C.layout.dtype, param.format,
param.compute_mode);
}
void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A,
......
/**
* \file dnn/src/rocm/batched_matrix_mul/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/rocm/batched_matrix_mul/algos.h"
#include "src/common/algo_base.h"
using namespace megdnn;
using namespace rocm;
BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&blas);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack;
MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl)
BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(
BatchedMatrixMulForwardImpl* o, const TensorLayout& A,
const TensorLayout& B, const TensorLayout& C)
: opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {}
BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(
BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A,
_megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace)
: SizeArgs(opr, A.layout, B.layout, C.layout),
tensor_a{A},
tensor_b{B},
tensor_c{C},
workspace{workspace} {}
std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const {
auto&& param = opr->param();
size_t m = layout_a.shape[0], n = layout_b.shape[1],
k = layout_a.shape[param.transposeA ? 0 : 1];
MEGDNN_MARK_USED_VAR(m);
MEGDNN_MARK_USED_VAR(n);
MEGDNN_MARK_USED_VAR(k);
return megdnn_mangle(ssprintf(
"A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose "
"B=%d,ldA=%zu,ldB=%zu,ldC=%zu",
m, k, k, n, m, n, param.transposeA, param.transposeB,
layout_a.stride[0], layout_b.stride[0], layout_c.stride[0]));
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/batched_matrix_mul/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/rocm/batched_matrix_mul/opr_impl.h"
#include <memory>
#include <unordered_map>
namespace megdnn {
namespace rocm {
/*!
* \brief base class for matrix mul algos
*
*/
class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
ROCM_BLAS,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; }
struct SizeArgs {
BatchedMatrixMulForwardImpl* opr;
TensorLayout layout_a, layout_b, layout_c;
std::string to_string() const;
SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A,
const TensorLayout& B, const TensorLayout& C);
bool can_be_treated_as_int8x8x32() const {
return layout_a.dtype.enumv() == layout_b.dtype.enumv() &&
(layout_a.dtype.enumv() == DTypeEnum::Int8 ||
layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) &&
(layout_c.dtype.enumv() == DTypeEnum::Int32 ||
layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) &&
opr->param().format == param::MatrixMul::Format::DEFAULT;
}
};
struct ExecArgs : public SizeArgs {
TensorND tensor_a, tensor_b, tensor_c;
Workspace workspace;
ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A,
_megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(
req <= workspace.size,
"matrix mul fwd algo %s: required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}
};
class BatchedMatrixMulForwardImpl::AlgoBlas final : public AlgoBase {
public:
AlgoBlas() = default;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
return 0_z;
}
const char* name() const override { return "BLAS"; }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS)
};
class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
AlgoBlas blas;
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/batched_matrix_mul/Blas.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/rocm/batched_matrix_mul/algos.h"
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/rocm/handle.h"
#include "src/rocm/utils.h"
using namespace megdnn;
using namespace rocm;
bool BatchedMatrixMulForwardImpl::AlgoBlas::is_available(
const SizeArgs& args) const {
if (args.opr->param().format != param::MatrixMul::Format::DEFAULT)
return false;
if (args.layout_a.dtype == dtype::Float32() ||
args.layout_a.dtype == dtype::Float16()) {
return true;
}
return false;
}
void BatchedMatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const {
auto batch = args.layout_a.shape[0];
auto m = args.layout_c.shape[1], n = args.layout_c.shape[2];
auto k = args.layout_a.shape[args.opr->param().transposeA ? 1 : 2];
auto&& handle = concrete_handle(args.opr->handle());
auto rocblas_handle_ = handle->get_rocblas_handle();
auto sgemm = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
rocblas_check(rocblas_sgemm_strided_batched(
rocblas_handle_,
args.opr->param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
args.opr->param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, args.tensor_b.ptr<dt_float32>(),
(rocblas_int)(args.layout_b.stride[1]),
(rocblas_int)(args.layout_b.stride[0]),
args.tensor_a.ptr<dt_float32>(),
(rocblas_int)(args.layout_a.stride[1]),
(rocblas_int)(args.layout_a.stride[0]), zero,
args.tensor_c.ptr<dt_float32>(),
(rocblas_int)(args.layout_c.stride[1]),
(rocblas_int)(args.layout_c.stride[0]), (rocblas_int)(batch)));
};
#if !MEGDNN_DISABLE_FLOAT16
//! used for FLOAT_IO16xC32, not tested
auto gemm_ex = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
//! These two arguments for future use, see
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp
int32_t solution_index = 0;
uint32_t flags = 1;
size_t ws_size = 0;
rocblas_check(rocblas_gemm_strided_batched_ex(
rocblas_handle_,
args.opr->param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
args.opr->param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r,
args.layout_b.stride[1], args.layout_b.stride[0],
args.tensor_a.raw_ptr, rocblas_datatype_i8_r,
args.layout_a.stride[1], args.layout_a.stride[0], zero,
args.tensor_c.raw_ptr, rocblas_datatype_i32_r,
args.layout_c.stride[1], args.layout_c.stride[0],
args.tensor_c.raw_ptr, rocblas_datatype_i32_r,
args.layout_c.stride[1], args.layout_c.stride[0], batch,
rocblas_datatype_i32_r, rocblas_gemm_algo_standard,
solution_index, flags, &ws_size, nullptr));
MEGDNN_MARK_USED_VAR(ws_size);
};
auto hgemm = [&]() {
auto one_half = handle->one_device_h();
auto zero_half = handle->zero_device_h();
rocblas_check(rocblas_hgemm_strided_batched(
rocblas_handle_,
args.opr->param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
args.opr->param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, reinterpret_cast<const rocblas_half*>(one_half),
static_cast<const rocblas_half*>(args.tensor_b.raw_ptr),
args.layout_b.stride[1], args.layout_b.stride[0],
static_cast<const rocblas_half*>(args.tensor_a.raw_ptr),
args.layout_a.stride[1], args.layout_a.stride[0],
reinterpret_cast<const rocblas_half*>(zero_half),
static_cast<rocblas_half*>(args.tensor_c.raw_ptr),
args.layout_c.stride[1], args.layout_c.stride[0], batch));
};
#endif
if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) {
if (args.layout_a.dtype == dtype::Float32()) {
sgemm();
}
#if !MEGDNN_DISABLE_FLOAT16
else {
megdnn_assert(args.layout_a.dtype == dtype::Float16(),
"invalid matmul data type");
hgemm();
}
#endif
}
#if !MEGDNN_DISABLE_FLOAT16
else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) {
megdnn_assert(args.layout_b.dtype == dtype::Float16() &&
args.layout_c.dtype == dtype::Float16() &&
args.layout_a.dtype == dtype::Float16(),
"DataType::FLOAT_IO16xC32 is supported, when dtype of A, "
"B, C are all Float16");
gemm_ex();
}
#endif
else {
megdnn_throw("Unsupported data_type of matrix mul on rocm.");
}
}
// vim: syntax=cpp.doxygen
......@@ -10,111 +10,58 @@
* implied.
*/
#include "./opr_impl.h"
#include "./algos.h"
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/common/algo_chooser.h"
#include "src/common/utils.cuh"
#include "src/rocm/handle.h"
#include "src/rocm/utils.h"
namespace megdnn {
namespace rocm {
using namespace megdnn;
using namespace rocm;
void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
auto dtype = A.layout.dtype;
megdnn_assert(dtype.category() == DTypeCategory::FLOAT &&
param().format == param::MatrixMul::Format::DEFAULT);
if (dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT(dtype == dtype::Float16(), false)) {
auto batch = A.layout.shape[0];
auto m = C.layout.shape[1], n = C.layout.shape[2];
auto k = A.layout.shape[param().transposeA ? 1 : 2];
auto handle = concrete_handle(this->handle());
auto rocblas_handle_ = handle->get_rocblas_handle();
auto io32_c32 = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
rocblas_check(rocblas_sgemm_strided_batched(
rocblas_handle_,
param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, B.ptr<dt_float32>(),
(rocblas_int)(B.layout.stride[1]),
(rocblas_int)(B.layout.stride[0]), A.ptr<dt_float32>(),
(rocblas_int)(A.layout.stride[1]),
(rocblas_int)(A.layout.stride[0]), zero,
C.ptr<dt_float32>(), (rocblas_int)(C.layout.stride[1]),
(rocblas_int)(C.layout.stride[0]), (rocblas_int)(batch)));
};
#if !MEGDNN_DISABLE_FLOAT16
auto io16_c32 = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
int32_t solution_index = 0;
uint32_t flags = 1;
size_t ws_size = 0;
rocblas_check(rocblas_gemm_strided_batched_ex(
rocblas_handle_,
param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r,
B.layout.stride[1], B.layout.stride[0], A.raw_ptr,
rocblas_datatype_i8_r, A.layout.stride[1],
A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r,
C.layout.stride[1], C.layout.stride[0], C.raw_ptr,
rocblas_datatype_i32_r, C.layout.stride[1],
C.layout.stride[0], batch, rocblas_datatype_i32_r,
rocblas_gemm_algo_standard, solution_index, flags, &ws_size,
nullptr));
};
auto io16_c16 = [&]() {
auto zero_half = handle->zero_device_h();
auto one_half = handle->one_device_h();
rocblas_check(rocblas_hgemm_strided_batched(
rocblas_handle_,
param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, reinterpret_cast<const rocblas_half*>(one_half),
static_cast<const rocblas_half*>(B.raw_ptr),
B.layout.stride[1], B.layout.stride[0],
static_cast<const rocblas_half*>(A.raw_ptr),
A.layout.stride[1], A.layout.stride[0],
reinterpret_cast<const rocblas_half*>(zero_half),
static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[1],
C.layout.stride[0], batch));
};
#endif
std::vector<BatchedMatrixMulForwardImpl::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args);
}
if (dtype == dtype::Float32()) {
io32_c32();
BatchedMatrixMulForwardImpl::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) {
AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.blas.is_available_reproducible(args, reproducible,
workspace_limit_in_bytes)) {
return &sm_algo_pack.blas;
}
#if !MEGDNN_DISABLE_FLOAT16
else {
if (param().compute_mode == Param::ComputeMode::FLOAT32) {
io16_c32();
if (reproducible) {
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward");
} else {
io16_c16();
}
}
#endif
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward");
}
}
} // namespace rocm
} // namespace megdnn
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
}
void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
AlgoBase::ExecArgs args(this, A, B, C, workspace);
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout);
algo->check_workspace(args, workspace).exec(args);
}
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
......@@ -17,36 +18,35 @@ namespace rocm {
class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward {
public:
using BatchedMatrixMulForward::BatchedMatrixMulForward;
BatchedMatrixMulForwardImpl(Handle* handle)
: BatchedMatrixMul(handle),
m_opr(handle->create_operator<MatrixMul>()) {}
void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
const TensorLayout&) override;
bool is_thread_safe() const override { return true; }
class AlgoBase;
class AlgoBlas;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
private:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override {
return {};
}
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override {
return nullptr;
}
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
bool /*reproducible*/) override;
bool is_thread_safe() const override { return true; }
const char* get_algorithm_set_name() const override {
return "ROCM BATCHED MATMUL";
}
private:
std::unique_ptr<MatrixMul> m_opr;
static AlgoPack sm_algo_pack;
};
} // namespace rocm
......
/**
* \file dnn/src/rocm/matrix_mul/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/rocm/matrix_mul/algos.h"
#include "src/common/algo_base.h"
using namespace megdnn;
using namespace rocm;
MatrixMulForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&blas);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;
MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl)
MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o,
const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C)
: opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {}
MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(MatrixMulForwardImpl* opr,
_megdnn_tensor_in A,
_megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace)
: SizeArgs(opr, A.layout, B.layout, C.layout),
tensor_a{A},
tensor_b{B},
tensor_c{C},
workspace{workspace} {}
std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const {
auto&& param = opr->param();
size_t m = layout_a.shape[0], n = layout_b.shape[1],
k = layout_a.shape[param.transposeA ? 0 : 1];
MEGDNN_MARK_USED_VAR(m);
MEGDNN_MARK_USED_VAR(n);
MEGDNN_MARK_USED_VAR(k);
return megdnn_mangle(ssprintf(
"A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose "
"B=%d,ldA=%zu,ldB=%zu,ldC=%zu",
m, k, k, n, m, n, param.transposeA, param.transposeB,
layout_a.stride[0], layout_b.stride[0], layout_c.stride[0]));
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/matrix_mul/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/rocm/matrix_mul/opr_impl.h"
#include <memory>
#include <unordered_map>
namespace megdnn {
namespace rocm {
/*!
* \brief base class for matrix mul algos
*
*/
class MatrixMulForwardImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
ROCM_BLAS,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; }
struct SizeArgs {
MatrixMulForwardImpl* opr;
TensorLayout layout_a, layout_b, layout_c;
std::string to_string() const;
SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A,
const TensorLayout& B, const TensorLayout& C);
bool can_be_treated_as_int8x8x32() const {
return layout_a.dtype.enumv() == layout_b.dtype.enumv() &&
(layout_a.dtype.enumv() == DTypeEnum::Int8 ||
layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) &&
(layout_c.dtype.enumv() == DTypeEnum::Int32 ||
layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) &&
opr->param().format == param::MatrixMul::Format::DEFAULT;
}
};
struct ExecArgs : public SizeArgs {
TensorND tensor_a, tensor_b, tensor_c;
Workspace workspace;
ExecArgs(MatrixMulForwardImpl* opr, _megdnn_tensor_in A,
_megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(
req <= workspace.size,
"matrix mul fwd algo %s: required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}
};
class MatrixMulForwardImpl::AlgoBlas final : public AlgoBase {
public:
AlgoBlas() = default;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
return 0_z;
}
const char* name() const override { return "BLAS"; }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS)
};
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
AlgoBlas blas;
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/matrix_mul/Blas.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/rocm/matrix_mul/algos.h"
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/rocm/handle.h"
#include "src/rocm/utils.h"
using namespace megdnn;
using namespace rocm;
bool MatrixMulForwardImpl::AlgoBlas::is_available(
const SizeArgs& args) const {
if (args.opr->param().format != param::MatrixMul::Format::DEFAULT)
return false;
if (args.layout_a.dtype == dtype::Float32() ||
args.layout_a.dtype == dtype::Float16()) {
return true;
} else if (args.layout_a.dtype.enumv() == DTypeEnum::Int8 ||
args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) {
auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1];
//! see
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470
bool rocblas_int8x8x32_valid = true;
rocblas_int8x8x32_valid &= (k % 4 == 0);
rocblas_int8x8x32_valid &= (!args.opr->param().transposeB ||
args.layout_b.stride[0] % 4 == 0);
rocblas_int8x8x32_valid &= (!args.opr->param().transposeA ||
args.layout_a.stride[0] % 4 == 0);
return rocblas_int8x8x32_valid;
}
return false;
}
void MatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const {
auto m = args.layout_c.shape[0], n = args.layout_c.shape[1];
auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1];
auto&& handle = concrete_handle(args.opr->handle());
auto rocblas_handle_ = handle->get_rocblas_handle();
auto sgemm = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
rocblas_check(rocblas_sgemm(
rocblas_handle_,
args.opr->param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
args.opr->param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, args.tensor_b.ptr<dt_float32>(),
args.layout_b.stride[0], args.tensor_a.ptr<dt_float32>(),
args.layout_a.stride[0], zero, args.tensor_c.ptr<dt_float32>(),
args.layout_c.stride[0]));
};
#if !MEGDNN_DISABLE_FLOAT16
//! used for FLOAT_IO16xC32, not tested
auto gemm_ex = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
//! These two arguments for future use, see
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp
int32_t solution_index = 0;
uint32_t flags = 1;
size_t ws_size = 0;
auto gemm_ex_err = rocblas_gemm_ex(
rocblas_handle_,
args.opr->param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
args.opr->param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_f16_r,
args.layout_b.stride[0], args.tensor_a.raw_ptr,
rocblas_datatype_f16_r, args.layout_a.stride[0], zero,
args.tensor_c.raw_ptr, rocblas_datatype_f16_r,
args.layout_c.stride[0], args.tensor_c.raw_ptr,
rocblas_datatype_f16_r, args.layout_c.stride[0],
rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
solution_index, flags, &ws_size, nullptr);
rocblas_check(gemm_ex_err);
MEGDNN_MARK_USED_VAR(ws_size);
};
auto hgemm = [&]() {
auto one_half = handle->one_device_h();
auto zero_half = handle->zero_device_h();
auto hgemm_err = rocblas_hgemm(
rocblas_handle_,
args.opr->param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
args.opr->param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, reinterpret_cast<const rocblas_half*>(one_half),
static_cast<const rocblas_half*>(args.tensor_b.raw_ptr),
args.layout_b.stride[0],
static_cast<const rocblas_half*>(args.tensor_a.raw_ptr),
args.layout_a.stride[0],
reinterpret_cast<const rocblas_half*>(zero_half),
static_cast<rocblas_half*>(args.tensor_c.raw_ptr),
args.layout_c.stride[0]);
rocblas_check(hgemm_err);
};
#endif
if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) {
if (args.layout_a.dtype == dtype::Float32()) {
sgemm();
}
#if !MEGDNN_DISABLE_FLOAT16
else {
megdnn_assert(args.layout_a.dtype == dtype::Float16(),
"invalid matmul data type");
hgemm();
}
#endif
}
#if !MEGDNN_DISABLE_FLOAT16
else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) {
megdnn_assert(args.layout_b.dtype == dtype::Float16() &&
args.layout_c.dtype == dtype::Float16() &&
args.layout_a.dtype == dtype::Float16(),
"DataType::FLOAT_IO16xC32 is supported, when dtype of A, "
"B, C are all Float16");
gemm_ex();
}
#endif
else {
megdnn_assert(args.can_be_treated_as_int8x8x32());
int32_t solution_index = 0;
uint32_t flags = 1;
size_t ws_size = 0;
auto zero = handle->zero_device_i32();
auto one = handle->one_device_i32();
rocblas_check(rocblas_gemm_ex(
rocblas_handle_,
args.opr->param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
args.opr->param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r,
args.layout_b.stride[0], args.tensor_a.raw_ptr,
rocblas_datatype_i8_r, args.layout_a.stride[0], zero,
args.tensor_c.raw_ptr, rocblas_datatype_i32_r,
args.layout_c.stride[0], args.tensor_c.raw_ptr,
rocblas_datatype_i32_r, args.layout_c.stride[0],
rocblas_datatype_i32_r, rocblas_gemm_algo_standard,
solution_index, flags, &ws_size, nullptr));
MEGDNN_MARK_USED_VAR(ws_size);
}
}
// vim: syntax=cpp.doxygen
......@@ -13,147 +13,53 @@
#include "src/rocm/utils.h"
#include "src/rocm/handle.h"
#include "./algos.h"
#include "src/common/algo_chooser.h"
namespace megdnn {
namespace rocm {
using namespace megdnn;
using namespace rocm;
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A,
_megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace)
{
check_exec(A.layout, B.layout, C.layout, workspace.size);
auto m = C.layout.shape[0], n = C.layout.shape[1];
auto k = A.layout.shape[param().transposeA ? 0 : 1];
auto handle = concrete_handle(this->handle());
auto rocblas_handle_ = handle->get_rocblas_handle();
auto sgemm = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
rocblas_check(rocblas_sgemm(
rocblas_handle_,
param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, B.ptr<dt_float32>(), B.layout.stride[0],
A.ptr<dt_float32>(), A.layout.stride[0], zero,
C.ptr<dt_float32>(), C.layout.stride[0]));
};
#if !MEGDNN_DISABLE_FLOAT16
//! used for FLOAT_IO16xC32, not tested
auto gemm_ex = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
//! These two arguments for future use, see
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp
int32_t solution_index = 0;
uint32_t flags = 1;
size_t ws_size = 0;
auto gemm_ex_err = rocblas_gemm_ex(
rocblas_handle_,
param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, B.raw_ptr, rocblas_datatype_f16_r,
B.layout.stride[0], A.raw_ptr, rocblas_datatype_f16_r,
A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_f16_r,
C.layout.stride[0], C.raw_ptr, rocblas_datatype_f16_r,
C.layout.stride[0], rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, solution_index, flags, &ws_size,
nullptr);
rocblas_check(gemm_ex_err);
};
auto hgemm = [&]() {
auto one_half = handle->one_device_h();
auto zero_half = handle->zero_device_h();
auto hgemm_err = rocblas_hgemm(
rocblas_handle_,
param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, reinterpret_cast<const rocblas_half*>(one_half),
static_cast<const rocblas_half*>(B.raw_ptr), B.layout.stride[0],
static_cast<const rocblas_half*>(A.raw_ptr), A.layout.stride[0],
reinterpret_cast<const rocblas_half*>(zero_half),
static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[0]);
rocblas_check(hgemm_err);
};
#endif
std::vector<MatrixMulForwardImpl::Algorithm*>
MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args);
}
if (param().compute_mode == Param::ComputeMode::DEFAULT) {
if (A.layout.dtype == dtype::Float32()) {
sgemm();
}
#if !MEGDNN_DISABLE_FLOAT16
else {
megdnn_assert(A.layout.dtype == dtype::Float16(),
"invalid matmul data type");
hgemm();
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) {
AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.blas.is_available_reproducible(
args, reproducible, workspace_limit_in_bytes)) {
return &sm_algo_pack.blas;
}
#endif
}
#if !MEGDNN_DISABLE_FLOAT16
else if (param().compute_mode == Param::ComputeMode::FLOAT32) {
megdnn_assert(B.layout.dtype == dtype::Float16() &&
C.layout.dtype == dtype::Float16() &&
A.layout.dtype == dtype::Float16(),
"DataType::FLOAT_IO16xC32 is supported, when dtype of A, "
"B, C are all Float16");
gemm_ex();
}
#endif
else if (A.layout.dtype == dtype::Int8() &&
B.layout.dtype == dtype::Int8() &&
C.layout.dtype == dtype::Int32()) {
//! see
//! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470
bool rocblas_int8x8x32_valid = true;
rocblas_int8x8x32_valid &= (k % 4 == 0);
rocblas_int8x8x32_valid &=
(!param().transposeB || B.layout.stride[0] % 4 == 0);
rocblas_int8x8x32_valid &=
(!param().transposeA || A.layout.stride[0] % 4 == 0);
megdnn_assert(rocblas_int8x8x32_valid,
"rocblas int8x8x32 matmul requires K must be a multiple "
"of 4, and/or LDA/LDB based on transpose mode"
"get: %zu, is_trans_b = %d, %zu, is_trans_a = %d, %zu",
k, param().transposeB, B.layout.stride[0],
param().transposeA, A.layout.stride[0]);
int32_t solution_index = 0;
uint32_t flags = 1;
size_t ws_size = 0;
auto zero = handle->zero_device_i32();
auto one = handle->one_device_i32();
rocblas_check(rocblas_gemm_ex(
rocblas_handle_,
param().transposeB ? rocblas_operation_transpose
: rocblas_operation_none,
param().transposeA ? rocblas_operation_transpose
: rocblas_operation_none,
n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r,
B.layout.stride[0], A.raw_ptr, rocblas_datatype_i8_r,
A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r,
C.layout.stride[0], C.raw_ptr, rocblas_datatype_i32_r,
C.layout.stride[0], rocblas_datatype_i32_r,
rocblas_gemm_algo_standard, solution_index, flags, &ws_size,
nullptr));
if (reproducible) {
return megdnn::get_reproducible_algo<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward");
} else {
megdnn_assert((A.layout.dtype == dtype::Int8() &&
B.layout.dtype == dtype::Int8() &&
C.layout.dtype == dtype::Int16()),
"invalid matmul data type");
megdnn_throw("cuda matmul does not support INT8x8x16 now");
return megdnn::get_usable_algo<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward");
}
}
} // namespace rocm
} // namespace megdnn
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
}
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
AlgoBase::ExecArgs args(this, A, B, C, workspace);
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout);
algo->check_workspace(args, workspace).exec(args);
}
// vim: syntax=cpp.doxygen
......@@ -20,29 +20,32 @@ public:
void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
const TensorLayout&) override;
bool is_thread_safe() const override { return true; }
class AlgoBase;
class AlgoBlas;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
private:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override {
return {};
}
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override {
return nullptr;
}
bool /*reproducible*/) override;
const char* get_algorithm_set_name() const override {
return "ROCM MATMUL";
}
static AlgoPack sm_algo_pack;
};
} // namespace rocm
......
......@@ -46,6 +46,37 @@ TEST_F(FALLBACK, MATRIX_MUL) {
}
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE) {
Checker<MatrixMul> checker(handle());
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE"));
using Param = MatrixMul::Param;
auto args = matrix_mul::get_matmul_args();
for (auto arg : args) {
auto m = arg.m, n = arg.n, k = arg.k;
auto mask = arg.mask;
Param param;
param.transposeA = mask & 1;
param.transposeB = mask & 2;
TensorShape AS, BS, CS;
if (param.transposeA)
AS = TensorShape{k, m};
else
AS = TensorShape{m, k};
if (param.transposeB)
BS = TensorShape{n, k};
else
BS = TensorShape{k, n};
CS = TensorShape{m, n};
TensorLayout AL, BL, CL;
AL = TensorLayout(AS, dtype::Float32());
BL = TensorLayout(BS, dtype::Float32());
CL = TensorLayout(CS, dtype::Float32());
checker.set_param(param);
checker.execl({AL, BL, CL});
}
}
TEST_F(FALLBACK, BATCHED_MATRIX_MUL) {
Checker<BatchedMatrixMul> checker(handle());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册