提交 a1877ee0 编写于 作者: M Megvii Engine Team

refactor(dnn): refactor algo interface, use algoinfo instead of global algorithm

GitOrigin-RevId: 479718ac7557af090996757063cbed58d1c78f2e
上级 cb59c278
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -92,24 +93,72 @@ enum class AlgoDataType : uint32_t {
/*!
* \brief Abstract representation of an algorithm for implementing
* the operator
*
* All pointers to Algorithm should be allocated globally and usable
* across multiple megdnn handles, and they should not be freed by
* the caller.
*/
class Algorithm {
public:
static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1);
/**
* \brief Algorithm information, we can get real algo from
* AlgorithmInfo::Info::Desc
*/
struct Info {
struct Desc {
//! backend of the algo belonging to
Handle::HandleType handle_type;
//! indicate the real algo implementation
uint32_t type = INVALID_ALGO_TYPE;
//! serialized param of the algo type
std::string param;
bool valid() const { return type != INVALID_ALGO_TYPE; }
void reset() { type = INVALID_ALGO_TYPE; }
bool operator==(const Desc& rhs) const {
return handle_type == rhs.handle_type && type == rhs.type &&
param == rhs.param;
}
} desc;
//! algorithm name
std::string name;
bool is_reproducible;
bool valid() const { return desc.valid(); }
void reset() { desc.reset(); }
//! desc donate the algo
bool operator==(const Info& rhs) const { return desc == rhs.desc; }
};
virtual ~Algorithm() = default;
/**
* \brief whether the execution result is
* reproducible across multiple runs.
*/
virtual bool is_reproducible() const = 0;
virtual const char* name() const = 0;
//! serialized param
virtual std::string param() const { return {}; }
virtual uint32_t type() const = 0;
Handle::HandleType handle_type() const { return m_handle_type; }
Info info() const {
return {{handle_type(), type(), param()}, name(), is_reproducible()};
}
template <typename T>
static void serialize_write_pod(const T& val, std::string& result) {
result.append(reinterpret_cast<const char*>(&val), sizeof(T));
}
static void serialize_write_pod(const char* val, std::string& result) {
result.append(val, strlen(val));
}
template <typename T>
static T deserialize_read_pod(const std::string& data, size_t offset = 0) {
T ret = *reinterpret_cast<const T*>(&data[offset]);
return ret;
}
protected:
~Algorithm() = default;
Handle::HandleType m_handle_type = Handle::HandleType::NAIVE;
};
......@@ -127,6 +176,8 @@ class MultiAlgoOpr;
template <class Opr>
class MultiAlgoOpr<Opr, -1> {
public:
using AlgorithmInfo = detail::Algorithm::Info;
using AlgorithmDesc = detail::Algorithm::Info::Desc;
using Algorithm = detail::Algorithm;
/*!
* \brief get a string representation for current algorithm set;
......@@ -139,8 +190,8 @@ public:
//! policy for executing the operator
struct ExecutionPolicy {
//! nullptr means using heuristic
Algorithm* algorithm = nullptr;
//! INVALID_ALGO_TYPE algo_type means using heuristic
AlgorithmInfo algo;
};
ExecutionPolicy& execution_policy() { return m_execution_policy; }
......@@ -161,6 +212,39 @@ template <class Opr>
class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> {
public:
using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info;
//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2)) {
ret.emplace_back(algo->info());
}
return ret;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
*
* The selected algorithm should not use workspace more than
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) {
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes,
reproducible)
->info();
}
protected:
~MultiAlgoOpr() = default;
//! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
......@@ -179,9 +263,6 @@ public:
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0;
protected:
~MultiAlgoOpr() = default;
};
//! specializae for nargs == 4
......@@ -189,6 +270,40 @@ template <class Opr>
class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> {
public:
using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info;
//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) {
ret.emplace_back(algo->info());
}
return ret;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
*
* The selected algorithm should not use workspace more than
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) {
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes,
reproducible)
->info();
}
protected:
~MultiAlgoOpr() = default;
//! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
......@@ -207,9 +322,6 @@ public:
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0;
protected:
~MultiAlgoOpr() = default;
};
//! specializae for nargs == 5
......@@ -217,6 +329,42 @@ template <class Opr>
class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> {
public:
using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info;
//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3,
const TensorLayout& p4) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) {
ret.emplace_back(algo->info());
}
return ret;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
*
* The selected algorithm should not use workspace more than
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4,
workspace_limit_in_bytes, reproducible)
->info();
}
protected:
~MultiAlgoOpr() = default;
//! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
......@@ -237,9 +385,6 @@ public:
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0;
protected:
~MultiAlgoOpr() = default;
};
//! specializae for nargs == 8
......@@ -247,6 +392,42 @@ template <class Opr>
class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> {
public:
using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info;
//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) {
ret.emplace_back(algo->info());
}
return ret;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
*
* The selected algorithm should not use workspace more than
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7,
workspace_limit_in_bytes, reproducible)
->info();
}
protected:
~MultiAlgoOpr() = default;
//! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
......@@ -269,9 +450,6 @@ public:
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0;
protected:
~MultiAlgoOpr() = default;
};
} // namespace detail
} // namespace megdnn
......
......@@ -31,6 +31,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP16)
};
} // namespace aarch64
} // namespace megdnn
......
......@@ -36,6 +36,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP32)
};
} // namespace aarch64
......
......@@ -48,6 +48,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_S8)
};
} // namespace aarch64
......
......@@ -32,28 +32,54 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF16DirectStride2 f16_direct_stride2;
#endif
fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_matmul_algos;
public:
AlgoPack() {
matmul_algos.emplace_back(&qu8_matrix_mul);
matmul_algos.emplace_back(&s8_matrix_mul);
m_matmul_algos.emplace_back(&qu8_matrix_mul);
m_matmul_algos.emplace_back(&s8_matrix_mul);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
direct_algos.emplace_back(&f16_direct_stride2);
m_direct_algos.emplace_back(&f16_direct_stride2);
#endif
direct_algos.emplace_back(&f32_direct_stride2);
m_direct_algos.emplace_back(&f32_direct_stride2);
for (auto&& algo : m_direct_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
for (auto&& algo : m_matmul_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const {
return m_direct_algos;
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos()
const {
return m_matmul_algos;
}
SmallVector<AlgoBase*> direct_algos;
SmallVector<AlgoBase*> matmul_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
auto&& algos = arm_common::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
sl_algo_pack.direct_algos.end());
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl)
SmallVector<fallback::ConvBiasImpl::AlgoBase*>
ConvBiasImpl::get_all_packed_algo() {
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().direct_algos().begin(),
algo_pack().direct_algos().end());
//! We put matmul algos at the begin. Because matmul will get privilege when
//! prefer return true. See
algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(),
sl_algo_pack.matmul_algos.end());
algos.insert(algos.begin(), algo_pack().matmul_algos().begin(),
algo_pack().matmul_algos().end());
return std::move(algos);
}
......
......@@ -25,7 +25,9 @@ public:
}
};
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override;
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl);
protected:
const char* get_algorithm_set_name() const override;
......@@ -38,6 +40,7 @@ private:
class AlgoF16DirectStride2;
#endif
class AlgoPack;
static const AlgoPack& algo_pack();
};
} // namespace aarch64
......
......@@ -48,6 +48,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL};
}
MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_QU8)
};
} // namespace aarch64
} // namespace megdnn
......
......@@ -27,6 +27,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K8X12X1)
};
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
......@@ -37,6 +38,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_K8X12X1)
};
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
......@@ -47,6 +49,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K4X16X1)
};
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase {
......@@ -58,10 +61,17 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16)
};
class MatrixMulImpl::AlgoF32Gemv final
: public arm_common::MatrixMulImpl::AlgoF32Gemv {};
: public arm_common::MatrixMulImpl::AlgoF32Gemv {
public:
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() {
m_handle_type = Handle::HandleType::AARCH64;
}
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_GEMV)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase {
......@@ -72,6 +82,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_K8X24X1)
};
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase {
......@@ -83,6 +94,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_8X8)
};
#endif
......@@ -98,6 +110,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X12X4_DOTPROD)
};
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase {
......@@ -110,6 +123,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD)
};
#else
......@@ -124,6 +138,7 @@ public:
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_4X4X16)
};
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase {
......@@ -136,6 +151,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K4X4X16)
};
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase {
......@@ -147,6 +163,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8)
};
#endif
......@@ -160,6 +177,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K8X8X8)
};
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase {
......@@ -171,6 +189,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16)
};
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
......@@ -186,6 +205,7 @@ public:
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_16X12X4)
};
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
......@@ -201,6 +221,7 @@ public:
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_K8X8X8)
};
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
......@@ -214,6 +235,7 @@ public:
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_4X4X8)
};
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
......@@ -225,6 +247,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_K12X8X1)
};
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase {
......@@ -236,6 +259,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8)
};
#if __ARM_FEATURE_DOTPROD
......@@ -249,6 +273,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD)
};
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
......@@ -262,6 +287,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD)
};
#else
......@@ -273,6 +299,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8)
};
#endif
......
......@@ -51,49 +51,66 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoQuint8K8x8x8 quint8_k8x8x8;
#endif
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
public:
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;
AlgoPack() {
all_algos.emplace_back(&f32_gemv);
all_algos.emplace_back(&f32K8x12x1);
all_algos.emplace_back(&f32_mk4_8x12x1);
all_algos.emplace_back(&f32k4x16x1);
all_algos.emplace_back(&f32mk4_4x16);
m_all_algos.emplace_back(&f32_gemv);
m_all_algos.emplace_back(&f32K8x12x1);
m_all_algos.emplace_back(&f32_mk4_8x12x1);
m_all_algos.emplace_back(&f32k4x16x1);
m_all_algos.emplace_back(&f32mk4_4x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16_k8x24x1);
all_algos.emplace_back(&f16_mk8_8x8);
m_all_algos.emplace_back(&f16_k8x24x1);
m_all_algos.emplace_back(&f16_mk8_8x8);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod);
m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod);
#else
all_algos.emplace_back(&int8x8x32_k4x4x16);
all_algos.emplace_back(&int8x8x32_k8x8x8);
all_algos.emplace_back(&int8x8x32_mk4_4x4x16);
m_all_algos.emplace_back(&int8x8x32_k4x4x16);
m_all_algos.emplace_back(&int8x8x32_k8x8x8);
m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16);
#endif
all_algos.emplace_back(&int8x8x16_k4x4x16);
all_algos.emplace_back(&int8x8x16_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_4x4x8);
all_algos.emplace_back(&int8x8x16_mk4_16x12x4);
m_all_algos.emplace_back(&int8x8x16_k4x4x16);
m_all_algos.emplace_back(&int8x8x16_k8x8x8);
m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8);
m_all_algos.emplace_back(&int8x8x16_mk4_4x4x8);
m_all_algos.emplace_back(&int8x8x16_mk4_16x12x4);
all_algos.emplace_back(&int16x16x32_k12x8x1);
all_algos.emplace_back(&int16x16x32_mk8_8x8);
m_all_algos.emplace_back(&int16x16x32_k12x8x1);
m_all_algos.emplace_back(&int16x16x32_mk8_8x8);
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&quint8_gemv_dotprod);
all_algos.emplace_back(&quint8_k8x8x4_dotprod);
m_all_algos.emplace_back(&quint8_gemv_dotprod);
m_all_algos.emplace_back(&quint8_k8x8x4_dotprod);
#else
all_algos.emplace_back(&quint8_k8x8x8);
m_all_algos.emplace_back(&quint8_k8x8x8);
#endif
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const {
return m_all_algos;
}
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack;
auto&& algos = arm_common::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),
s_algo_pack.all_algos.end());
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl)
SmallVector<fallback::MatrixMulImpl::AlgoBase*>
MatrixMulImpl::get_all_packed_algo() {
auto&& algos = arm_common::MatrixMulImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return std::move(algos);
}
......
......@@ -25,7 +25,10 @@ public:
}
};
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo()
override;
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);
private:
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
......@@ -66,6 +69,8 @@ private:
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16
class AlgoPack;
public:
static const AlgoPack& algo_pack();
};
} // namespace aarch64
......
......@@ -30,6 +30,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16)
};
class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase {
......@@ -45,7 +46,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16)
};
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase {
public:
......@@ -61,6 +62,7 @@ public:
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16)
};
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase {
public:
......@@ -75,6 +77,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16)
};
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase {
......@@ -94,6 +97,7 @@ public:
ConvAlgoTypePack get_algo_type() const override{
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16)
};
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase {
......@@ -110,6 +114,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16)
};
} // namespace arm_common
......
......@@ -30,6 +30,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32)
};
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase {
......@@ -45,6 +46,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32)
};
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase {
......@@ -60,6 +62,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32)
};
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase {
......@@ -75,6 +78,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32)
};
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase {
......@@ -90,6 +94,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32)
};
//===================== NCHW44 Winograd Support =====================//
......@@ -107,6 +112,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32)
};
class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase {
......@@ -123,6 +129,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32)
};
class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase {
......@@ -139,6 +146,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32)
};
// ================================================================= //
......@@ -157,6 +165,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32)
};
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
......@@ -174,6 +183,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32)
};
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
......@@ -191,6 +201,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32)
};
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
......@@ -209,6 +220,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32)
};
class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase {
......@@ -227,6 +239,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32)
};
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase {
......@@ -244,6 +257,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32)
};
} // namespace arm_common
......
......@@ -33,6 +33,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_S8)
};
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
......@@ -49,6 +50,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_S8)
};
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase {
......@@ -65,6 +67,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44)
};
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
......@@ -81,6 +84,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_S8)
};
class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase {
......@@ -95,6 +99,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD1_NCHW44_S8)
};
class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase {
......@@ -109,6 +114,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8)
};
#if __ARM_FEATURE_DOTPROD
......@@ -126,6 +132,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8)
};
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
......@@ -142,6 +149,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_S8)
};
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase {
......@@ -159,6 +167,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_S8)
};
class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase {
......@@ -180,6 +189,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_DOT_S8)
};
#endif
......@@ -196,6 +206,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8)
};
//=======================input int8 compute fp32 output int8============
......@@ -213,6 +224,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32)
};
//=======================input int8 compute int16 output int8============
......@@ -231,6 +243,7 @@ public:
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8)
};
} // namespace arm_common
......
......@@ -39,6 +39,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_INT8X8X16)
};
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase {
......@@ -54,6 +55,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_INT8X8X16)
};
class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase {
......@@ -80,6 +82,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_INT8X8X16)
};
class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase {
......@@ -96,12 +99,16 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16)
};
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final
: public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; }
const char* name() const override {
return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44";
}
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(
......@@ -111,6 +118,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16)
};
class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase {
......@@ -129,6 +137,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16)
};
} // namespace arm_common
......
......@@ -88,46 +88,50 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
#endif
SmallVector<std::unique_ptr<AlgoBase>> refhold;
fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_winograd_algos;
public:
AlgoPack() {
#if __ARM_FEATURE_DOTPROD
direct_algos.emplace_back(&ds8_direct_stride1);
direct_algos.emplace_back(&ds8_direct_stride2);
direct_algos.emplace_back(&du8_direct_stride1);
direct_algos.emplace_back(&du8_direct_stride2);
m_direct_algos.emplace_back(&ds8_direct_stride1);
m_direct_algos.emplace_back(&ds8_direct_stride2);
m_direct_algos.emplace_back(&du8_direct_stride1);
m_direct_algos.emplace_back(&du8_direct_stride2);
direct_algos.emplace_back(&ds8_direct_nchw44);
direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
m_direct_algos.emplace_back(&ds8_direct_nchw44);
m_direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
#endif
direct_algos.emplace_back(&qu8_direct_stride2);
direct_algos.emplace_back(&qu8_direct_stride1);
direct_algos.emplace_back(&s8_direct_stride2);
direct_algos.emplace_back(&s8_direct_nchw44);
direct_algos.emplace_back(&s8x8x16_direct_nchw44);
direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1);
direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
m_direct_algos.emplace_back(&qu8_direct_stride2);
m_direct_algos.emplace_back(&qu8_direct_stride1);
m_direct_algos.emplace_back(&s8_direct_stride2);
m_direct_algos.emplace_back(&s8_direct_nchw44);
m_direct_algos.emplace_back(&s8x8x16_direct_nchw44);
m_direct_algos.emplace_back(&s8_direct_nchw_nchw44);
m_direct_algos.emplace_back(&s8_direct_stride1);
m_direct_algos.emplace_back(
&s8x8x16_channel_wise_stride1_stride2_nchw44);
m_direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
m_direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
direct_algos.emplace_back(&f16_direct_stride1);
direct_algos.emplace_back(&f16_direct);
m_direct_algos.emplace_back(&f16_direct_stride1);
m_direct_algos.emplace_back(&f16_direct);
#endif
direct_algos.emplace_back(&i8x8x16_direct);
direct_algos.emplace_back(&i8x8x16_stride2_filter2);
direct_algos.emplace_back(&i8x8x16_stride2);
direct_algos.emplace_back(&i8x8x16_nchw_nchw44);
m_direct_algos.emplace_back(&i8x8x16_direct);
m_direct_algos.emplace_back(&i8x8x16_stride2_filter2);
m_direct_algos.emplace_back(&i8x8x16_stride2);
m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44);
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&f32_chanel_wise_nchw44);
direct_algos.emplace_back(&f32_direct_nchw44);
m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
m_direct_algos.emplace_back(&f32_chanel_wise_nchw44);
m_direct_algos.emplace_back(&f32_direct_nchw44);
direct_algos.emplace_back(&f32_direct_stride1);
direct_algos.emplace_back(&f32_direct_stride2);
direct_algos.emplace_back(&f32_direct);
m_direct_algos.emplace_back(&f32_direct_stride1);
m_direct_algos.emplace_back(&f32_direct_stride2);
m_direct_algos.emplace_back(&f32_direct);
static CpuOprDelegationStorage<2> storage;
auto matmul_opr = storage.get<MatrixMul, 0>();
......@@ -143,31 +147,31 @@ public:
refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
#endif
//! Qint8x8x32 winograd compute with fp32
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
}
}
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
......@@ -180,15 +184,15 @@ public:
refhold.emplace_back(new AlgoFP32WinogradF63(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF54(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF45(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
}
}
......@@ -203,15 +207,15 @@ public:
refhold.emplace_back(new AlgoFP16WinogradF23(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP16WinogradF45(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP16WinogradF63(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
}
}
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
......@@ -224,7 +228,7 @@ public:
refhold.emplace_back(new AlgoFP16WinogradF23_8x8(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
}
}
#endif
......@@ -238,25 +242,48 @@ public:
refhold.emplace_back(new AlgoS8WinogradF23_8x8(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
}
}
for (auto&& algo : m_direct_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
for (auto&& algo : m_winograd_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
SmallVector<AlgoBase*> direct_algos;
SmallVector<AlgoBase*> winograd_algos;
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos()
const {
return m_direct_algos;
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos()
const {
return m_winograd_algos;
}
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
auto&& algos = fallback::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
sl_algo_pack.direct_algos.end());
algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(),
sl_algo_pack.winograd_algos.end());
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl)
SmallVector<fallback::ConvBiasImpl::AlgoBase*>
ConvBiasImpl::get_all_packed_algo() {
auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().direct_algos().begin(),
algo_pack().direct_algos().end());
algos.insert(algos.end(), algo_pack().winograd_algos().begin(),
algo_pack().winograd_algos().end());
return std::move(algos);
}
......
......@@ -12,6 +12,7 @@
#pragma once
#include "src/common/utils.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/common/algo_base.h"
namespace megdnn {
namespace arm_common {
......@@ -27,7 +28,7 @@ public:
}
};
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override;
bool is_matmul_quantized_prefer(
const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param)
......@@ -35,7 +36,8 @@ public:
SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const override;
class AlgoPack;
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl);
protected:
const char* get_algorithm_set_name() const override;
......@@ -95,6 +97,9 @@ private:
class AlgoF16Direct;
class AlgoF16DirectStride1;
#endif
class AlgoPack;
static const AlgoPack& algo_pack();
};
} // namespace arm_common
......
......@@ -32,6 +32,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_QU8)
};
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase {
......@@ -48,6 +49,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8)
};
#if __ARM_FEATURE_DOTPROD
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase {
......@@ -65,6 +67,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8)
};
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase {
......@@ -81,6 +84,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8)
};
#endif
} // namespace arm_common
......
......@@ -36,6 +36,7 @@ public:
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32)
};
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final
......@@ -54,6 +55,7 @@ public:
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_INT8X8X32)
};
#endif
......
......@@ -1086,6 +1086,10 @@ bool deconv::can_stride1_int8x8x32_dot(const NCBKernSizeParam& param) {
(FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
FH >= PH + 1 && FW >= PW + 1;
avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
param.filter_type.enumv() == DTypeEnum::Int8) &&
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
param.grad_type.enumv() == DTypeEnum::Int32);
return avaiable &&
((FH == 2 && OC <= 8) ||
((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16)));
......
......@@ -1180,6 +1180,10 @@ bool deconv::can_stride2_int8x8x32_dot(const NCBKernSizeParam& param) {
(FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
FH >= PH + 1 && FW >= PW + 1;
avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
param.filter_type.enumv() == DTypeEnum::Int8) &&
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
param.grad_type.enumv() == DTypeEnum::Int32);
return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) ||
(FH == 5 && OC <= 16) || (FH == 7 && OC < 32));
}
......
......@@ -23,15 +23,54 @@ using namespace arm_common;
/* ===================== ConvolutionBackwardData ===================== */
struct ConvolutionBackwardDataImpl::AlgoPack {
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD
AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot;
AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot;
AlgoUdot8DirectStride1 quint8_direct_stride1_udot;
AlgoUdot8DirectStride2 quint8_direct_stride2_udot;
#endif
fallback::ConvolutionBackwardDataImpl::AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>
m_all_algos;
public:
AlgoPack() {
#if __ARM_FEATURE_DOTPROD
m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot);
m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot);
m_all_algos.emplace_back(&quint8_direct_stride1_udot);
m_all_algos.emplace_back(&quint8_direct_stride2_udot);
#endif
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
const SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>&
all_algos() const {
return m_all_algos;
}
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
const ConvolutionBackwardDataImpl::AlgoPack&
ConvolutionBackwardDataImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl)
SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>
ConvolutionBackwardDataImpl::get_all_packed_algo() {
auto&& algos = fallback::ConvolutionBackwardDataImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return std::move(algos);
}
ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
......@@ -52,35 +91,6 @@ size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
param);
}
std::vector<ConvolutionBackwardDataImpl::Algorithm*>
ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
const NCBKernSizeParam& param) {
auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
param);
#if __ARM_FEATURE_DOTPROD
if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
param.filter_type.enumv() == DTypeEnum::Int8) &&
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
param.grad_type.enumv() == DTypeEnum::Int32)) {
if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot);
}
if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot);
}
} else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.grad_type.enumv() == DTypeEnum::QuantizedS32) {
if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot);
}
if (sm_algo_pack.quint8_direct_stride2_udot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride2_udot);
}
}
#endif
return ret;
}
const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
// arm common version 0
return "DeconvAC0";
......
......@@ -47,11 +47,14 @@ protected:
size_t ncb_1g_get_workspace(Algorithm* algo,
const NCBKernSizeParam& param) override;
std::vector<Algorithm*> ncb_1g_get_all_algorithms(
const NCBKernSizeParam& param) override;
const char* get_algorithm_set_name() const override;
SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>
get_all_packed_algo() override;
public:
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl);
private:
#if __ARM_FEATURE_DOTPROD
class AlgoSdot8DirectStride1;
......@@ -59,8 +62,8 @@ private:
class AlgoUdot8DirectStride1;
class AlgoUdot8DirectStride2;
#endif
struct AlgoPack;
static AlgoPack sm_algo_pack;
class AlgoPack;
static const AlgoPack& algo_pack();
};
} // namespace arm_common
......
......@@ -36,6 +36,7 @@ public:
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8)
};
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final
......@@ -55,6 +56,7 @@ public:
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8)
};
#endif
} // namespace arm_common
......
......@@ -1236,6 +1236,9 @@ bool deconv::can_stride1_quint8_dot(const NCBKernSizeParam& param) {
(FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
FH >= PH + 1 && FW >= PW + 1;
avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm ||
param.grad_type.enumv() == DTypeEnum::Int32);
/**
* \note In the kernel, we use int32_t to calc the value, in order
* not generate negative number, we first initialize SHIFT and sub
......
......@@ -1337,6 +1337,9 @@ bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) {
(FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
FH >= PH + 1 && FW >= PW + 1;
avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm ||
param.grad_type.enumv() == DTypeEnum::Int32);
/**
* \note In the kernel, we use uint32_t to calc the value, in order
* not generate negative number, we first initialize SHIFT and sub
......
......@@ -59,6 +59,7 @@ public:
virtual bool is_available(const KernParam&) const = 0;
virtual void exec(const KernParam&) const = 0;
virtual ~AlgoBase() = default;
uint32_t type() const override { return INVALID_ALGO_TYPE; };
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......
......@@ -26,6 +26,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X16)
};
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
......@@ -39,6 +40,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV)
};
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
......@@ -52,6 +54,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4)
};
#if __ARM_FEATURE_DOTPROD
......@@ -66,6 +69,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4_DOT)
};
#endif
......@@ -96,6 +100,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -110,6 +115,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F16_GEMV)
};
#endif
......@@ -130,6 +136,7 @@ public:
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32)),
DEFAULT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_GEVM)
};
} // namespace arm_common
......
......@@ -28,28 +28,47 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoGevm gevm;
AlgoF32GemvMK4 f32_gemv_mk4;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack() {
all_algos.emplace_back(&int8x8x16);
m_all_algos.emplace_back(&int8x8x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16gemv);
m_all_algos.emplace_back(&f16gemv);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_gemv_mk4_dot);
m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot);
#endif
all_algos.emplace_back(&int8x8x32_gemv);
all_algos.emplace_back(&int8x8x32_gemv_mk4);
all_algos.emplace_back(&f32_gemv_mk4);
all_algos.emplace_back(&gevm);
m_all_algos.emplace_back(&int8x8x32_gemv);
m_all_algos.emplace_back(&int8x8x32_gemv_mk4);
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&gevm);
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const {
return m_all_algos;
}
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl)
SmallVector<fallback::MatrixMulImpl::AlgoBase*>
MatrixMulImpl::get_all_packed_algo() {
static AlgoPack s_algo_pack;
auto&& algos = fallback::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),
s_algo_pack.all_algos.end());
auto&& algos = fallback::MatrixMulImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return std::move(algos);
}
......
......@@ -11,6 +11,7 @@
#pragma once
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/opr_impl.h"
#include "src/common/algo_base.h"
namespace megdnn {
namespace arm_common {
......@@ -27,7 +28,10 @@ public:
}
};
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo()
override;
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);
protected:
class AlgoF32Gemv; // Arm_common F32 Gemv
......@@ -43,6 +47,9 @@ protected:
#endif
class AlgoInt8x8x16; // Arm_common Int 8x8x16
class AlgoPack;
public:
static const AlgoPack& algo_pack();
};
} // namespace arm_common
......
......@@ -10,6 +10,7 @@
* implied.
*/
#pragma once
#include "megdnn/oprs/base.h"
#include "src/fallback/pooling/opr_impl.h"
namespace megdnn {
......@@ -72,6 +73,8 @@ public:
virtual ~AlgoBase() = default;
virtual bool usable(const PoolingKernSizeParam& param) const = 0;
virtual void exec(const PoolingKernParam& param) const = 0;
uint32_t type() const override { return INVALID_ALGO_TYPE; };
};
private:
......
......@@ -40,6 +40,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_S8)
};
} // namespace armv7
......
......@@ -24,22 +24,40 @@ using namespace armv7;
class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8MatrixMul s8_matrix_mul;
AlgoQU8MatrixMul qu8_matrix_mul;
fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_all_algos;
public:
AlgoPack() {
all_algos.emplace_back(&qu8_matrix_mul);
all_algos.emplace_back(&s8_matrix_mul);
m_all_algos.emplace_back(&qu8_matrix_mul);
m_all_algos.emplace_back(&s8_matrix_mul);
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& all_algos()
const {
return m_all_algos;
}
SmallVector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
auto&& algos = arm_common::ConvBiasImpl::algo_pack();
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl)
SmallVector<fallback::ConvBiasImpl::AlgoBase*>
ConvBiasImpl::get_all_packed_algo() {
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo();
//! TODO fused matmul bias is slower than matmul + elemwise in armv7 now,
//! and nearly equal in aarch64, because of the waste of register in
//! postprocess
algos.insert(algos.end(), sl_algo_pack.all_algos.begin(),
sl_algo_pack.all_algos.end());
algos.insert(algos.end(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return std::move(algos);
}
......
......@@ -25,7 +25,9 @@ public:
}
};
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override;
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl);
protected:
const char* get_algorithm_set_name() const override;
......@@ -34,6 +36,7 @@ private:
class AlgoS8MatrixMul;
class AlgoQU8MatrixMul;
class AlgoPack;
static const AlgoPack& algo_pack();
};
} // namespace armv7
......
......@@ -42,6 +42,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL};
}
MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_QU8)
};
} // namespace armv7
......
......@@ -27,6 +27,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32)
};
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase {
......@@ -37,6 +38,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_PACK_4X12)
};
class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase {
......@@ -48,6 +50,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_4x8)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -59,6 +62,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_K4X16X1)
};
class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase {
public:
......@@ -69,6 +73,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8)
};
#endif
#if __ARM_FEATURE_DOTPROD
......@@ -80,6 +85,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_K6X8X4)
};
class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase {
......@@ -90,6 +96,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X4)
};
class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase {
......@@ -102,11 +109,18 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_MK4_8X4X4_DOTPROD)
};
#endif
class MatrixMulImpl::AlgoF32Gemv final
: public arm_common::MatrixMulImpl::AlgoF32Gemv {};
: public arm_common::MatrixMulImpl::AlgoF32Gemv {
public:
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() {
m_handle_type = Handle::HandleType::ARMV7;
}
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_GEMV)
};
class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase {
public:
......@@ -117,6 +131,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X2X16)
};
class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase {
......@@ -128,6 +143,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X8X8)
};
class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase {
......@@ -138,6 +154,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X8)
};
class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase {
......@@ -149,6 +166,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X2X16)
};
class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase {
......@@ -160,6 +178,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8)
};
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase {
......@@ -171,6 +190,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_MK4_K8X8X4)
};
class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase {
......@@ -182,6 +202,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_K12X4X1)
};
class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase {
......@@ -193,6 +214,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_MK8_4X8)
};
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {
......@@ -204,6 +226,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_MK4_4X2X16)
};
} // namespace armv7
......
......@@ -43,42 +43,60 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1;
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
public:
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;
AlgoPack() {
all_algos.emplace_back(&f32_gemv);
all_algos.emplace_back(&f32);
all_algos.emplace_back(&f32_mk4_pack_4x12);
all_algos.emplace_back(&f32_mk4_4x8);
m_all_algos.emplace_back(&f32_gemv);
m_all_algos.emplace_back(&f32);
m_all_algos.emplace_back(&f32_mk4_pack_4x12);
m_all_algos.emplace_back(&f32_mk4_4x8);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16_k4x16x1);
all_algos.emplace_back(&f16_mk8_4x8);
m_all_algos.emplace_back(&f16_k4x16x1);
m_all_algos.emplace_back(&f16_mk8_4x8);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod);
all_algos.emplace_back(&int8_k6x8x4);
all_algos.emplace_back(&quint8_k4x8x4);
m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod);
m_all_algos.emplace_back(&int8_k6x8x4);
m_all_algos.emplace_back(&quint8_k4x8x4);
#endif
all_algos.emplace_back(&int8x8x32_mk4_4x2x16);
all_algos.emplace_back(&int8x8x32_k4x2x16);
all_algos.emplace_back(&int8x8x32_k4x8x8);
all_algos.emplace_back(&quint8_k4x8x8);
all_algos.emplace_back(&int8x8x16_mk4_8x8x4);
all_algos.emplace_back(&int8x8x16_k4x2x16);
all_algos.emplace_back(&int8x8x16_k4x8x8);
m_all_algos.emplace_back(&int8x8x32_mk4_4x2x16);
m_all_algos.emplace_back(&int8x8x32_k4x2x16);
m_all_algos.emplace_back(&int8x8x32_k4x8x8);
m_all_algos.emplace_back(&quint8_k4x8x8);
m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4);
m_all_algos.emplace_back(&int8x8x16_k4x2x16);
m_all_algos.emplace_back(&int8x8x16_k4x8x8);
m_all_algos.emplace_back(&int16x16x32_k12x4x1);
m_all_algos.emplace_back(&int16x16x32_mk8_4x8);
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
all_algos.emplace_back(&int16x16x32_k12x4x1);
all_algos.emplace_back(&int16x16x32_mk8_4x8);
const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const {
return m_all_algos;
}
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack;
auto algos = arm_common::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),
s_algo_pack.all_algos.end());
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}
SmallVector<fallback::MatrixMulImpl::AlgoBase*>
MatrixMulImpl::get_all_packed_algo() {
auto algos = arm_common::MatrixMulImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return algos;
}
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl)
// vim: syntax=cpp.doxygen
......@@ -25,7 +25,10 @@ public:
}
};
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo()
override;
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);
private:
class AlgoF32; // Armv7 F32
......@@ -52,6 +55,9 @@ private:
// DotProduct
#endif
class AlgoPack;
public:
static const AlgoPack& algo_pack();
};
} // namespace armv7
......
/**
* \file dnn/src/common/algo_base.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <functional>
#include <string>
#include "megdnn/oprs/base.h"
#include "src/common/utils.h"
namespace megdnn {
#define MEGDNN_DECL_ALGO_TYPE(_type) \
uint32_t type() const override { \
return static_cast<std::underlying_type<AlgoType>::type>( \
AlgoType::_type); \
}
#define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \
static fallback::_opr::AlgoBase* get_algo_from_desc( \
const AlgorithmDesc& desc)
#define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \
fallback::_opr::AlgoBase* _opr::get_algo_from_desc( \
const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
}
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
}
/**
* \brief construct algo from AlgorithmDesc
*/
template <typename AlgoBase>
class AlgoConstructMixin {
private:
std::vector<std::unique_ptr<AlgoBase>> m_refhold;
protected:
typename AlgoBase::Mapper m_all_algos_map;
public:
//! construct the algo which described by desc, and return the instance
AlgoBase* construct_and_get_algo(
const detail::Algorithm::Info::Desc& desc) {
auto iter = m_all_algos_map.find(desc);
if (iter != m_all_algos_map.end()) {
return m_all_algos_map.at(desc);
}
std::string serialized_bin;
AlgoBase::serialize_write_pod(desc.type, serialized_bin);
serialized_bin += desc.param;
m_refhold.emplace_back(AlgoBase::deserialize(serialized_bin));
m_all_algos_map.emplace(desc, m_refhold.back().get());
return m_refhold.back().get();
}
void clear() {
m_all_algos_map.clear();
m_refhold.clear();
}
const typename AlgoBase::Mapper& all_algos_map() const {
return m_all_algos_map;
}
};
} // namespace megdnn
namespace std {
template <>
struct hash<megdnn::detail::Algorithm::Info::Desc> {
std::size_t operator()(
const megdnn::detail::Algorithm::Info::Desc& desc) const {
return megdnn::hash_combine<size_t>(
megdnn::hash_combine<size_t>(
std::hash<std::string>()(desc.param),
std::hash<uint32_t>()(desc.type)),
std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type)));
}
};
} // namespace std
// vim: syntax=cpp.doxygen
......@@ -25,15 +25,34 @@ namespace megdnn {
*/
template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
typename Opr::Algorithm* ret;
if (auto set = opr->execution_policy().algorithm) {
typename Opr::AlgorithmInfo ret;
auto set = opr->execution_policy().algo;
if (set.valid()) {
ret = set;
} else {
ret = opr->get_algorithm_heuristic(std::forward<Args>(args)...,
std::numeric_limits<size_t>::max(),
false);
ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
false);
}
return opr->get_algo_from_desc(ret.desc);
}
/*!
* \brief get user-configured algorithm, or heuristic algorithm. used in opencl
* whose algo need to be constructed each time.
*/
template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) {
typename Opr::AlgorithmInfo ret;
auto set = opr->execution_policy().algo;
if (set.valid()) {
return opr->algo_pack().construct_and_get_algo(set.desc);
} else {
ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
false);
return opr->get_algo_from_desc(ret.desc);
}
return static_cast<typename Opr::AlgoBase*>(ret);
}
/*!
......
......@@ -9,6 +9,32 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
/**
* Boost Software License - Version 1.0 - August 17th, 2003
*
* Permission is hereby granted, free of charge, to any person or organization
* obtaining a copy of the software and accompanying documentation covered by
* this license (the "Software") to use, reproduce, display, distribute,
* execute, and transmit the Software, and to prepare derivative works of the
* Software, and to permit third-parties to whom the Software is furnished to
* do so, all subject to the following:
*
* The copyright notices in the Software and this entire statement, including
* the above license grant, this restriction and the following disclaimer,
* must be included in all copies of the Software, in whole or in part, and
* all derivative works of the Software, unless such copies or derivative
* works are solely in the form of machine-executable object code generated by
* a source language processor.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
* SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
* FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
#pragma once
#include "megdnn/arch.h"
......@@ -263,6 +289,13 @@ constexpr uint32_t operator"" _hash(char const* str, size_t count) {
return XXHash64CT::hash(str, count, 20160701);
}
// refer to https://www.boost.org/doc/libs/1_64_0/boost/functional/hash/hash.hpp
template <typename T>
inline T hash_combine(T seed, T value) {
seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2);
return seed;
}
template <typename Vec>
std::string vec2str(Vec&& vec) {
std::string res;
......
......@@ -18,8 +18,14 @@ using namespace cuda;
BatchConvBiasForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&int8_nchw4_gemm_dotprod);
all_algos.push_back(&int8_nchw4_implicit_gemm_dotprod);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchConvBiasForwardImpl)
BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack;
BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
......
......@@ -11,13 +11,16 @@
#pragma once
#include <csetjmp>
#include <unordered_map>
#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/handle.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
namespace megdnn {
namespace cuda {
......@@ -26,6 +29,12 @@ protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
CUDA_GEMM_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
BatchConvBiasForwardImpl* opr;
......@@ -85,6 +94,7 @@ public:
const char* name() const override {
return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD";
}
MEGDNN_DECL_ALGO_TYPE(CUDA_GEMM_NCHW4_DOTPROD_INT8)
};
class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp final
......@@ -99,15 +109,16 @@ public:
const char* name() const override {
return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD";
}
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8)
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const;
};
class BatchConvBiasForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class BatchConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
......@@ -116,6 +127,8 @@ public:
AlgoInt8NCHW4DotProdImplicitGemmPrecomp int8_nchw4_implicit_gemm_dotprod;
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace cuda
......
......@@ -26,6 +26,18 @@ public:
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst) override;
const char* get_algorithm_set_name() const override;
class AlgoBase;
class AlgoInt8NCHW4DotProdGemm;
class AlgoInt8NCHW4DotProdImplicitGemmPrecomp;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
......@@ -37,15 +49,6 @@ public:
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
const char* get_algorithm_set_name() const override;
class AlgoBase;
class AlgoInt8NCHW4DotProdGemm;
class AlgoInt8NCHW4DotProdImplicitGemmPrecomp;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
private:
static AlgoPack sm_algo_pack;
......
......@@ -60,4 +60,12 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() {
for (auto& algo : brute_force_algos) {
all_algos.push_back(&algo);
}
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl)
// vim: syntax=cpp.doxygen
......@@ -16,6 +16,8 @@
#include "src/common/utils.h"
#include "src/cuda/batched_matrix_mul/opr_impl.h"
#include "src/cuda/matrix_mul/cublasLt_wrapper.h"
#include "src/common/metahelper.h"
#if CUDA_VERSION >= 10010
#include <cublasLt.h>
#endif
......@@ -28,6 +30,14 @@ protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
CUDA_BRUTE_FORCE,
CUDA_CUBLAS,
CUDA_CUBLASLT,
CUDA_INT8X8X32,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
BatchedMatrixMulForwardImpl* opr;
......@@ -90,6 +100,13 @@ public:
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
const char* name() const override { return m_name.c_str(); }
MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}
};
class BatchedMatrixMulForwardImpl::AlgoCublas final
: public BatchedMatrixMulForwardImpl::AlgoBase {
......@@ -100,6 +117,7 @@ public:
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
const char* name() const override { return "CUBLAS"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS)
};
#if CUDA_VERSION >= 10010
class BatchedMatrixMulForwardImpl::AlgoCublasLt final : public AlgoBase {
......@@ -110,6 +128,7 @@ public:
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
const char* name() const override { return "CUBLAS_LT"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT)
};
#endif
class BatchedMatrixMulForwardImpl::AlgoInt8x8x32 final
......@@ -121,11 +140,13 @@ public:
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
const char* name() const override { return "INT8x8x32"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32)
};
class BatchedMatrixMulForwardImpl::AlgoPack {
class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
MatrixMulForwardImpl::AlgoPack mm_pack;
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
public:
AlgoPack();
......@@ -137,6 +158,8 @@ public:
AlgoInt8x8x32 int8x8x32;
std::vector<AlgoBase*> all_algos;
std::vector<AlgoBruteForce> brute_force_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace cuda
} // namespace megdnn
......@@ -24,7 +24,7 @@ bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available(
const SizeArgs& args) const {
MatrixMulForwardImpl mm{args.opr->handle()};
mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB};
mm.execution_policy() = {m_algorithm};
mm.execution_policy() = {m_algorithm->info()};
auto mm_layout_a = args.layout_a.remove_axis(0);
auto mm_layout_b = args.layout_b.remove_axis(0);
......@@ -39,7 +39,7 @@ size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes(
auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>();
mm_opr->param() = {args.opr->param().transposeA,
args.opr->param().transposeB};
mm_opr->execution_policy() = {m_algorithm};
mm_opr->execution_policy() = {m_algorithm->info()};
return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b,
args.layout_c);
......@@ -50,7 +50,7 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(
auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>();
mm_opr->param() = {args.opr->param().transposeA,
args.opr->param().transposeB};
mm_opr->execution_policy() = {m_algorithm};
mm_opr->execution_policy() = {m_algorithm->info()};
rep(n, N) {
TensorND A_, B_, C_;
auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) {
......
......@@ -32,6 +32,16 @@ public:
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C) override;
const char* get_algorithm_set_name() const override {
return "BATCHED_MATMUL";
}
bool is_thread_safe() const override { return true; }
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) override;
......@@ -40,12 +50,6 @@ public:
const TensorLayout& C,
size_t workspace_limit_in_bytes,
bool reproducible) override;
const char* get_algorithm_set_name() const override {
return "BATCHED_MATMUL";
}
bool is_thread_safe() const override { return true; }
static const AlgoPack& algo_pack() { return sm_algo_pack; }
private:
static AlgoPack sm_algo_pack;
......
......@@ -100,10 +100,16 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
for (size_t i = all_algo_size; i < all_algos.size(); ++i) {
non_cudnn_algos.push_back(all_algos[i]);
}
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack;
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl)
ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
ConvBiasForwardImpl* o, const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& bias,
......@@ -172,43 +178,10 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const {
}
void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)
#define DEF_ALGO(NAME, REPROD) \
cudnn_conv_bias_activations.push_back( \
{REPROD, \
"CUDNN:ConvBiasActivation:" #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \
NAME}); \
cudnn_convs.push_back( \
{REPROD, \
"CUDNN:Convolution:" #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \
NAME})
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true);
#if CUDNN_MAJOR >= 5
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true);
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true);
#endif
#endif
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif
#undef DEF_ALGO
#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) {
cudnn_conv_bias_activations.push_back(algo.first);
cudnn_convs.push_back(algo.first);
}
}
#if CUDA_VERSION >= 10000
......
......@@ -6,19 +6,23 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/common/algo_base.h"
#include "src/common/utils.h"
#include "src/common/metahelper.h"
#include "src/cuda/conv_bias/conv_bias_int8.cuh"
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/opr_impl.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/handle.h"
#include "src/cuda/cudnn_wrapper.h"
#include <cuda.h>
#include <memory>
......@@ -38,11 +42,39 @@ protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
CUDA_CUDNN_CONVBIAS,
CUDA_CHANWISE,
CUDA_CHANWISE_SMALL,
CUDA_CHANWISE_INT8X8X32,
CUDA_CUDNN_CONV,
CUDA_INPLACE_MATMUL,
CUDA_MATMUL,
CUDA_MATMUL_INT8X8X32,
CUDA_1X1,
CUDA_BATCHED_MATMUL,
CUDA_GROUP_CONV_GENERAL,
CUDA_WMMA_UINT4X4X32,
CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8,
CUDA_BFLOAT16,
CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8,
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs : public conv_bias::BiasForwardSizeArgs {
ConvBiasForwardImpl* opr;
const PreprocessedFilter* preprocessed_filter;
std::string to_string() const;
SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& bias,
......@@ -80,13 +112,17 @@ public:
virtual void exec(const ExecArgs& args) const = 0;
virtual size_t get_preprocess_workspace_in_bytes(
const SizeArgs& args) const {
MEGDNN_MARK_USED_VAR(args);
return 0;
}
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const {
MEGDNN_MARK_USED_VAR(args);
return {};
}
virtual void exec_preprocess(const ExecArgs& args) const {}
virtual void exec_preprocess(const ExecArgs& args) const {
MEGDNN_MARK_USED_VAR(args);
}
bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
......@@ -114,11 +150,14 @@ public:
class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase {
public:
AlgoCUDNNConvBiasActivation(bool is_reproducible, const char* name,
cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_is_reproducible(is_reproducible),
m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})),
m_cudnn_enum(cudnn_enum) {}
AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_fwd_algos().end());
m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
m_name = ConvBiasForward::algo_name<DefaultParam>(
"CUDNN:ConvBiasActivation:" + m_attr.name, {});
}
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
......@@ -127,16 +166,24 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_is_reproducible; }
bool is_reproducible() const override { return m_attr.is_reproducible; }
cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; }
bool is_cudnn() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS)
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
private:
bool m_is_reproducible;
std::string m_name;
cudnnConvolutionFwdAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
};
class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase {
......@@ -154,6 +201,8 @@ public:
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
private:
mutable std::string m_name;
};
......@@ -172,6 +221,7 @@ public:
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
private:
mutable std::string m_name;
......@@ -190,6 +240,7 @@ public:
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32)
private:
mutable std::string m_name;
......@@ -197,27 +248,39 @@ private:
class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase {
public:
AlgoCUDNNConv(bool is_reproducible, const char* name,
cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_is_reproducible(is_reproducible),
m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})),
m_cudnn_enum(cudnn_enum) {}
AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_fwd_algos().end());
m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
m_name = ConvBiasForward::algo_name<DefaultParam>(
"CUDNN:Convolution:" + m_attr.name, {});
}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_is_reproducible; }
bool is_reproducible() const override { return m_attr.is_reproducible; }
const char* name() const override { return m_name.c_str(); }
cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; }
bool is_cudnn() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV)
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
private:
bool m_is_reproducible;
std::string m_name;
cudnnConvolutionFwdAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};
......@@ -237,6 +300,7 @@ public:
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
private:
mutable std::string m_name;
......@@ -261,6 +325,7 @@ public:
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
......@@ -281,6 +346,7 @@ public:
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32)
private:
bool need_src_unroll(const SizeArgs& args) const;
......@@ -310,6 +376,7 @@ public:
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1)
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
......@@ -333,6 +400,7 @@ public:
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
......@@ -354,6 +422,13 @@ public:
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& dst_pg, TensorLayout& bias_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
......@@ -370,10 +445,13 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "QUINT4x4x32_WMMA"; }
bool is_reproducible() const override { return true; }
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const;
bool use_kernel_fhxfw(const SizeArgs& args) const;
size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const;
MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32)
};
#endif
......@@ -395,6 +473,7 @@ public:
const convolution::ConvParam& param, float alpha, float beta,
float gamma, float scale, cudaStream_t stream,
param::ConvBias::NonlineMode nonlinear_mode);
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8)
};
class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final
......@@ -415,8 +494,9 @@ public:
warp_k == 32 && stage == 2) {
return "";
}
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k, stage);
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m,
threadblock_n, threadblock_k, warp_m, warp_n,
warp_k, stage);
}
};
AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
......@@ -433,6 +513,13 @@ public:
SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const override;
void exec_preprocess(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
......@@ -457,9 +544,7 @@ public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
template <typename BiasVisitor>
static void dispatch_nonlinear_mode(
......@@ -471,6 +556,14 @@ public:
MMATileSize mma_tile_size);
static std::string to_string(MMATileSize mma_tile_size);
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}
private:
MMATileSize m_mma_tile_size;
std::string m_name;
......@@ -488,10 +581,16 @@ public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const;
......@@ -513,6 +612,13 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}
private:
MMATileSize m_mma_tile_size;
......@@ -533,6 +639,13 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}
private:
MMATileSize m_mma_tile_size;
......@@ -570,6 +683,13 @@ public:
SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const override;
void exec_preprocess(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
......@@ -592,6 +712,14 @@ public:
bool is_reproducible() const override { return m_impl->is_reproducible(); }
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
private:
SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr,
TensorLayout& fsrc, TensorLayout& ffilter,
......@@ -603,17 +731,16 @@ private:
};
class ConvBiasForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
std::vector<AlgoBase*> all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos,
bfloat16_algos;
non_cudnn_algos, bfloat16_algos;
std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
std::vector<AlgoCUDNNConv> cudnn_convs;
AlgoChanwise chanwise;
......@@ -646,6 +773,8 @@ public:
AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo);
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
private:
#if CUDA_VERSION >= 10000
void fill_imma_algos();
......
......@@ -47,7 +47,7 @@ ConvBiasForwardImpl::AlgoBFloat16::float_args(
change_dtype(fdst);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_impl};
opr->execution_policy() = {m_impl->info()};
return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst);
}
......@@ -110,7 +110,7 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
auto convbias_opr = args.handle->create_operator<ConvBias>();
convbias_opr->param() = args.opr->param();
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
convbias_opr->execution_policy() = {m_impl};
convbias_opr->execution_policy() = {m_impl->info()};
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor,
fdst_tensor, nullptr, cvter.workspace());
}
......
......@@ -63,12 +63,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
auto conv_args = args;
auto cudnn_conv_bias_act_from_enum_wrapper =
[this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
[](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo);
};
auto cudnn_conv_from_enum_wrapper =
[this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
[](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
return sm_algo_pack.cudnn_conv_from_enum(algo);
};
......
......@@ -24,17 +24,6 @@ public:
_megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
......@@ -80,6 +69,20 @@ public:
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
private:
static AlgoPack sm_algo_pack;
};
......
......@@ -52,8 +52,14 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl)
ConvolutionBackwardDataImpl::AlgoCUDNN*
ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionBwdDataAlgo_t algo) {
......
......@@ -11,8 +11,11 @@
#pragma once
#include "src/cuda/convolution/helper.h"
#include <unordered_map>
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/convolution/helper.h"
#include "src/cuda/cudnn_wrapper.h"
namespace megdnn {
namespace cuda {
......@@ -23,154 +26,146 @@ namespace cuda {
* All the algo impls should try to support non-contiguous batch dim, for group
* conv execution.
*/
class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;
public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl *handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout, *filter_layout;
ConvolutionBackwardDataImpl *opr;
std::string to_string() const;
void init_desc(convolution::CUDNNBwdDataDescs &desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
}
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad);
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad);
convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_layout, filter_meta,
diff_layout};
}
};
struct ExecArgs: public SizeArgs {
const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(ConvolutionBackwardDataImpl *opr,
_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;
bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(
const SizeArgs &args, const Workspace &workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd data algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
public:
enum class AlgoType : uint32_t {
CUDA_CUDNN,
CUDA_MATMUL,
CUDA_CHANWISE,
CUDA_CHANWISE_SMALL,
CUDA_BFLOAT16,
CUDA_GROUP_CONV_GENERAL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl* handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout, *filter_layout;
ConvolutionBackwardDataImpl* opr;
std::string to_string() const;
void init_desc(convolution::CUDNNBwdDataDescs& desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
}
virtual bool is_cudnn() const {
return false;
SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& grad);
SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad);
convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_layout, filter_meta,
diff_layout};
}
};
struct ExecArgs : public SizeArgs {
const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd data algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}
virtual bool is_cudnn() const { return false; }
};
class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionBwdDataAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
public:
public:
AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_bwd_data_algos().end());
m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum);
}
AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionBwdDataAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }
const char* name() const override {
return m_name;
}
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
bool is_cudnn() const override {
return true;
}
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
};
//! im2col and matmul, with dilation
class ConvolutionBackwardDataImpl::AlgoMatmul final: public AlgoBase {
template<typename T>
static void exec_internal(const ExecArgs &args);
class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase {
template <typename T>
static void exec_internal(const ExecArgs& args);
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return "MATMUL";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
};
class ConvolutionBackwardDataImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
};
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return "CHANNEL_WISE_SMALL";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE_SMALL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
};
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
......@@ -190,61 +185,72 @@ private:
TensorLayout& fsrc, TensorLayout& ffilter,
TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}
};
//! implement group conv by another algo
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final
: public AlgoBase {
AlgoBase* m_impl;
std::string m_name;
public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override {
return m_impl->is_reproducible();
}
bool is_reproducible() const override { return m_impl->is_reproducible(); }
static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
TensorLayout& grad_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
static void modify_size_args(SizeArgs &args,
TensorLayout &diff_pg, TensorLayout &grad_pg);
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
};
class ConvolutionBackwardDataImpl::AlgoPack {
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp
void fill_cudnn_algos();
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
public:
AlgoPack();
std::vector<AlgoCUDNN> cudnn;
AlgoMatmul matmul;
AlgoChanwise chanwise;
AlgoChanwiseSmall chanwise_small;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::vector<AlgoCUDNN> cudnn;
AlgoMatmul matmul;
AlgoChanwise chanwise;
AlgoChanwiseSmall chanwise_small;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms
all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos,
bfloat16_algos;
non_cudnn_algos, bfloat16_algos;
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -42,7 +42,7 @@ ConvolutionBackwardDataImpl::AlgoBFloat16::float_args(
change_dtype(fgrad);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
opr->execution_policy() = {m_algorithm->info()};
return SizeArgs(opr, ffilter, fdiff, fgrad);
}
......@@ -105,7 +105,7 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec(
args.handle->create_operator<ConvolutionBackwardData>();
conv_back_data_opr->param() = args.opr->param();
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
conv_back_data_opr->execution_policy() = {m_algorithm};
conv_back_data_opr->execution_policy() = {m_algorithm->info()};
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
}
......
......@@ -98,35 +98,9 @@ void ConvolutionBackwardDataImpl::AlgoCUDNN::exec(
}
void ConvolutionBackwardDataImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)
#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({ \
REPROD, #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \
"." V(CUDNN_PATCHLEVEL), \
NAME})
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true);
#if CUDNN_MAJOR >= 5
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true);
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true);
#endif
#endif
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif
#undef DEF_ALGO
#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv_bwd_data_algos()) {
cudnn.push_back(algo.first);
}
}
// vim: syntax=cpp.doxygen
......@@ -49,8 +49,14 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() {
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardFilterImpl)
ConvolutionBackwardFilterImpl::AlgoCUDNN*
ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionBwdFilterAlgo_t algo) {
......
......@@ -6,13 +6,16 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/convolution/helper.h"
#include <unordered_map>
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/convolution/helper.h"
namespace megdnn {
namespace cuda {
......@@ -23,141 +26,134 @@ namespace cuda {
* All the algo impls should try to support non-contiguous batch dim, for group
* conv execution.
*/
class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;
public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout, *grad_layout;
CanonizedFilterMeta grad_filter_meta;
ConvolutionBackwardFilterImpl *opr;
std::string to_string() const;
void init_desc(convolution::CUDNNBwdFilterDescs &desc) const {
desc.set(*src_layout, *diff_layout, grad_filter_meta,
opr->param());
}
SizeArgs(ConvolutionBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardFilterImpl* opr,
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta);
convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_layout, grad_filter_meta,
diff_layout};
}
};
struct ExecArgs: public SizeArgs {
const TensorND *src_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(ConvolutionBackwardFilterImpl *opr,
_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;
bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
class ConvolutionBackwardFilterImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
public:
enum class AlgoType : uint32_t {
CUDA_CUDNN,
CUDA_MATMUL,
CUDA_CHANWISE,
CUDA_BFLOAT16,
CUDA_GROUP_CONV_GENERAL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl* handle;
const TensorLayout *src_layout, *diff_layout, *grad_layout;
CanonizedFilterMeta grad_filter_meta;
ConvolutionBackwardFilterImpl* opr;
std::string to_string() const;
void init_desc(convolution::CUDNNBwdFilterDescs& desc) const {
desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param());
}
AlgoBase& check_workspace(
const SizeArgs &args, const Workspace &workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd filter algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}
virtual bool is_cudnn() const {
return false;
SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad);
SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta);
convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_layout, grad_filter_meta,
diff_layout};
}
};
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd filter algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}
virtual bool is_cudnn() const { return false; }
};
class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
public:
public:
AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv_bwd_flt_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_bwd_flt_algos().end());
m_attr = CudnnAlgoPack::conv_bwd_flt_algos().at(cudnn_enum);
}
AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionBwdFilterAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }
const char* name() const override {
return m_name;
}
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; }
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }
bool is_cudnn() const override {
return true;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
};
//! im2col and matmul, with dilation
class ConvolutionBackwardFilterImpl::AlgoMatmul final: public AlgoBase {
template<typename T>
static void exec_internal(const ExecArgs &args);
class ConvolutionBackwardFilterImpl::AlgoMatmul final : public AlgoBase {
template <typename T>
static void exec_internal(const ExecArgs& args);
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return "MATMUL";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
};
class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
};
class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase {
......@@ -169,6 +165,13 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}
private:
std::string m_name;
......@@ -180,57 +183,62 @@ private:
};
//! implement group conv by another algo
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final
: public AlgoBase {
AlgoBase* m_impl;
std::string m_name;
public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_impl->is_reproducible(); }
const char* name() const override {
return m_name.c_str();
}
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& diff_pg);
bool is_reproducible() const override {
return m_impl->is_reproducible();
}
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
static void modify_size_args(SizeArgs &args,
TensorLayout &src_pg, TensorLayout &diff_pg);
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
};
class ConvolutionBackwardFilterImpl::AlgoPack {
class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp
void fill_cudnn_algos();
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
public:
AlgoPack();
std::vector<AlgoCUDNN> cudnn;
AlgoMatmul matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::vector<AlgoCUDNN> cudnn;
AlgoMatmul matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms
all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos,
bfloat16_algos;
non_cudnn_algos, bfloat16_algos;
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -42,7 +42,7 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args(
change_dtype(fgrad);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
opr->execution_policy() = {m_algorithm->info()};
return SizeArgs(opr, fsrc, fdiff, fgrad);
}
......@@ -107,7 +107,7 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec(
conv_back_filter_opr->param() = args.opr->param();
conv_back_filter_opr->param().compute_mode =
Param::ComputeMode::DEFAULT;
conv_back_filter_opr->execution_policy() = {m_algorithm};
conv_back_filter_opr->execution_policy() = {m_algorithm->info()};
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
}
......
......@@ -80,35 +80,9 @@ void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec(
}
void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)
#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({ \
REPROD, #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \
"." V(CUDNN_PATCHLEVEL), \
NAME})
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false);
#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1)
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, true);
#if CUDNN_MAJOR >= 6
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true);
#endif
#endif
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif
#undef DEF_ALGO
#undef V
#undef V1
for(auto&& algo : CudnnAlgoPack::conv_bwd_flt_algos()) {
cudnn.push_back(algo.first);
}
}
// vim: syntax=cpp.doxygen
......@@ -70,7 +70,7 @@ ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src,
conv_param.dilate_w,
0,
conv_param.compute_mode};
ret.convbias_opr->execution_policy() = {this->execution_policy().algorithm};
ret.convbias_opr->execution_policy() = {this->execution_policy().algo};
return ret;
}
......@@ -183,15 +183,6 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
CUDNNBwdDataDescs desc;
args.init_desc(desc);
//disable, segfault in megbrain, need further investigate.
#if 0
bool is_heuristic_success= convolution::
PerformanceModelBackwardData::get_algo_backward_data_success(
args, desc, workspace_limit_in_bytes, &algo);
if (is_heuristic_success) {
return sm_algo_pack.cudnn_from_enum(algo);
}
#endif
#if CUDNN_MAJOR >= 7
int max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(
......
......@@ -24,14 +24,6 @@ class ConvolutionForwardImpl: public ConvolutionForward {
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst,
......@@ -60,99 +52,129 @@ class ConvolutionForwardImpl: public ConvolutionForward {
TensorLayout bias_layout;
TensorLayout z_layout;
};
private:
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&,
const TensorLayout&,
const TensorLayout&);
};
class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
public:
using ConvolutionBackwardData::ConvolutionBackwardData;
void exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;
class AlgoBase;
class AlgoCUDNN;
class AlgoMatmul;
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoGroupConvGeneral;
class AlgoBFloat16;
class AlgoPack;
static const AlgoPack& algo_pack() {
return sm_algo_pack;
}
private:
static AlgoPack sm_algo_pack;
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&,
const TensorLayout&,
const TensorLayout&);
};
class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
public:
using ConvolutionBackwardFilter::ConvolutionBackwardFilter;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& gradk,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;
class AlgoBase;
class AlgoCUDNN;
class AlgoMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoBFloat16;
class AlgoPack;
static const AlgoPack& algo_pack() {
return sm_algo_pack;
}
class ConvolutionBackwardDataImpl : public ConvolutionBackwardData {
public:
using ConvolutionBackwardData::ConvolutionBackwardData;
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, bool reproducible) {
return get_algorithm_heuristic(filter, filter_meta, diff, grad,
workspace_limit_in_bytes, reproducible)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;
class AlgoBase;
class AlgoCUDNN;
class AlgoMatmul;
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoGroupConvGeneral;
class AlgoBFloat16;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
private:
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible);
static AlgoPack sm_algo_pack;
};
private:
static AlgoPack sm_algo_pack;
class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter {
public:
using ConvolutionBackwardFilter::ConvolutionBackwardFilter;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) override;
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
bool reproducible) {
return get_algorithm_heuristic(src, diff, grad, grad_meta,
workspace_limit_in_bytes, reproducible)
->info();
}
const char* get_algorithm_set_name() const override;
class AlgoBase;
class AlgoCUDNN;
class AlgoMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoBFloat16;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
bool reproducible);
static AlgoPack sm_algo_pack;
};
} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -39,8 +39,14 @@ Convolution3DBackwardDataImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&i);
}
megdnn_assert(all_algos_data == all_algos.data());
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardDataImpl)
Convolution3DBackwardDataImpl::AlgoCUDNN*
Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionBwdDataAlgo_t algo) {
......@@ -96,7 +102,7 @@ std::string Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::to_string() const
fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2],
diff_layout->to_string().c_str(),
grad_layout->to_string().c_str(),
fm.padding[0], fm.padding[1], fm.padding[2],
fm.padding[0], fm.padding[1], fm.padding[2],
fm.stride[0], fm.stride[1], fm.stride[2],
fm.dilation[0], fm.dilation[1] ,fm.dilation[2],
!fm.should_flip,
......
......@@ -6,13 +6,16 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/convolution3d/helper.h"
#include <unordered_map>
#include "src/cuda/convolution3d/helper.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
namespace megdnn {
namespace cuda {
......@@ -23,170 +26,174 @@ namespace cuda {
* All the algo impls should try to support non-contiguous batch dim, for group
* conv execution.
*/
class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;
public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl *handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout;
Convolution3DBackwardDataImpl *opr;
std::string to_string() const;
void init_desc(convolution3d::CUDNNBwdDataDescs &desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
}
SizeArgs(Convolution3DBackwardDataImpl *opr,
const TensorLayout &filter, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(Convolution3DBackwardDataImpl *opr,
const CanonizedFilterMeta &filter, const TensorLayout &diff,
const TensorLayout &grad);
convolution3d::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_meta, diff_layout,
opr->param().data_type};
}
};
struct ExecArgs: public SizeArgs {
const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(Convolution3DBackwardDataImpl *opr,
_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;
bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(
const SizeArgs &args, const Workspace &workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd data algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
class Convolution3DBackwardDataImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
CUDA_GROUP_CONV_GENERAL,
CUDA_CUDNN,
CUDA_CHANWISE,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl* handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout;
Convolution3DBackwardDataImpl* opr;
std::string to_string() const;
void init_desc(convolution3d::CUDNNBwdDataDescs& desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
}
virtual bool is_cudnn() const {
return false;
SizeArgs(Convolution3DBackwardDataImpl* opr, const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& grad);
SizeArgs(Convolution3DBackwardDataImpl* opr,
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad);
convolution3d::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_meta, diff_layout,
opr->param().data_type};
}
};
struct ExecArgs : public SizeArgs {
const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(Convolution3DBackwardDataImpl* opr, _megdnn_tensor_in filter,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd data algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}
virtual bool is_cudnn() const { return false; }
};
class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionBwdDataAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
public:
public:
AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv3d_bwd_data_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv3d_bwd_data_algos().end());
m_attr = CudnnAlgoPack::conv3d_bwd_data_algos().at(cudnn_enum);
}
AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionBwdDataAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }
const char* name() const override {
return m_name;
}
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }
bool is_cudnn() const override {
return true;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
};
class Convolution3DBackwardDataImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class Convolution3DBackwardDataImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
};
//! implement group conv by another algo
class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final
: public AlgoBase {
AlgoBase* m_impl;
std::string m_name;
public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override {
return m_impl->is_reproducible();
}
bool is_reproducible() const override { return m_impl->is_reproducible(); }
static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
TensorLayout& grad_pg);
static void modify_size_args(SizeArgs &args,
TensorLayout &diff_pg, TensorLayout &grad_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
};
class Convolution3DBackwardDataImpl::AlgoPack {
class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp
void fill_cudnn_algos();
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
public:
AlgoPack();
std::vector<AlgoCUDNN> cudnn;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<AlgoCUDNN> cudnn;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms
all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -80,27 +80,9 @@ void Convolution3DBackwardDataImpl::AlgoCUDNN::exec(
}
void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)
#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({ \
REPROD, #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \
"." V(CUDNN_PATCHLEVEL), \
NAME})
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true);
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif
#undef DEF_ALGO
#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv3d_bwd_data_algos()) {
cudnn.push_back(algo.first);
}
}
// vim: syntax=cpp.doxygen
......@@ -17,7 +17,7 @@ using namespace cuda;
Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&chanwise);
non_cudnn_algos.push_back(&inplace_matmul);
non_cudnn_algos.push_back(&inplace_matmul);
all_algos.push_back(&chanwise); // prefer chanwise
fill_cudnn_algos();
......@@ -41,8 +41,14 @@ Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() {
}
megdnn_assert(all_algos_data == all_algos.data());
non_cudnn_algos.push_back(all_algos.rbegin()[0]); //group inplace_matmul
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardFilterImpl)
Convolution3DBackwardFilterImpl::AlgoCUDNN*
Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionBwdFilterAlgo_t algo) {
......@@ -99,9 +105,9 @@ Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const {
"pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s",
src_layout->to_string().c_str(),
diff_layout->to_string().c_str(),
fm.group, fm.ocpg, fm.icpg,
fm.group, fm.ocpg, fm.icpg,
fm.spatial[0], fm.spatial[1], fm.spatial[2],
fm.padding[0], fm.padding[1], fm.padding[2],
fm.padding[0], fm.padding[1], fm.padding[2],
fm.stride[0], fm.stride[1], fm.stride[2],
fm.dilation[0], fm.dilation[1], fm.dilation[2],
!fm.should_flip,
......
......@@ -6,198 +6,198 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/convolution3d/helper.h"
#include <unordered_map>
#include "src/cuda/convolution3d/helper.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
namespace megdnn {
namespace cuda {
class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;
public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout;
CanonizedFilterMeta grad_filter_meta;
Convolution3DBackwardFilterImpl *opr;
std::string to_string() const;
void init_desc(convolution3d::CUDNNBwdFilterDescs &desc) const {
desc.set(*src_layout, *diff_layout, grad_filter_meta,
opr->param());
}
SizeArgs(Convolution3DBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(Convolution3DBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const CanonizedFilterMeta &grad);
convolution3d::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_filter_meta, diff_layout,
opr->param().data_type};
}
};
struct ExecArgs: public SizeArgs {
const TensorND *src_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(Convolution3DBackwardFilterImpl *opr,
_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;
bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd filter algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
class Convolution3DBackwardFilterImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;
public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
enum class AlgoType : uint32_t {
CUDA_GROUP_CONV_GENERAL,
CUDA_CUDNN,
CUDA_INPLACE_MATMUL,
CUDA_CHANWISE,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
struct SizeArgs {
HandleImpl* handle;
const TensorLayout *src_layout, *diff_layout;
CanonizedFilterMeta grad_filter_meta;
Convolution3DBackwardFilterImpl* opr;
std::string to_string() const;
void init_desc(convolution3d::CUDNNBwdFilterDescs& desc) const {
desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param());
}
SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad);
SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src,
const TensorLayout& diff, const CanonizedFilterMeta& grad);
virtual bool is_cudnn() const {
return false;
convolution3d::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_filter_meta, diff_layout,
opr->param().data_type};
}
};
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(Convolution3DBackwardFilterImpl* opr, _megdnn_tensor_in src,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd filter algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}
virtual bool is_cudnn() const { return false; }
};
class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
public:
public:
AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv3d_bwd_flt_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv3d_bwd_flt_algos().end());
m_attr = CudnnAlgoPack::conv3d_bwd_flt_algos().at(cudnn_enum);
}
AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionBwdFilterAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }
const char* name() const override {
return m_name;
}
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; }
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
bool is_cudnn() const override {
return true;
}
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
};
class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final
: public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
const char* name() const override {
return "INPLACE_MATMUL";
}
bool is_reproducible() const override {
return false;
}
const char* name() const override { return "INPLACE_MATMUL"; }
bool is_reproducible() const override { return false; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
};
class Convolution3DBackwardFilterImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class Convolution3DBackwardFilterImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
};
//! implement group conv by another algo
class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final
: public AlgoBase {
AlgoBase* m_impl;
std::string m_name;
public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override {
return m_impl->is_reproducible();
}
bool is_reproducible() const override { return m_impl->is_reproducible(); }
static void modify_size_args(SizeArgs &args,
TensorLayout &src_pg, TensorLayout &diff_pg);
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& diff_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
};
class Convolution3DBackwardFilterImpl::AlgoPack {
class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp
void fill_cudnn_algos();
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
public:
AlgoPack();
std::vector<AlgoCUDNN> cudnn;
AlgoInplaceMatmul inplace_matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<AlgoCUDNN> cudnn;
AlgoInplaceMatmul inplace_matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms
all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -66,29 +66,9 @@ void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec(
}
void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)
#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({REPROD, \
#NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V( \
CUDNN_PATCHLEVEL), \
NAME})
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false);
#pragma message \
"fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc"
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false);
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif
#undef DEF_ALGO
#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv3d_bwd_flt_algos()) {
cudnn.push_back(algo.first);
}
}
// vim: syntax=cpp.doxygen
......@@ -21,13 +21,13 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&a1x1x1);
all_algos.push_back(&chanwise);
fill_cudnn_algos();
for (auto &&i: cudnn) {
all_algos.push_back(&i);
all_algos.push_back(&i);
}
all_algos.push_back(&inplace_matmul);
all_algos.push_back(&a1x1x1);
all_algos.push_back(&a1x1x1);
all_algos.reserve(all_algos.size() * 2);
// add gconv algos by AlgoGroupConvGeneral
......@@ -42,10 +42,16 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&i);
}
megdnn_assert(all_algos_data == all_algos.data());
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DForwardImpl)
Convolution3DForwardImpl::AlgoCUDNN*
Convolution3DForwardImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionFwdAlgo_t algo) {
......@@ -99,7 +105,7 @@ std::string Convolution3DForwardImpl::AlgoBase::SizeArgs::to_string() const {
"src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, "
"pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s",
src_layout->to_string().c_str(),
fm.group, fm.ocpg, fm.icpg,
fm.group, fm.ocpg, fm.icpg,
fm.spatial[0], fm.spatial[1], fm.spatial[2],
dst_layout->to_string().c_str(),
fm.padding[0], fm.padding[1], fm.padding[2],
......
......@@ -78,30 +78,10 @@ void Convolution3DForwardImpl::AlgoCUDNN::exec(
cudnnGetErrorString(status), args.to_string().c_str());
}
void Convolution3DForwardImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)
#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({ \
REPROD, #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \
"." V(CUDNN_PATCHLEVEL), \
NAME})
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true);
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif
#undef DEF_ALGO
#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv3d_fwd_algos()) {
cudnn.push_back(algo.first);
}
}
// vim: syntax=cpp.doxygen
此差异已折叠。
......@@ -10,6 +10,7 @@
*/
#pragma once
#include <unordered_map>
#include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h"
#include "src/cuda/cudnn_with_check.h"
......@@ -27,7 +28,7 @@ class TensorDesc {
public:
TensorDesc();
//! default layout is nchw
void set(const TensorLayout& layout, const param::Convolution::Format =
void set(const TensorLayout& layout, const param::Convolution::Format =
param::Convolution::Format::NCHW);
~TensorDesc();
cudnnTensorDescriptor_t desc;
......@@ -103,9 +104,52 @@ class Conv3DDesc {
cudnnConvolutionDescriptor_t desc;
};
class CudnnAlgoPack {
public:
//! algorithm attr
struct Attr {
std::string name;
bool is_reproducible;
};
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr>
conv_bwd_data_algos();
} // namespace cuda
} // namespace megdnn
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr>
conv_bwd_flt_algos();
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr>
conv_fwd_algos();
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr>
conv3d_bwd_data_algos();
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr>
conv3d_bwd_flt_algos();
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr>
conv3d_fwd_algos();
};
} // namespace cuda
} // namespace megdnn
namespace std {
#define DEF_HASH(_type) \
template <> \
struct hash<_type> { \
std::size_t operator()(const _type& algo) const { \
return std::hash<uint32_t>()(static_cast<uint32_t>(algo)); \
} \
}
DEF_HASH(cudnnConvolutionBwdDataAlgo_t);
DEF_HASH(cudnnConvolutionBwdFilterAlgo_t);
DEF_HASH(cudnnConvolutionFwdAlgo_t);
#undef DEF_HASH
} // namespace std
// vim: syntax=cpp.doxygen
......@@ -19,7 +19,12 @@ using OprImpl = DeformableConvBackwardDataImpl;
OprImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo_matmul);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardDataImpl)
OprImpl::AlgoPack OprImpl::sm_algo_pack;
......
......@@ -13,11 +13,15 @@
#include "megdnn/oprs.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/cuda/handle.h"
#include "src/cuda/deformable_conv/opr_impl.h"
#include <unordered_map>
namespace megdnn {
namespace cuda {
......@@ -26,6 +30,10 @@ protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
CUDA_MATMUL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
DeformableConvBackwardDataImpl* opr;
......@@ -107,17 +115,18 @@ public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AlgoMatmul"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
};
class DeformableConvBackwardDataImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj {
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
AlgoMatmul algo_matmul;
//! all algorithms
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace cuda
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册