提交 c33a7173 编写于 作者: M Megvii Engine Team

feat(dnn): repalce is_reproducible with algo attribute in opencl, cpu, rocm and cuda

GitOrigin-RevId: 86dead0a1103cd6ca287a3eeea2893bf723a81d5
上级 249a116b
......@@ -99,6 +99,27 @@ enum class AlgoDataType : uint32_t {
class Algorithm {
public:
static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1);
/**
* \brief the attribe of the algo, such as REPRODUCIBLE, NAIVE
*
*/
enum class Attribute : uint32_t {
/**
* \brief whether the execution result is
* reproducible across multiple runs.
*/
REPRODUCIBLE = 1 << 0,
/**
* \brief whether the algo is naive
* Mark algorithms with simple implementation as NAIVE, so we can filter
* these algorithms to speed up fastrun.
* */
NAIVE = 1 << 1,
};
/**
* \brief Algorithm information, we can get real algo from
* AlgorithmInfo::Info::Desc
......@@ -121,7 +142,7 @@ public:
} desc;
//! algorithm name
std::string name;
bool is_reproducible;
Attribute attribute;
bool valid() const { return desc.valid(); }
void reset() { desc.reset(); }
//! desc donate the algo
......@@ -131,18 +152,20 @@ public:
virtual ~Algorithm() = default;
/**
* \brief whether the execution result is
* reproducible across multiple runs.
* \brief get the attribute of the algo
*/
virtual bool is_reproducible() const = 0;
virtual Attribute attribute() const = 0;
virtual const char* name() const = 0;
//! serialized param
virtual std::string param() const { return {}; }
virtual uint32_t type() const = 0;
bool contain_attribute(const Attribute& attr) const;
Handle::HandleType handle_type() const { return m_handle_type; }
Info info() const {
return {{handle_type(), type(), param()}, name(), is_reproducible()};
return {{handle_type(), type(), param()}, name(), attribute()};
}
Info::Desc desc() const { return {handle_type(), type(), param()}; }
......@@ -524,6 +547,7 @@ protected:
} // namespace detail
using Algorithm = detail::Algorithm;
using AlgoAttribute = Algorithm::Attribute;
using ExecutionPolicy = detail::ExecutionPolicy;
} // namespace megdnn
......
......@@ -19,7 +19,9 @@ namespace aarch64 {
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV8F16STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -23,7 +23,9 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl;
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV8F32STRD2"; }
bool usable(const NCBKernSizeParam& param,
......
......@@ -25,7 +25,9 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8MATMUL"; }
bool usable(const NCBKernSizeParam& param,
......
......@@ -25,7 +25,9 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex&);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "QU8MATMUL"; }
bool usable(const NCBKernSizeParam& param,
......
......@@ -21,7 +21,9 @@ namespace aarch64 {
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_F32K8X12X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -32,7 +34,9 @@ public:
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -43,7 +47,9 @@ public:
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_F32K4X16X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -54,7 +60,9 @@ public:
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_F32_MK4_4x16"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -76,7 +84,9 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_F16_K8X24X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -87,7 +97,9 @@ public:
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_F16_MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -102,7 +114,9 @@ public:
#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD";
}
......@@ -115,7 +129,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD";
}
......@@ -129,7 +145,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -143,7 +161,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -156,7 +176,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -169,7 +191,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -182,7 +206,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -194,7 +220,9 @@ public:
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -207,7 +235,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_16X12X4";
}
......@@ -223,7 +253,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_K8X8X8";
}
......@@ -239,7 +271,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -253,7 +287,9 @@ public:
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -265,7 +301,9 @@ public:
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -278,7 +316,9 @@ public:
#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_QUINT8_K8X8X4_DOTPROD";
}
......@@ -291,7 +331,9 @@ public:
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -306,7 +348,9 @@ public:
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......
......@@ -29,6 +29,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16)
};
......@@ -45,6 +48,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16)
};
......@@ -60,7 +66,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16)
};
......@@ -76,6 +84,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16)
};
......@@ -84,7 +95,9 @@ class ConvBiasImpl::AlgoF16Direct final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F16DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -104,7 +117,9 @@ class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F16STRD1"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -29,6 +29,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32)
};
......@@ -45,6 +48,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32)
};
......@@ -61,6 +67,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32)
};
......@@ -77,6 +86,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32)
};
......@@ -93,6 +105,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32)
};
......@@ -111,6 +126,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32)
};
......@@ -128,6 +146,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32)
};
......@@ -145,6 +166,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32)
};
......@@ -154,7 +178,9 @@ class ConvBiasImpl::AlgoF32Direct final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F32DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -172,7 +198,9 @@ class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F32STRD1"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -190,7 +218,9 @@ class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F32STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -209,7 +239,9 @@ class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
public:
AlgoF32DirectNCHW44() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -228,7 +260,9 @@ class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase {
public:
AlgoF32DirectNCHWNCHW44() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F32_CONV_NCHW_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -246,7 +280,9 @@ class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F32_CHANNEL_WISE_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -20,7 +20,9 @@ namespace arm_common {
class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8STRD1"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -39,7 +41,9 @@ public:
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -56,7 +60,9 @@ public:
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase {
public:
AlgoS8DirectNCHW44() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8_NCHW44_DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -73,7 +79,9 @@ public:
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
public:
AlgoS8DirectNCHWNCHW44() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8_CONV_NCHW_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -89,7 +97,9 @@ public:
class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8_CHAN_WISE_STRD1_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -104,7 +114,9 @@ public:
class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8_CHAN_WISE_STRD2_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -121,7 +133,9 @@ public:
class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; }
bool usable(const NCBKernSizeParam&,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -138,7 +152,9 @@ public:
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMDOTS8STRD1"; }
bool usable(const NCBKernSizeParam&,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -155,7 +171,9 @@ public:
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMDOTS8STRD2"; }
bool usable(const NCBKernSizeParam&,
......@@ -174,7 +192,9 @@ class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase {
public:
AlgoDotS8Direct_NCHW44() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMDOTS8DIRECT_NCHW44"; }
bool usable(const NCBKernSizeParam&,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -205,6 +225,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8)
};
......@@ -223,6 +246,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32)
};
......@@ -241,7 +267,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8)
};
......
......@@ -29,7 +29,9 @@ class ConvBiasImpl::AlgoI8x8x16Direct final : public AlgoBase {
const CpuNDRange& workspace_ids);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "I8816DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -45,7 +47,9 @@ public:
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase {
public:
AlgoS8x8x16DirectNCHW44() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8x8x16_NCHW44_DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -71,7 +75,9 @@ class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase {
const CpuNDRange& workspace_ids);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "I8816STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -87,7 +93,9 @@ public:
class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "I8816STRD2F2"; }
bool usable(const NCBKernSizeParam& param,
......@@ -105,10 +113,10 @@ public:
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final
: public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44";
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(
......@@ -126,7 +134,9 @@ class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase {
public:
AlgoI8x8x16DirectNCHWNCHW44() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "I8816_CONV_NCHW_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -20,7 +20,9 @@ namespace arm_common {
class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "QU8STRD1"; }
bool usable(const NCBKernSizeParam& param,
......@@ -38,7 +40,9 @@ public:
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "QU8STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -55,7 +59,9 @@ public:
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMDOTU8STRD1"; }
bool usable(const NCBKernSizeParam& param,
......@@ -73,7 +79,9 @@ public:
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMDOTU8STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -23,7 +23,9 @@ namespace arm_common {
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final
: public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH32_I8x8x32_DECONV_STRIDE1";
}
......@@ -42,7 +44,9 @@ public:
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final
: public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH32_I8x8x32_DECONV_STRIDE2";
}
......
......@@ -22,7 +22,9 @@ namespace arm_common {
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final
: public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1";
}
......@@ -42,7 +44,9 @@ public:
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final
: public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2";
}
......
......@@ -18,7 +18,9 @@ namespace arm_common {
class ElemwiseImpl::AlgoBinary##case final \
: public ElemwiseImpl::AlgoBase { \
mutable std::string m_name; \
bool is_reproducible() const override { return true; } \
AlgoAttribute attribute() const override { \
return AlgoAttribute::REPRODUCIBLE; \
} \
const char* name() const override { \
if (m_name.empty()) { \
m_name = megdnn_mangle( \
......
......@@ -11,8 +11,8 @@
*/
#pragma once
#include "src/fallback/elemwise/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
namespace megdnn {
namespace arm_common {
class ElemwiseImpl final : public fallback::ElemwiseImpl {
......
......@@ -18,7 +18,9 @@ namespace arm_common {
class ElemwiseImpl::AlgoTernaryFma3##case final \
: public ElemwiseImpl::AlgoBase { \
mutable std::string m_name; \
bool is_reproducible() const override { return true; } \
AlgoAttribute attribute() const override { \
return AlgoAttribute::REPRODUCIBLE; \
} \
const char* name() const override { \
if (m_name.empty()) { \
m_name = megdnn_mangle( \
......
......@@ -16,7 +16,9 @@ namespace arm_common {
class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase {
mutable std::string m_name;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
if (m_name.empty()) {
m_name = megdnn_mangle(ssprintf("Elemwise::AlgoUnary"));
......
......@@ -19,7 +19,9 @@ namespace arm_common {
class MatrixMulImpl::AlgoInt8x8x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X16"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -31,7 +33,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -45,7 +49,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -60,7 +66,9 @@ public:
#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -78,7 +86,9 @@ protected:
~AlgoF32Gemv() = default;
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARM_COMMON_F32_GEMV"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -91,7 +101,9 @@ public:
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -106,7 +118,9 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARM_COMMON_F16_GEMV"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -121,7 +135,9 @@ public:
class MatrixMulImpl::AlgoGevm : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARM_COMMON_GEVM"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......
......@@ -22,7 +22,9 @@ using AlgoBase = PoolingImpl::AlgoBase;
class PoolingImpl::AlgoFilterxModexStride1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_STRIDE1"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -30,14 +32,18 @@ public:
class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_STRIDE2"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
};
class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -45,7 +51,9 @@ public:
class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -53,7 +61,9 @@ public:
class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -61,7 +71,9 @@ public:
class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -69,7 +81,9 @@ public:
class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -77,7 +91,9 @@ public:
class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -85,7 +101,9 @@ public:
class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -93,7 +111,9 @@ public:
class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -101,7 +121,9 @@ public:
class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......@@ -109,14 +131,18 @@ public:
class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
};
class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
......
......@@ -24,7 +24,9 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex&);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8MATMUL"; }
bool usable(const NCBKernSizeParam& param,
......
......@@ -24,7 +24,9 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex&);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "QU8MATMUL"; }
bool usable(const NCBKernSizeParam& param,
......
......@@ -21,7 +21,9 @@ namespace armv7 {
class MatrixMulImpl::AlgoF32 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_F32"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -32,7 +34,9 @@ public:
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -43,7 +47,9 @@ public:
class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_F32_MK4_4x8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -56,7 +62,9 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16K4x16x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH32_F16_K4X16X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -66,7 +74,9 @@ public:
};
class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH32_F16_MK8_4X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -79,7 +89,9 @@ public:
#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH32_INT8_K6X8X4"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -90,7 +102,9 @@ public:
class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "AARCH32_QUINT8_K4X8X4"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -101,7 +115,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH32_INT8_MK4_8X4X4_DOTPROD";
}
......@@ -124,7 +140,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT8X8X32_K4X2X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -136,7 +154,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT8X8X32_K4X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -148,7 +168,9 @@ public:
class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_QUINT8_K4X8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -159,7 +181,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT8X8X16_K4X2X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -171,7 +195,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT8X8X16_K4X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -183,7 +209,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16K8x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT8X8X16_K8X8X4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -195,7 +223,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -207,7 +237,9 @@ public:
class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT16X16X32_K12X4X1"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -219,7 +251,9 @@ public:
class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT16X16X32_MK8_4X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -231,7 +265,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......
/**
* \file dnn/src/common/algo_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/algo_base.h"
#include "src/common/utils.h"
using namespace megdnn;
bool Algorithm::contain_attribute(const Attribute& attr) const {
return bool(attribute() & attr);
}
// vim: syntax=cpp.doxygen
......@@ -21,6 +21,8 @@
namespace megdnn {
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(AlgoAttribute)
#define MEGDNN_DECL_ALGO_TYPE(_type) \
uint32_t type() const override { \
return static_cast<std::underlying_type<AlgoType>::type>( \
......
......@@ -82,7 +82,7 @@ template <typename Opr>
typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo,
bool reproducible) {
if (reproducible) {
if (algo->is_reproducible()) {
if (algo->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
return algo;
}
} else {
......@@ -113,7 +113,7 @@ typename Opr::Algorithm* get_reproducible_algo(
}
}
if (i->is_available(args)) {
if (!i->is_reproducible())
if (!i->contain_attribute(AlgoAttribute::REPRODUCIBLE))
available_but_not_reproducible = true;
}
}
......
......@@ -54,6 +54,7 @@
#include <mutex>
#include <string>
#include <thread>
#include <type_traits>
#if defined(_WIN32)
#include <windows.h>
......@@ -683,6 +684,62 @@ inline void* get_origin_ptr(const TensorND* tensor, void* ptr) {
return static_cast<void*>(static_cast<dt_byte*>(ptr) -
tensor->layout.span().low_byte);
}
template <typename T>
class EnumClassBit {
std::underlying_type_t<T> m_val;
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {}
public:
constexpr EnumClassBit(T v)
: m_val(static_cast<std::underlying_type_t<T>>(v)) {}
constexpr operator T() const { return static_cast<T>(m_val); }
constexpr explicit operator bool() const { return m_val; }
#define DEF_OPR(op) \
constexpr EnumClassBit operator op(const EnumClassBit& rhs) const { \
return m_val op rhs.m_val; \
}
DEF_OPR(&)
DEF_OPR(|)
DEF_OPR (^)
constexpr EnumClassBit operator~() const { return ~m_val; }
#undef DEF_OPR
};
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) \
op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
}
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \
inline constexpr cls& operator op##=(cls& x, cls y) { \
x = x op ::megdnn::EnumClassBit<cls>(y); \
return x; \
}
#define MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) \
_MEGDNN_DECBO_SINGLE_OPR(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR(cls, ^) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \
inline constexpr ::megdnn::EnumClassBit<cls> operator~(cls x) { \
return ~::megdnn::EnumClassBit<cls>(x); \
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -14,6 +14,7 @@
#include <unordered_map>
#include "megdnn/oprs.h"
#include "megdnn/oprs/base.h"
#include "src/common/utils.h"
#include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/handle.h"
......@@ -67,7 +68,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
......@@ -89,7 +91,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD";
......@@ -104,7 +108,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD";
......
......@@ -71,7 +71,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -94,7 +95,9 @@ public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override;
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute()const override{
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "BRUTE_FORCE"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE)
......@@ -109,7 +112,9 @@ public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override;
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "CUBLAS"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS)
};
......@@ -120,7 +125,9 @@ public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override;
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "CUBLAS_LT"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT)
};
......@@ -132,7 +139,9 @@ public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override;
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "INT8x8x32"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32)
};
......
......@@ -130,7 +130,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
......@@ -165,7 +166,13 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_attr.is_reproducible; }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_attr.is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; }
......@@ -198,7 +205,9 @@ public:
}
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
......@@ -219,8 +228,10 @@ public:
}
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
mutable std::string m_name;
......@@ -238,8 +249,10 @@ public:
}
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
mutable std::string m_name;
......@@ -260,7 +273,13 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_attr.is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
const char* name() const override { return m_name.c_str(); }
......@@ -298,8 +317,10 @@ public:
}
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
mutable std::string m_name;
......@@ -327,8 +348,10 @@ public:
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
......@@ -347,8 +370,10 @@ public:
}
return m_name.c_str();
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
bool need_src_unroll(const SizeArgs& args) const;
......@@ -378,7 +403,10 @@ public:
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
private:
......@@ -397,7 +425,13 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_impl->is_reproducible(); }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& dst_pg, TensorLayout& bias_pg);
......@@ -423,7 +457,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return "QUINT4x4x32_WMMA"; }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
......@@ -444,7 +480,9 @@ public:
const char* name() const override {
return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM";
}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
template <typename BiasVisitor>
static void dispatch_nonlinear_mode(
const int8_t* d_src, const int8_t* d_filter,
......@@ -486,7 +524,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
size_t get_preprocess_workspace_in_bytes(
const SizeArgs& args) const override;
SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
......@@ -524,7 +564,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
template <typename BiasVisitor>
static void dispatch_nonlinear_mode(
const int8_t* d_src, const int8_t* d_filter,
......@@ -561,7 +603,6 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8)
std::string param() const override {
......@@ -569,6 +610,9 @@ public:
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
......@@ -590,7 +634,6 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8)
std::string param() const override {
......@@ -598,6 +641,9 @@ public:
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
MMATileSize m_mma_tile_size;
......@@ -617,7 +663,6 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8)
std::string param() const override {
......@@ -625,6 +670,9 @@ public:
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
MMATileSize m_mma_tile_size;
......@@ -655,7 +703,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
static std::string to_string(AlgoParam algo_param);
size_t get_preprocess_workspace_in_bytes(
const SizeArgs& args) const override;
......@@ -690,7 +740,10 @@ public:
const OperatorBase* opr) const override;
const char* name() const override { return "CONVBIAS_BFLOAT16"; }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
private:
......
......@@ -82,7 +82,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
......@@ -115,10 +116,14 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
const char* name() const override { return m_attr.name.c_str(); }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_attr.is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }
bool is_cudnn() const override { return true; }
......@@ -146,8 +151,10 @@ public:
const OperatorBase* opr) const override;
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase {
......@@ -157,8 +164,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase {
......@@ -168,8 +177,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE_SMALL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
......@@ -185,7 +196,10 @@ public:
const char* name() const override {
return "CONVOLUTION_BACKWARD_DATD_BFLOAT16";
}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
......@@ -207,11 +221,17 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_impl->is_reproducible(); }
static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
TensorLayout& grad_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
std::string param() const override {
std::string ret;
......
......@@ -81,7 +81,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
......@@ -114,9 +115,14 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
const char* name() const override { return m_attr.name.c_str(); }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_attr.is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; }
......@@ -145,8 +151,10 @@ public:
const OperatorBase* opr) const override;
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase {
......@@ -156,8 +164,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase {
......@@ -173,7 +183,11 @@ public:
const char* name() const override {
return "CONVOLUTION_BACKWARD_FILTER_BFLOAT16";
}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
private:
......@@ -195,12 +209,17 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_impl->is_reproducible(); }
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& diff_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
std::string param() const override {
std::string ret;
......
......@@ -31,6 +31,7 @@ protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
CUDA_DEFAULT,
};
......@@ -65,7 +66,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -86,7 +88,10 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override;
const char* name() const override { return "DEFAULT"; }
void exec(const ExecArgs&) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
......
......@@ -38,7 +38,6 @@ public:
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl* handle;
......@@ -79,7 +78,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -111,9 +111,14 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
const char* name() const override { return m_attr.name.c_str(); }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_attr.is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }
......@@ -135,8 +140,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
//! implement group conv by another algo
......@@ -154,10 +161,15 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_impl->is_reproducible(); }
static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
TensorLayout& grad_pg);
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
......
......@@ -72,7 +72,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -104,7 +105,13 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_attr.is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
const char* name() const override { return m_attr.name.c_str(); }
......@@ -128,7 +135,9 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "INPLACE_MATMUL"; }
bool is_reproducible() const override { return false; }
AlgoAttribute attribute() const override {
return static_cast<AlgoAttribute>(0);
}
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
};
......@@ -139,7 +148,9 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
};
......@@ -158,7 +169,13 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_impl->is_reproducible(); }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& diff_pg);
......@@ -201,3 +218,4 @@ public:
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -77,7 +77,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -102,7 +103,9 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "1x1x1"; }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1)
};
......@@ -120,8 +123,13 @@ public:
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_impl->is_reproducible(); }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& dst_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
......@@ -147,7 +155,13 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_attr.is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
const char* name() const override { return m_attr.name.c_str(); }
......@@ -172,7 +186,9 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "INPLACE_MATMUL"; }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
};
......@@ -183,7 +199,9 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
};
......@@ -218,3 +236,4 @@ public:
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -83,7 +83,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -107,7 +108,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
......
......@@ -76,7 +76,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -99,7 +100,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
......
......@@ -71,7 +71,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -94,7 +95,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
......
......@@ -35,7 +35,6 @@ public:
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
LocalShareBackwardDataImpl* opr;
......@@ -63,7 +62,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -83,7 +83,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "LOCAL_SHARE_IMPLICIT_GEMM";
......@@ -100,7 +102,9 @@ public:
const SizeArgs& args) const;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "LOCAL_SHARE_BATCHED_MATMUL";
......
......@@ -62,7 +62,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -82,7 +83,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM)
......@@ -96,7 +99,9 @@ public:
const SizeArgs& args) const;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
......
......@@ -63,7 +63,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -85,7 +86,9 @@ public:
const SizeArgs& args) const;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE";
......@@ -102,7 +105,9 @@ public:
const SizeArgs& args) const;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE";
......@@ -118,7 +123,9 @@ public:
const SizeArgs& args) const;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
......
......@@ -86,7 +86,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -109,8 +110,10 @@ public:
}
const char* name() const override { return "CUBLAS"; }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
#if CUDA_VERSION >= 10000
......@@ -121,8 +124,10 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return "UINT4x4x32_WMMA"; }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
#endif
#if CUDA_VERSION >= 10010
......@@ -132,8 +137,10 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return "CUBLAS_LT"; }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
#endif
......@@ -146,8 +153,10 @@ public:
}
const char* name() const override { return "NAIVE"; }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
#if !MEGDNN_DISABLE_FLOAT16
......@@ -163,7 +172,10 @@ public:
const OperatorBase* opr) const override;
const char* name() const override { return "MATMUL_BFLOAT16"; }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
......@@ -189,7 +201,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT)
std::string param() const override {
......@@ -214,7 +228,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K)
std::string param() const override {
......@@ -239,7 +255,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED)
std::string param() const override {
......
......@@ -66,7 +66,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -87,7 +88,9 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override;
const char* name() const override { return "DEFAULT"; }
virtual void exec(const ExecArgs&) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(fallback_BLAS)
};
......
......@@ -20,7 +20,9 @@ namespace fallback {
class ConvBiasImpl::AlgoNaive final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override{
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override { return "FALLBACK_NAIVE"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -43,7 +45,9 @@ class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase {
public:
AlgoWinogradF32(MatrixMulImpl::AlgoBase* matmul_algo)
: m_matmul_algo{matmul_algo} {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
......@@ -77,7 +81,9 @@ class ConvBiasImpl::AlgoWinogradF32_4x4 final : public AlgoBase {
public:
AlgoWinogradF32_4x4(MatrixMulImpl::AlgoBase* matmul_algo)
: m_matmul_algo{matmul_algo} {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
......@@ -111,7 +117,9 @@ class ConvBiasImpl::AlgoWinogradQS8 final : public AlgoBase {
public:
AlgoWinogradQS8(MatrixMulImpl::AlgoBase* matmul_algo)
: m_matmul_algo{matmul_algo} {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
......@@ -145,7 +153,9 @@ class ConvBiasImpl::AlgoWinogradQS8_8x8 final : public AlgoBase {
public:
AlgoWinogradQS8_8x8(MatrixMulImpl::AlgoBase* matmul_algo)
: m_matmul_algo{matmul_algo} {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
......
......@@ -141,7 +141,6 @@ using BiasMode = ConvBiasForward::BiasMode;
}
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(_algo_data_type) \
bool is_reproducible() const override { return true; } \
bool usable(const NCBKernSizeParam& param, \
AlgoSelectionStrategy algo_selection_strategy) const override; \
size_t get_workspace(const NCBKernSizeParam& param) const override; \
......
......@@ -29,7 +29,9 @@ public:
AlgoConv1x1(MatrixMulImpl::AlgoBase* matmul_algo, size_t oc_block_size)
: m_matmul_algo(matmul_algo), m_oc_block_size(oc_block_size) {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return m_matmul_algo->attribute();
}
const char* name() const override {
if (m_name.empty()) {
......
......@@ -22,7 +22,9 @@ class ConvBiasImpl::AlgoConv1x1Gemv final : public AlgoBase {
public:
AlgoConv1x1Gemv() = default;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "CONV1x1_GEMV"; }
......
......@@ -27,7 +27,9 @@ public:
: m_matmul_algo(matmul_algo),
m_ohw_tile_size(ohw_tile_size) {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return m_matmul_algo->attribute();
}
const char* name() const override {
if (m_name.empty()) {
m_name = ssprintf("IM2COLMATMUL:%s:%zu", m_matmul_algo->name(),
......
......@@ -320,10 +320,12 @@ public:
virtual bool is_preferred(const NCBKernSizeParam&) const {
return false;
}
bool usable_reproducible(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy,
bool reproducible = true) const {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
usable(param, algo_selection_strategy);
}
......
......@@ -75,7 +75,6 @@ void kern_naive(const ConvolutionBackwardDataImpl::NCBKernParam& p) {
class ConvolutionImpl::AlgoFallback final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "FALLBACK_ALGO"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -85,6 +84,10 @@ public:
SmallVector<NCBKern> dispatch_kern(
const NCBKernSizeParam& /*param*/) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE};
}
......@@ -93,7 +96,6 @@ public:
class ConvolutionImpl::AlgoNaive final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "NAIVE_ALGO"; }
bool usable(const NCBKernSizeParam& /*param*/,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -103,6 +105,9 @@ public:
SmallVector<NCBKern> dispatch_kern(
const NCBKernSizeParam& /*param*/) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
ConvAlgoTypePack get_algo_type() const override {
auto support_data_type = static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
......@@ -122,7 +127,6 @@ class ConvolutionImpl::AlgoDefault final : public AlgoBase {
public:
AlgoDefault(ConvBiasImpl::AlgoBase*);
bool is_reproducible() const override { return true; }
const char* name() const override { return m_name.c_str(); }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -144,6 +148,10 @@ public:
return get_kimpl(m_algorithm, param);
}
AlgoAttribute attribute() const override {
return m_algorithm->attribute();
}
//! select matmul to the highest preference
bool is_preferred(const NCBKernSizeParam& param) const override;
......@@ -169,7 +177,6 @@ private:
////////////////////////// convolutionbackwarddata ////////////////////////
class ConvolutionBackwardDataImpl::AlgoNaive final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "DeconvNaive"; }
bool usable(ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const override;
......@@ -178,12 +185,14 @@ public:
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override;
bool is_naive() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_DECL_ALGO_TYPE(FB_NAIVE)
};
class ConvolutionBackwardDataImpl::AlgoDirect final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "DeconvDirect"; }
bool usable(ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const override;
......@@ -191,12 +200,14 @@ public:
const NCBKernSizeParam& param) const override;
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(FB_DIRECT)
};
class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "DeconvMatmul"; }
bool usable(ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const override;
......@@ -205,6 +216,9 @@ public:
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(FB_MATMUL)
};
......
......@@ -736,7 +736,7 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
for (auto i : ncb_1g_get_all_algorithms(param)) {
if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
if (reproducible) {
if (i->is_reproducible()) {
if (i->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
return i;
}
} else {
......
......@@ -237,10 +237,12 @@ public:
virtual bool is_preferred(const NCBKernSizeParam&) const {
return false;
}
bool usable_reproducible(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy,
bool reproducible = true) const {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
usable(param, algo_selection_strategy);
}
......@@ -422,7 +424,9 @@ protected:
bool usable_reproducible(ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param,
bool reproducible = true) const {
return (!reproducible || is_reproducible()) && usable(opr, param);
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
usable(opr, param);
}
virtual bool is_preferred(const NCBKernSizeParam&) const {
return false;
......
......@@ -21,18 +21,19 @@ namespace fallback {
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "FB_F32_K8X12X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_DECL_ALGO_TYPE(FB_F32K8x12x1)
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoGemv final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "FB_GEMV"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......@@ -40,6 +41,9 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_DECL_ALGO_TYPE(FB_GEMV)
MEGDNN_OVERRIDE_MATMUL_DESC(
8, 16, 1, 4,
......@@ -54,7 +58,9 @@ public:
class MatrixMulImpl::AlgoNaive final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override { return "FB_NAIVE"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
......
......@@ -225,7 +225,9 @@ public:
};
bool preferred_reproducible(const KernSizeParam& param,
bool reproducible = true) {
return (!reproducible || is_reproducible()) && preferred(param);
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
preferred(param);
};
virtual MatmulDescription matmul_description() const = 0;
......
......@@ -129,7 +129,7 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic(
auto algo = static_cast<HandleImpl*>(handle())
->default_batch_conv_bias_fwd_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......
......@@ -250,7 +250,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......
......@@ -11,39 +11,50 @@
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/common/algo_base.h"
namespace megdnn {
namespace naive {
class DefaultConvolutionForwardAlgorithm final
: public megdnn::ConvolutionForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
class DefaultConvolutionBackwardDataAlgorithm final
: public megdnn::ConvolutionBackwardData::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
class DefaultConvolutionBackwardFilterAlgorithm final
: public megdnn::ConvolutionBackwardFilter::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
class DefaultConvBiasForwardAlgorithm final
: public megdnn::ConvBiasForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
class DefaultBatchConvBiasForwardAlgorithm final
: public megdnn::BatchConvBiasForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
} // namespace naive
......
......@@ -276,7 +276,7 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic(
auto algo =
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......@@ -308,7 +308,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......@@ -341,7 +341,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......
......@@ -10,25 +10,32 @@
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/common/algo_base.h"
namespace megdnn {
namespace naive {
class DefaultConvolution3DForwardAlgorithm final
: public megdnn::Convolution3DForward::Algorithm {
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
class DefaultConvolution3DBackwardDataAlgorithm final
: public megdnn::Convolution3DBackwardData::Algorithm {
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
class DefaultConvolution3DBackwardFilterAlgorithm final
: public megdnn::Convolution3DBackwardFilter::Algorithm {
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
......
......@@ -123,7 +123,7 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
bool reproducible) {
auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......@@ -156,7 +156,7 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
auto algo =
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......@@ -191,7 +191,7 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
auto algo = static_cast<HandleImpl*>(handle())
->default_conv3d_bwd_filter_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......
......@@ -11,6 +11,7 @@
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/oprs/base.h"
#include "src/common/handle_impl.h"
#include "src/naive/convolution/algorithms.h"
#include "src/naive/matrix_mul/algorithms.h"
......
......@@ -11,27 +11,34 @@
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/common/algo_base.h"
namespace megdnn {
namespace naive {
class DefaultLocalShareForwardAlgorithm final
: public megdnn::LocalShareForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
class DefaultLocalShareBackwardDataAlgorithm final
: public megdnn::LocalShareBackwardData::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
class DefaultLocalShareBackwardFilterAlgorithm final
: public megdnn::LocalShareBackwardFilter::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
} // namespace naive
} // namespace megdnn
......
......@@ -166,7 +166,7 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic(
auto algo =
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......@@ -200,7 +200,7 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic(
auto algo = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_data_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......@@ -234,7 +234,7 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic(
auto algo = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_filter_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
......
......@@ -17,14 +17,18 @@ namespace naive {
class DefaultMatrixMulAlgorithm final
: public megdnn::MatrixMulForward::Algorithm {
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
class DefaultBatchedMatrixMulAlgorithm final
: public megdnn::BatchedMatrixMulForward::Algorithm {
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
......
......@@ -73,7 +73,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -96,7 +97,9 @@ public:
}
const char* name() const override { return "BLAS"; }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS)
};
......
......@@ -77,7 +77,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
......@@ -107,8 +108,13 @@ public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_is_reproducible; }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
const char* name() const override {
return "MIOpenConvolutionBackwardData";
......@@ -137,8 +143,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase {
......@@ -148,8 +156,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
......
......@@ -34,6 +34,7 @@ public:
ROCM_CHANWISE
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; }
struct SizeArgs {
HandleImpl* handle;
......@@ -73,7 +74,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
......@@ -104,8 +106,13 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_is_reproducible; }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
const char* name() const override {
return "MIOpenConvolutionBackwardFilter";
}
......@@ -133,8 +140,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase {
......@@ -144,8 +153,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj {
......
......@@ -33,7 +33,6 @@ namespace rocm {
class ConvolutionForwardImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;
public:
enum class AlgoType : uint32_t {
ROCM_MIOPEN,
......@@ -77,7 +76,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
......@@ -107,7 +107,13 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return m_is_reproducible; }
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
const char* name() const override { return "MIOpenConvolutionForward"; }
......@@ -134,7 +140,9 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL)
};
......@@ -146,8 +154,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "INPLACE_MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_INPLACE_MATMUL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
//! optimized 1x1 conv
......@@ -161,8 +171,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "1x1"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_1X1)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
//! optimized 1x1 conv when input data batchsize is larger than 32
......@@ -176,8 +188,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "LARGE_BATCH_1x1"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_1X1_LARGE_BATCH)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionForwardImpl::AlgoChanwise final : public AlgoBase {
......@@ -187,8 +201,10 @@ public:
void exec(const ExecArgs& args) const override;
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
};
class ConvolutionForwardImpl::AlgoPack : NonCopyableObj {
......
......@@ -73,7 +73,8 @@ public:
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) &&
return (!reproducible ||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -96,7 +97,9 @@ public:
}
const char* name() const override { return "BLAS"; }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS)
};
......
......@@ -32,7 +32,9 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase {
const CpuNDRange& workspace_ids);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP";
}
......@@ -68,7 +70,9 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase {
const CpuNDRange& workspace_ids);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP";
}
......@@ -101,6 +105,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(X86_WINOGRAD_F63_8x8_F32)
};
......@@ -117,6 +124,9 @@ public:
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(X86_WINOGRAD_F23_8x8_F32)
};
......@@ -128,7 +138,9 @@ class ConvBiasImpl::AlgoMkldnnConv final : public AlgoBase {
public:
AlgoMkldnnConv() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "MKLDNN_CONV_FP32"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const override {
......
......@@ -21,7 +21,9 @@ class ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8 final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1";
}
......@@ -46,7 +48,9 @@ class ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8 final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2";
}
......@@ -71,7 +75,9 @@ class ConvBiasImpl::AlgoDirectAvx2Stride1Int8 final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE1";
}
......@@ -96,7 +102,9 @@ class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE2";
}
......@@ -124,7 +132,9 @@ class ConvBiasImpl::AlgoMkldnnQint8 final : public AlgoBase {
public:
AlgoMkldnnQint8() {}
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "MKLDNN_INT8"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const override;
......@@ -163,7 +173,9 @@ class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "MKLDNN_MATMUL_INT8"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const override;
......
......@@ -20,11 +20,13 @@ namespace x86 {
class MatrixMulImpl::AlgoF32Blas : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "X86_F32_BLAS"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(X86_F32_BLAS)
......@@ -33,7 +35,9 @@ public:
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
class MatrixMulImpl::AlgoF32MKLPackA : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_F32_MKL_PACKA"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
......@@ -55,7 +59,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16 : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_INT8X8X32_AVX2_2X4X16"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -66,7 +72,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_INT8X8X32_AVX2_4X16X2"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -81,7 +89,9 @@ private:
const MatrixMulImpl::KernParam& kern_param);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_INT8X8X16_AVX2"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -97,7 +107,9 @@ private:
const MatrixMulImpl::KernParam& kern_param);
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_INT8X8X16_SSE"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -109,7 +121,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_INT8X8X32_SSE_4X8X2"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -120,7 +134,9 @@ public:
class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_F32MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -133,7 +149,9 @@ public:
#if MEGDNN_X86_WITH_VNNI
class MatrixMulImpl::AlgoInt8x8x32Vnni : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_INT8X8X32_VNNI"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -146,7 +164,9 @@ public:
#if MEGDNN_X86_WITH_MKL_DNN
class MatrixMulImpl::AlgoInt8x8x32Mkldnn : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "X86_INT8X8X32_MKLDNN"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
......
......@@ -420,7 +420,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
mgb_assert(palgo, "Unknown algo description");
ret.append("): algo=" + std::string(palgo->name()));
ret.append(ssprintf(" workspace=%.2fMiB reproducible=%d",
workspace / (1024 * 1024.0), palgo->is_reproducible()));
workspace / (1024 * 1024.0),
palgo->contain_attribute(
megdnn::AlgoAttribute::REPRODUCIBLE)));
mgb_log_debug("%s", ret.c_str());
megdnn_opr->execution_policy() = policy;
......@@ -715,8 +717,10 @@ AlgoChooser<Opr>::ExeContext::profile_single_algo(
if (!rst.valid())
return None;
return AlgoChooserProfileCache::ResultEntry{
palgo->name(), palgo->is_reproducible(), rst.val().time,
param.workspace};
palgo->name(),
palgo->contain_attribute(
megdnn::AlgoAttribute::REPRODUCIBLE),
rst.val().time, param.workspace};
}
template <typename Opr>
......
......@@ -2127,7 +2127,8 @@ TEST(TestOprDNN, HeuristicReproducible) {
megdnn_opr->get_algorithm_from_desc(algo);
mgb_assert(palgo, "Unknown algo description");
if (strategy == S::HEURISTIC_REPRODUCIBLE) {
EXPECT_TRUE(palgo->is_reproducible());
EXPECT_TRUE(palgo->contain_attribute(
megdnn::AlgoAttribute::REPRODUCIBLE));
}
algo_name0 = palgo->name();
}
......@@ -2338,7 +2339,9 @@ class MockAlgorithm : public megdnn::detail::Algorithm {
public:
MockAlgorithm(const char* name = "NotImportant") : m_name(name) {}
bool is_reproducible() const override { return true; }
Attribute attribute() const override {
return Attribute::REPRODUCIBLE;
}
const char* name() const override { return m_name; }
uint32_t type() const override {
return megdnn::detail::Algorithm::INVALID_ALGO_TYPE;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册