提交 ec1a99ac 编写于 作者: M Megvii Engine Team

refactor(mgb/dnn): replace reproducible with attribute

GitOrigin-RevId: d49015714c26432d9c965231d5c6c4d60efae4bd
上级 6af0299c
...@@ -105,6 +105,10 @@ public: ...@@ -105,6 +105,10 @@ public:
* *
*/ */
enum class Attribute : uint32_t { enum class Attribute : uint32_t {
/**
* \brief general algo.
*/
DEFAULT = 0,
/** /**
* \brief whether the execution result is * \brief whether the execution result is
...@@ -163,6 +167,8 @@ public: ...@@ -163,6 +167,8 @@ public:
bool contain_attribute(const Attribute& attr) const; bool contain_attribute(const Attribute& attr) const;
static std::string attribute_str(const Attribute& attr);
Handle::HandleType handle_type() const { return m_handle_type; } Handle::HandleType handle_type() const { return m_handle_type; }
Info info() const { Info info() const {
return {{handle_type(), type(), param()}, name(), attribute()}; return {{handle_type(), type(), param()}, name(), attribute()};
...@@ -311,6 +317,7 @@ class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> { ...@@ -311,6 +317,7 @@ class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> {
public: public:
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info; using AlgorithmInfo = detail::Algorithm::Info;
using AlgoAttribute = detail::Algorithm::Attribute;
//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
...@@ -335,9 +342,9 @@ public: ...@@ -335,9 +342,9 @@ public:
const TensorLayout& p2, const TensorLayout& p2,
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) { const AlgoAttribute& attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes,
reproducible) attr)
->info(); ->info();
} }
...@@ -360,7 +367,7 @@ protected: ...@@ -360,7 +367,7 @@ protected:
const TensorLayout& p2, const TensorLayout& p2,
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0; const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0;
}; };
//! specializae for nargs == 4 //! specializae for nargs == 4
...@@ -369,6 +376,7 @@ class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> { ...@@ -369,6 +376,7 @@ class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> {
public: public:
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info; using AlgorithmInfo = detail::Algorithm::Info;
using AlgoAttribute = detail::Algorithm::Attribute;
//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
...@@ -394,9 +402,9 @@ public: ...@@ -394,9 +402,9 @@ public:
const TensorLayout& p2, const TensorLayout& p3, const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) { const AlgoAttribute& attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes,
reproducible) attr)
->info(); ->info();
} }
...@@ -419,7 +427,7 @@ protected: ...@@ -419,7 +427,7 @@ protected:
const TensorLayout& p2, const TensorLayout& p3, const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0; const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0;
}; };
//! specializae for nargs == 5 //! specializae for nargs == 5
...@@ -428,6 +436,7 @@ class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> { ...@@ -428,6 +436,7 @@ class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> {
public: public:
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info; using AlgorithmInfo = detail::Algorithm::Info;
using AlgoAttribute = detail::Algorithm::Attribute;
//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
...@@ -455,9 +464,9 @@ public: ...@@ -455,9 +464,9 @@ public:
const TensorLayout& p4, const TensorLayout& p4,
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) { const AlgoAttribute& attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4, return get_algorithm_heuristic(p0, p1, p2, p3, p4,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
...@@ -482,7 +491,7 @@ protected: ...@@ -482,7 +491,7 @@ protected:
const TensorLayout& p4, const TensorLayout& p4,
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0; const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0;
}; };
//! specializae for nargs == 8 //! specializae for nargs == 8
...@@ -491,6 +500,7 @@ class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> { ...@@ -491,6 +500,7 @@ class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> {
public: public:
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info; using AlgorithmInfo = detail::Algorithm::Info;
using AlgoAttribute = detail::Algorithm::Attribute;
//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info( std::vector<AlgorithmInfo> get_all_algorithms_info(
...@@ -518,9 +528,9 @@ public: ...@@ -518,9 +528,9 @@ public:
const TensorLayout& p6, const TensorLayout& p7, const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) { const AlgoAttribute& attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
...@@ -547,7 +557,7 @@ protected: ...@@ -547,7 +557,7 @@ protected:
const TensorLayout& p6, const TensorLayout& p7, const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0; const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0;
}; };
} // namespace detail } // namespace detail
......
...@@ -15,8 +15,39 @@ ...@@ -15,8 +15,39 @@
using namespace megdnn; using namespace megdnn;
#define FOREACH_ALGO_ATTRIBUTE(cb) \
cb(DEFAULT) \
cb(REPRODUCIBLE) \
cb(NAIVE)
namespace {
inline const char* attr_str(const AlgoAttribute& attr) {
#define cb(attr) \
case AlgoAttribute::attr: \
return #attr;
switch (attr) { FOREACH_ALGO_ATTRIBUTE(cb) }
#undef cb
return "unknown arch";
}
} // namespace
std::string Algorithm::attribute_str(const Attribute& attr) {
std::string ret;
uint32_t attr_val = static_cast<uint32_t>(attr);
while(attr_val) {
uint32_t mask = ~(attr_val & (attr_val - 1));
Attribute sub_attr = static_cast<Attribute>(mask & attr_val);
if (!ret.empty()) {
ret.append(" | ");
}
ret.append(attr_str(sub_attr));
attr_val = attr_val & (attr_val - 1);
}
return ret;
}
bool Algorithm::contain_attribute(const Attribute& attr) const { bool Algorithm::contain_attribute(const Attribute& attr) const {
return bool(attribute() & attr); return attr == static_cast<Attribute>(attribute() & attr);
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -32,7 +32,7 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { ...@@ -32,7 +32,7 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
} else { } else {
ret = opr->get_algorithm_info_heuristic( ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
false).desc; AlgoAttribute::DEFAULT).desc;
} }
return static_cast<typename Opr::AlgoBase*>( return static_cast<typename Opr::AlgoBase*>(
opr->get_algorithm_from_desc(ret)); opr->get_algorithm_from_desc(ret));
...@@ -51,7 +51,7 @@ typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { ...@@ -51,7 +51,7 @@ typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) {
return static_cast<typename Opr::AlgoBase*>( return static_cast<typename Opr::AlgoBase*>(
opr->get_algorithm_heuristic(std::forward<Args>(args)..., opr->get_algorithm_heuristic(std::forward<Args>(args)...,
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
false)); AlgoAttribute::DEFAULT));
} }
} }
...@@ -74,37 +74,34 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms( ...@@ -74,37 +74,34 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms(
} }
/*! /*!
* \brief a helper function to get a reproducible algorithm. If require a * \brief a helper function to get an algorithm with attribute. If require a
* reproducible algorithm, and the given algorithm is reproducible, return the * algorithm with specified attribute, and the given algorithm has that
* given algorithm. Otherwise return nullptr * attribute, return the given algorithm. Otherwise return nullptr
*/ */
template <typename Opr> template <typename Opr>
typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo, typename Opr::Algorithm* get_algo_with_attribute(typename Opr::AlgoBase* algo,
bool reproducible) { const AlgoAttribute& attr) {
if (reproducible) { if (algo->contain_attribute(attr)) {
if (algo->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
return algo;
}
} else {
return algo; return algo;
} }
return nullptr; return nullptr;
} }
template <typename Opr> template <typename Opr>
typename Opr::Algorithm* get_reproducible_algo( typename Opr::Algorithm* get_algo_with_attribute(
const std::vector<typename Opr::AlgoBase*>& algos, const std::vector<typename Opr::AlgoBase*>& algos,
const typename Opr::AlgoBase::SizeArgs& args, const typename Opr::AlgoBase::SizeArgs& args,
size_t workspace_limit_in_bytes, const char* name) { size_t workspace_limit_in_bytes, const char* name,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) {
size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max(); size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max();
bool available_but_limited_by_workspace = false; bool available_but_limited_by_workspace = false;
bool available_but_not_reproducible = false; bool available_but_without_attribute = false;
for (auto i : algos) { for (auto i : algos) {
if (i->is_available_reproducible(args, true, if (i->is_available_attribute(args, attr,
workspace_limit_in_bytes)) { workspace_limit_in_bytes)) {
return i; return i;
} }
if (i->is_available_reproducible(args)) { if (i->is_available_attribute(args)) {
if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) {
available_but_limited_by_workspace = true; available_but_limited_by_workspace = true;
min_workspace_limit_in_bytes = min_workspace_limit_in_bytes =
...@@ -113,20 +110,22 @@ typename Opr::Algorithm* get_reproducible_algo( ...@@ -113,20 +110,22 @@ typename Opr::Algorithm* get_reproducible_algo(
} }
} }
if (i->is_available(args)) { if (i->is_available(args)) {
if (!i->contain_attribute(AlgoAttribute::REPRODUCIBLE)) if (!i->contain_attribute(attr))
available_but_not_reproducible = true; available_but_without_attribute = true;
} }
} }
MEGDNN_MARK_USED_VAR(name); MEGDNN_MARK_USED_VAR(name);
if (available_but_limited_by_workspace) { if (available_but_limited_by_workspace) {
megdnn_throw(ssprintf( megdnn_throw(ssprintf(
"no reproducible %s algorithm: %s workspace limit %zu is " "no %s algorithm with attribute:%s : %s workspace limit %zu is "
"less than mini workspace limit %zu", "less than mini workspace limit %zu",
name, args.to_string().c_str(), workspace_limit_in_bytes, name, Algorithm::attribute_str(attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes,
min_workspace_limit_in_bytes)); min_workspace_limit_in_bytes));
} else if (available_but_not_reproducible) { } else if (available_but_without_attribute) {
megdnn_throw(ssprintf("no reproducible %s algorithm", name)); megdnn_throw(ssprintf("no %s algorithm with attribute:%s", name,
Algorithm::attribute_str(attr).c_str()));
} else { } else {
megdnn_throw(ssprintf("no usable %s algorithm", name)); megdnn_throw(ssprintf("no usable %s algorithm", name));
} }
......
...@@ -65,12 +65,11 @@ public: ...@@ -65,12 +65,11 @@ public:
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
......
...@@ -22,21 +22,21 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -22,21 +22,21 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, src, filter, bias, z, dst); AlgoBase::SizeArgs args(this, src, filter, bias, z, dst);
if (sm_algo_pack.int8_nchw4_gemm_dotprod.is_available_reproducible( if (sm_algo_pack.int8_nchw4_gemm_dotprod.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.int8_nchw4_gemm_dotprod; return &sm_algo_pack.int8_nchw4_gemm_dotprod;
} }
if (sm_algo_pack.int8_nchw4_implicit_gemm_dotprod.is_available_reproducible( if (sm_algo_pack.int8_nchw4_implicit_gemm_dotprod.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.int8_nchw4_implicit_gemm_dotprod; return &sm_algo_pack.int8_nchw4_implicit_gemm_dotprod;
} }
megdnn_throw( megdnn_throw(ssprintf(
ssprintf("no %s batch conv bias algorithm with args(%s) and " "no batch conv bias algorithm with attribute%s args(%s) and "
"workspace limit (%zu bytes)", "workspace limit (%zu bytes)",
reproducible ? "reproducible" : "usable", Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(),
args.to_string().c_str(), workspace_limit_in_bytes)); workspace_limit_in_bytes));
} }
std::vector<BatchConvBiasForwardImpl::Algorithm*> std::vector<BatchConvBiasForwardImpl::Algorithm*>
......
...@@ -48,7 +48,7 @@ protected: ...@@ -48,7 +48,7 @@ protected:
const TensorLayout& z, const TensorLayout& z,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
......
...@@ -68,12 +68,11 @@ public: ...@@ -68,12 +68,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -55,24 +55,21 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms( ...@@ -55,24 +55,21 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms(
Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
AlgoBase::SizeArgs args(this, A, B, C); AlgoBase::SizeArgs args(this, A, B, C);
if (sm_algo_pack.cublas.is_available_reproducible(args, reproducible)) { if (sm_algo_pack.cublas.is_available_attribute(args, attr)) {
return &sm_algo_pack.cublas; return &sm_algo_pack.cublas;
} }
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
else if (sm_algo_pack.cublasLt.is_available_reproducible(args, else if (sm_algo_pack.cublasLt.is_available_attribute(args, attr)) {
reproducible)) {
return &sm_algo_pack.cublasLt; return &sm_algo_pack.cublasLt;
} }
#endif #endif
else if (sm_algo_pack.int8x8x32.is_available_reproducible(args, else if (sm_algo_pack.int8x8x32.is_available_attribute(args, attr)) {
reproducible)) {
return &sm_algo_pack.int8x8x32; return &sm_algo_pack.int8x8x32;
} else { } else {
if (sm_algo_pack.brute_force.is_available_reproducible(args, if (sm_algo_pack.brute_force.is_available_attribute(args, attr)) {
reproducible)) {
return &sm_algo_pack.brute_force; return &sm_algo_pack.brute_force;
} }
} }
......
...@@ -49,7 +49,7 @@ protected: ...@@ -49,7 +49,7 @@ protected:
const TensorLayout& B, const TensorLayout& B,
const TensorLayout& C, const TensorLayout& C,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
......
...@@ -127,12 +127,11 @@ public: ...@@ -127,12 +127,11 @@ public:
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
......
...@@ -51,7 +51,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -51,7 +51,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
using namespace conv_bias; using namespace conv_bias;
AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
auto dst_layout = *args.dst_layout; auto dst_layout = *args.dst_layout;
...@@ -74,7 +74,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -74,7 +74,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
}; };
auto get_cudnn_algo = auto get_cudnn_algo =
[this, &conv_args, &args, workspace_limit_in_bytes, reproducible]( [this, &conv_args, &args, workspace_limit_in_bytes, attr](
const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>& const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>&
cb) -> AlgoBase* { cb) -> AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle()); auto cudnn_handle = cuda::cudnn_handle(this->handle());
...@@ -92,8 +92,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -92,8 +92,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
&ret_count, algo_perf.data())); &ret_count, algo_perf.data()));
for (int i = 0; i < ret_count; ++i) { for (int i = 0; i < ret_count; ++i) {
auto conv_bias_algo = cb(algo_perf[i].algo); auto conv_bias_algo = cb(algo_perf[i].algo);
if (conv_bias_algo->is_available_reproducible( if (conv_bias_algo->is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) args, attr, workspace_limit_in_bytes))
return conv_bias_algo; return conv_bias_algo;
} }
#else #else
...@@ -105,18 +105,18 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -105,18 +105,18 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
workspace_limit_in_bytes, &algo)); workspace_limit_in_bytes, &algo));
auto conv_bias_algo = cb(algo); auto conv_bias_algo = cb(algo);
if (conv_bias_algo->is_available_reproducible(args, reproducible, if (conv_bias_algo->is_available_attribute(args, attr,
workspace_limit_in_bytes)) workspace_limit_in_bytes))
return conv_bias_algo; return conv_bias_algo;
#endif #endif
return nullptr; return nullptr;
}; };
auto get_1x1_algo = [workspace_limit_in_bytes, auto get_1x1_algo = [workspace_limit_in_bytes,
reproducible](const AlgoBase::SizeArgs& size_arg) attr](const AlgoBase::SizeArgs& size_arg)
-> ConvBiasForwardImpl::AlgoBase* { -> ConvBiasForwardImpl::AlgoBase* {
if (sm_algo_pack.batched_matmul.is_available_reproducible( if (sm_algo_pack.batched_matmul.is_available_attribute(
size_arg, reproducible, workspace_limit_in_bytes)) { size_arg, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batched_matmul; return &sm_algo_pack.batched_matmul;
} }
return nullptr; return nullptr;
...@@ -144,11 +144,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -144,11 +144,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
//! avoid bad case in cudnn, check dnn chanwise impl first //! avoid bad case in cudnn, check dnn chanwise impl first
if (is_chanwise) { if (is_chanwise) {
if (prefer_dnn_chanwise) { if (prefer_dnn_chanwise) {
if (sm_algo_pack.chanwise.is_available_reproducible( if (sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) args, attr, workspace_limit_in_bytes))
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
if (sm_algo_pack.chanwise8x8x32.is_available_reproducible( if (sm_algo_pack.chanwise8x8x32.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) args, attr, workspace_limit_in_bytes))
return &sm_algo_pack.chanwise8x8x32; return &sm_algo_pack.chanwise8x8x32;
} else { } else {
conv_args.dst_layout = &dst_layout; conv_args.dst_layout = &dst_layout;
...@@ -163,8 +163,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -163,8 +163,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
//! Prefer CUDNN CONVBIAS. //! Prefer CUDNN CONVBIAS.
bool cudnn_conv_bias_act_supported = false; bool cudnn_conv_bias_act_supported = false;
for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) { for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) {
if (algo.is_available_reproducible(args, reproducible, if (algo.is_available_attribute(args, attr, workspace_limit_in_bytes)) {
workspace_limit_in_bytes)) {
cudnn_conv_bias_act_supported = true; cudnn_conv_bias_act_supported = true;
break; break;
} }
...@@ -201,26 +200,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -201,26 +200,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
return algo; return algo;
} }
if (sm_algo_pack.fallback_nchw_qs8.is_available_reproducible( if (sm_algo_pack.fallback_nchw_qs8.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.fallback_nchw_qs8; return &sm_algo_pack.fallback_nchw_qs8;
} }
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( return megdnn::get_algo_with_attribute<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd"); workspace_limit_in_bytes, "cuda convbias fwd", attr);
} else { } else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>( return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd"); workspace_limit_in_bytes, "cuda convbias fwd");
} }
} else { } else {
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( return megdnn::get_algo_with_attribute<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd"); "cuda convbias fwd", attr);
} else { } else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>( return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
......
...@@ -82,7 +82,7 @@ public: ...@@ -82,7 +82,7 @@ public:
const TensorLayout& z, const TensorLayout& z,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
......
...@@ -82,12 +82,11 @@ public: ...@@ -82,12 +82,11 @@ public:
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
......
...@@ -78,12 +78,11 @@ public: ...@@ -78,12 +78,11 @@ public:
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
......
...@@ -63,13 +63,13 @@ public: ...@@ -63,13 +63,13 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) const { bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
size_t limit = std::numeric_limits<size_t>::max()) const { const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
return (!reproducible || size_t limit = std::numeric_limits<size_t>::max()) {
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && return contain_attribute(attr) && is_available_wk(args, limit);
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
auto req = get_workspace_in_bytes(args); auto req = get_workspace_in_bytes(args);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "src/cuda/convolution/opr_impl.h" #include "src/cuda/convolution/opr_impl.h"
#include "megdnn/dtype.h" #include "megdnn/dtype.h"
#include "src/common/algo_chooser.h"
#include "src/cuda/convolution/helper.h" #include "src/cuda/convolution/helper.h"
#include "src/cuda/convolution/forward/algos.h" #include "src/cuda/convolution/forward/algos.h"
#include "src/cuda/convolution/backward_data/algo.h" #include "src/cuda/convolution/backward_data/algo.h"
...@@ -36,10 +37,10 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, ...@@ -36,10 +37,10 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args{this, src, filter, dst}; AlgoBase::SizeArgs args{this, src, filter, dst};
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
MEGDNN_MARK_USED_VAR(reproducible); MEGDNN_MARK_USED_VAR(attr);
return &sm_algo_pack.algo_default; return &sm_algo_pack.algo_default;
} }
...@@ -100,32 +101,32 @@ ConvolutionBackwardDataImpl::Algorithm* ...@@ -100,32 +101,32 @@ ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = check_layout_fwd(grad, filter, diff); auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(filter, fm, diff, grad, return get_algorithm_heuristic(filter, fm, diff, grad,
workspace_limit_in_bytes, reproducible); workspace_limit_in_bytes, attr);
} }
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
const TensorLayout& diff, const TensorLayout& grad, const TensorLayout& grad, size_t workspace_limit_in_bytes,
size_t workspace_limit_in_bytes, bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad);
if (args.filter_meta.group > 1 && if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible( sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
// prefer special chanwise impl // prefer special chanwise impl
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
} }
if (args.filter_layout->dtype.enumv() == if (args.filter_layout->dtype.enumv() ==
DTypeTrait<dtype::QuantizedS8>::enumv) { DTypeTrait<dtype::QuantizedS8>::enumv) {
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, sm_algo_pack.int8_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data"); "cuda conv bwd_data", attr);
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, sm_algo_pack.int8_algos, args, workspace_limit_in_bytes,
...@@ -133,9 +134,8 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( ...@@ -133,9 +134,8 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
} }
} }
auto get_cudnn_algo = auto get_cudnn_algo = [this, &args, workspace_limit_in_bytes,
[this, &args, workspace_limit_in_bytes, attr]() -> ConvolutionBackwardDataImpl::AlgoBase* {
reproducible]() -> ConvolutionBackwardDataImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle()); auto cudnn_handle = cuda::cudnn_handle(this->handle());
CUDNNBwdDataDescs desc; CUDNNBwdDataDescs desc;
args.init_desc(desc); args.init_desc(desc);
...@@ -153,7 +153,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( ...@@ -153,7 +153,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
for (int i = 0; i < ret_count; ++i) { for (int i = 0; i < ret_count; ++i) {
if (algo_perf[i].memory > workspace_limit_in_bytes) if (algo_perf[i].memory > workspace_limit_in_bytes)
continue; continue;
if (reproducible) { if (attr & AlgoAttribute::REPRODUCIBLE) {
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
return reinterpret_cast<AlgoBase*>( return reinterpret_cast<AlgoBase*>(
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
...@@ -174,8 +174,8 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( ...@@ -174,8 +174,8 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
auto&& cast_algo = auto&& cast_algo =
reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo));
return reinterpret_cast<AlgoBase*>( return reinterpret_cast<AlgoBase*>(
megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
cast_algo, reproducible)); cast_algo, attr));
#endif #endif
}; };
...@@ -197,20 +197,20 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( ...@@ -197,20 +197,20 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
if (args.filter_layout->dtype.enumv() != if (args.filter_layout->dtype.enumv() !=
DTypeTrait<dtype::BFloat16>::enumv) { DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data"); workspace_limit_in_bytes, "cuda conv bwd_data", attr);
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data"); workspace_limit_in_bytes, "cuda conv bwd_data");
} }
} else { } else {
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data"); "cuda conv bwd_data", attr);
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
...@@ -255,29 +255,29 @@ ConvolutionBackwardFilterImpl::Algorithm* ...@@ -255,29 +255,29 @@ ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = check_layout_fwd(src, grad, diff); auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, grad, fm, return get_algorithm_heuristic(src, diff, grad, fm,
workspace_limit_in_bytes, reproducible); workspace_limit_in_bytes, attr);
} }
ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta);
if (args.grad_filter_meta.group > 1 && if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible( sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
// prefer special chanwise impl // prefer special chanwise impl
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
} }
auto get_cudnn_algo = auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes, [this, &args, workspace_limit_in_bytes,
reproducible]() -> ConvolutionBackwardFilterImpl::AlgoBase* { attr]() -> ConvolutionBackwardFilterImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle()); auto cudnn_handle = cuda::cudnn_handle(this->handle());
CUDNNBwdFilterDescs desc; CUDNNBwdFilterDescs desc;
args.init_desc(desc); args.init_desc(desc);
...@@ -305,7 +305,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ...@@ -305,7 +305,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
for (int i = 0; i < ret_count; ++i) { for (int i = 0; i < ret_count; ++i) {
if (algo_perf[i].memory > workspace_limit_in_bytes) if (algo_perf[i].memory > workspace_limit_in_bytes)
continue; continue;
if (reproducible) { if (attr & AlgoAttribute::REPRODUCIBLE) {
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
return reinterpret_cast<AlgoBase*>( return reinterpret_cast<AlgoBase*>(
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
...@@ -326,8 +326,8 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ...@@ -326,8 +326,8 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
auto&& cast_algo = auto&& cast_algo =
reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo));
return reinterpret_cast<AlgoBase*>( return reinterpret_cast<AlgoBase*>(
megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>(
cast_algo, reproducible)); cast_algo, attr));
#endif #endif
}; };
...@@ -348,20 +348,22 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ...@@ -348,20 +348,22 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
} }
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( return megdnn::get_algo_with_attribute<
ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter"); workspace_limit_in_bytes, "cuda conv bwd_filter", attr);
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter"); workspace_limit_in_bytes, "cuda conv bwd_filter");
} }
} else { } else {
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( return megdnn::get_algo_with_attribute<
ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter"); "cuda conv bwd_filter", attr);
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
......
...@@ -63,7 +63,7 @@ protected: ...@@ -63,7 +63,7 @@ protected:
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
...@@ -77,9 +77,9 @@ public: ...@@ -77,9 +77,9 @@ public:
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad, const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
return get_algorithm_heuristic(filter, filter_meta, diff, grad, return get_algorithm_heuristic(filter, filter_meta, diff, grad,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
...@@ -87,9 +87,9 @@ public: ...@@ -87,9 +87,9 @@ public:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
return get_algorithm_heuristic(filter, diff, grad, return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
...@@ -122,7 +122,7 @@ protected: ...@@ -122,7 +122,7 @@ protected:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
...@@ -130,7 +130,7 @@ private: ...@@ -130,7 +130,7 @@ private:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
...@@ -146,9 +146,9 @@ public: ...@@ -146,9 +146,9 @@ public:
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
return get_algorithm_heuristic(src, diff, grad, grad_meta, return get_algorithm_heuristic(src, diff, grad, grad_meta,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
...@@ -156,9 +156,9 @@ public: ...@@ -156,9 +156,9 @@ public:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
return get_algorithm_heuristic(filter, diff, grad, return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
...@@ -185,7 +185,7 @@ protected: ...@@ -185,7 +185,7 @@ protected:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src, Algorithm* get_algorithm_heuristic(const TensorLayout& src,
...@@ -193,7 +193,7 @@ private: ...@@ -193,7 +193,7 @@ private:
const TensorLayout& grad, const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta, const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
......
...@@ -75,12 +75,11 @@ public: ...@@ -75,12 +75,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -69,12 +69,11 @@ public: ...@@ -69,12 +69,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -74,12 +74,11 @@ public: ...@@ -74,12 +74,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -97,8 +97,8 @@ namespace convolution3d { ...@@ -97,8 +97,8 @@ namespace convolution3d {
const cudnnConvolutionDescriptor_t conv_desc, const cudnnConvolutionDescriptor_t conv_desc,
const cudnnTensorDescriptor_t y_desc, const cudnnTensorDescriptor_t y_desc,
size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo, size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo,
bool reproducible) { const AlgoAttribute& attr) {
MEGDNN_MARK_USED_VAR(reproducible); MEGDNN_MARK_USED_VAR(attr);
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
int algo_max_count = 0; int algo_max_count = 0;
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(
...@@ -118,7 +118,7 @@ namespace convolution3d { ...@@ -118,7 +118,7 @@ namespace convolution3d {
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, cudnn_handle, x_desc, w_desc, conv_desc, y_desc,
algo_perf[i].algo, &workspace_size)); algo_perf[i].algo, &workspace_size));
if (workspace_size > workspace_limit_in_bytes) continue; if (workspace_size > workspace_limit_in_bytes) continue;
if (!reproducible) { if (!(attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo; *algo = algo_perf[i].algo;
return true; return true;
} else { } else {
...@@ -144,8 +144,8 @@ namespace convolution3d { ...@@ -144,8 +144,8 @@ namespace convolution3d {
const cudnnConvolutionDescriptor_t conv_desc, const cudnnConvolutionDescriptor_t conv_desc,
const cudnnTensorDescriptor_t dx_desc, const cudnnTensorDescriptor_t dx_desc,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
cudnnConvolutionBwdDataAlgo_t* algo, bool reproducible) { cudnnConvolutionBwdDataAlgo_t* algo, const AlgoAttribute& attr) {
MEGDNN_MARK_USED_VAR(reproducible); MEGDNN_MARK_USED_VAR(attr);
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
int algo_max_count = 0; int algo_max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(
...@@ -166,7 +166,7 @@ namespace convolution3d { ...@@ -166,7 +166,7 @@ namespace convolution3d {
cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc,
algo_perf[i].algo, &workspace_size)); algo_perf[i].algo, &workspace_size));
if (workspace_size > workspace_limit_in_bytes) continue; if (workspace_size > workspace_limit_in_bytes) continue;
if (!reproducible) { if (!(attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo; *algo = algo_perf[i].algo;
return true; return true;
} else { } else {
...@@ -193,8 +193,8 @@ namespace convolution3d { ...@@ -193,8 +193,8 @@ namespace convolution3d {
const cudnnConvolutionDescriptor_t conv_desc, const cudnnConvolutionDescriptor_t conv_desc,
const cudnnFilterDescriptor_t dw_desc, const cudnnFilterDescriptor_t dw_desc,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
cudnnConvolutionBwdFilterAlgo_t* algo, bool reproducible) { cudnnConvolutionBwdFilterAlgo_t* algo, const AlgoAttribute& attr) {
MEGDNN_MARK_USED_VAR(reproducible); MEGDNN_MARK_USED_VAR(attr);
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
int algo_max_count = 0; int algo_max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(
...@@ -207,14 +207,15 @@ namespace convolution3d { ...@@ -207,14 +207,15 @@ namespace convolution3d {
algo_max_count, &algo_count, algo_perf.data())); algo_max_count, &algo_count, algo_perf.data()));
for (int i = 0; i < algo_count; ++i) { for (int i = 0; i < algo_count; ++i) {
if (algo_perf[i].algo == if (algo_perf[i].algo ==
cudnnConvolutionBwdFilterAlgo_t::CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) cudnnConvolutionBwdFilterAlgo_t::
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING)
continue; continue;
size_t workspace_size = 0; size_t workspace_size = 0;
cudnn_check(cudnnGetConvolutionBackwardFilterWorkspaceSize( cudnn_check(cudnnGetConvolutionBackwardFilterWorkspaceSize(
cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc,
algo_perf[i].algo, &workspace_size)); algo_perf[i].algo, &workspace_size));
if (workspace_size > workspace_limit_in_bytes) continue; if (workspace_size > workspace_limit_in_bytes) continue;
if (!reproducible) { if (!(attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo; *algo = algo_perf[i].algo;
return true; return true;
} else { } else {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "./forward/algo.h" #include "./forward/algo.h"
#include "./helper.h" #include "./helper.h"
#include "src/common/algo_chooser.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
using namespace megdnn; using namespace megdnn;
...@@ -32,16 +33,16 @@ Convolution3DForwardImpl::Algorithm* ...@@ -32,16 +33,16 @@ Convolution3DForwardImpl::Algorithm*
Convolution3DForwardImpl::get_algorithm_heuristic( Convolution3DForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = check_layout_fwd(src, filter, dst); auto fm = check_layout_fwd(src, filter, dst);
return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes,
reproducible); attr);
} }
Convolution3DForwardImpl::Algorithm* Convolution3DForwardImpl::Algorithm*
Convolution3DForwardImpl::get_algorithm_heuristic( Convolution3DForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const CanonizedFilterMeta& filter, const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, src, filter, dst); AlgoBase::SizeArgs args(this, src, filter, dst);
#if CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5) #if CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5)
...@@ -49,26 +50,26 @@ Convolution3DForwardImpl::get_algorithm_heuristic( ...@@ -49,26 +50,26 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
// prefer special chanwise impl since as the group conv of cudnn whose // prefer special chanwise impl since as the group conv of cudnn whose
// version is lower than v7.5.0 is still slower than our implementation // version is lower than v7.5.0 is still slower than our implementation
// in many channel-wise cases // in many channel-wise cases
if (sm_algo_pack.chanwise.is_available_reproducible( if (sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
} }
} }
#endif #endif
auto prefer_1x1x1 = [&args, reproducible, workspace_limit_in_bytes]() { auto prefer_1x1x1 = [&args, attr, workspace_limit_in_bytes]() {
const size_t MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO = 4; const size_t MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO = 4;
size_t batch_size = args.src_layout->shape[0]; size_t batch_size = args.src_layout->shape[0];
if (batch_size > MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO) { if (batch_size > MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO) {
return false; return false;
} }
return sm_algo_pack.a1x1x1.is_available_reproducible( return sm_algo_pack.a1x1x1.is_available_attribute(
args, reproducible, workspace_limit_in_bytes); args, attr, workspace_limit_in_bytes);
}; };
auto get_cudnn_algo = auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes, [this, &args, workspace_limit_in_bytes,
reproducible]() -> Convolution3DForwardImpl::AlgoBase* { attr]() -> Convolution3DForwardImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle()); auto cudnn_handle = cuda::cudnn_handle(this->handle());
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
CUDNNForwardDescs desc; CUDNNForwardDescs desc;
...@@ -77,11 +78,11 @@ Convolution3DForwardImpl::get_algorithm_heuristic( ...@@ -77,11 +78,11 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
bool got = cudnn_get_convolution_fwd_algo_helper( bool got = cudnn_get_convolution_fwd_algo_helper(
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc,
desc.conv_desc.desc, desc.dst_desc.desc, desc.conv_desc.desc, desc.dst_desc.desc,
workspace_limit_in_bytes, &algo, reproducible); workspace_limit_in_bytes, &algo, attr);
if (got) { if (got) {
return static_cast<AlgoBase*>( return static_cast<AlgoBase*>(
megdnn::get_reproducible_algo<Convolution3DForwardImpl>( megdnn::get_algo_with_attribute<Convolution3DForwardImpl>(
sm_algo_pack.cudnn_from_enum(algo), reproducible)); sm_algo_pack.cudnn_from_enum(algo), attr));
} else { } else {
return nullptr; return nullptr;
} }
...@@ -107,10 +108,10 @@ Convolution3DForwardImpl::get_algorithm_heuristic( ...@@ -107,10 +108,10 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
args = orig_args; args = orig_args;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<Convolution3DForwardImpl>( return megdnn::get_algo_with_attribute<Convolution3DForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d fwd"); "cuda conv3d fwd", attr);
} else { } else {
return megdnn::get_usable_algo<Convolution3DForwardImpl>( return megdnn::get_usable_algo<Convolution3DForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
...@@ -168,28 +169,28 @@ Convolution3DBackwardDataImpl::Algorithm* ...@@ -168,28 +169,28 @@ Convolution3DBackwardDataImpl::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_heuristic( Convolution3DBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = check_layout_fwd(grad, filter, diff); auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes,
reproducible); attr);
} }
Convolution3DBackwardDataImpl::Algorithm* Convolution3DBackwardDataImpl::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_heuristic( Convolution3DBackwardDataImpl::get_algorithm_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff, const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, filter, diff, grad); AlgoBase::SizeArgs args(this, filter, diff, grad);
if (args.filter_meta.group > 1 && if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible( sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
} }
auto get_cudnn_algo = auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes, [this, &args, workspace_limit_in_bytes,
reproducible]() -> Convolution3DBackwardDataImpl::AlgoBase* { attr]() -> Convolution3DBackwardDataImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle()); auto cudnn_handle = cuda::cudnn_handle(this->handle());
cudnnConvolutionBwdDataAlgo_t algo; cudnnConvolutionBwdDataAlgo_t algo;
CUDNNBwdDataDescs desc; CUDNNBwdDataDescs desc;
...@@ -197,11 +198,11 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( ...@@ -197,11 +198,11 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
bool got = cudnn_get_convolution_bwd_data_algo_helper( bool got = cudnn_get_convolution_bwd_data_algo_helper(
cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc,
desc.conv_desc.desc, desc.grad_desc.desc, desc.conv_desc.desc, desc.grad_desc.desc,
workspace_limit_in_bytes, &algo, reproducible); workspace_limit_in_bytes, &algo, attr);
if (got) { if (got) {
return static_cast<AlgoBase*>(megdnn::get_reproducible_algo< return static_cast<AlgoBase*>(megdnn::get_algo_with_attribute<
Convolution3DBackwardDataImpl>( Convolution3DBackwardDataImpl>(
sm_algo_pack.cudnn_from_enum(algo), reproducible)); sm_algo_pack.cudnn_from_enum(algo), attr));
} else { } else {
return nullptr; return nullptr;
} }
...@@ -223,10 +224,10 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( ...@@ -223,10 +224,10 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
args = orig_args; args = orig_args;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<Convolution3DBackwardDataImpl>( return megdnn::get_algo_with_attribute<Convolution3DBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d bwd data"); "cuda conv3d bwd data", attr);
} else { } else {
return megdnn::get_usable_algo<Convolution3DBackwardDataImpl>( return megdnn::get_usable_algo<Convolution3DBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
...@@ -268,28 +269,28 @@ Convolution3DBackwardFilterImpl::Algorithm* ...@@ -268,28 +269,28 @@ Convolution3DBackwardFilterImpl::Algorithm*
Convolution3DBackwardFilterImpl::get_algorithm_heuristic( Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = check_layout_fwd(src, grad, diff); auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes,
reproducible); attr);
} }
Convolution3DBackwardFilterImpl::Algorithm* Convolution3DBackwardFilterImpl::Algorithm*
Convolution3DBackwardFilterImpl::get_algorithm_heuristic( Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, src, diff, grad); AlgoBase::SizeArgs args(this, src, diff, grad);
if (args.grad_filter_meta.group > 1 && if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible( sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
} }
auto get_cudnn_algo = auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes, [this, &args, workspace_limit_in_bytes,
reproducible]() -> Convolution3DBackwardFilterImpl::AlgoBase* { attr]() -> Convolution3DBackwardFilterImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle()); auto cudnn_handle = cuda::cudnn_handle(this->handle());
cudnnConvolutionBwdFilterAlgo_t algo; cudnnConvolutionBwdFilterAlgo_t algo;
CUDNNBwdFilterDescs desc; CUDNNBwdFilterDescs desc;
...@@ -297,11 +298,11 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( ...@@ -297,11 +298,11 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
bool got = cudnn_get_convolution_bwd_filter_algo_helper( bool got = cudnn_get_convolution_bwd_filter_algo_helper(
cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc,
desc.conv_desc.desc, desc.grad_desc.desc, desc.conv_desc.desc, desc.grad_desc.desc,
workspace_limit_in_bytes, &algo, reproducible); workspace_limit_in_bytes, &algo, attr);
if (got) { if (got) {
return static_cast<AlgoBase*>(megdnn::get_reproducible_algo< return static_cast<AlgoBase*>(megdnn::get_algo_with_attribute<
Convolution3DBackwardFilterImpl>( Convolution3DBackwardFilterImpl>(
sm_algo_pack.cudnn_from_enum(algo), reproducible)); sm_algo_pack.cudnn_from_enum(algo), attr));
} else { } else {
return nullptr; return nullptr;
} }
...@@ -322,10 +323,10 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( ...@@ -322,10 +323,10 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
args = orig_args; args = orig_args;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<Convolution3DBackwardFilterImpl>( return megdnn::get_algo_with_attribute<Convolution3DBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d bwd filter"); "cuda conv3d bwd filter", attr);
} else { } else {
return megdnn::get_usable_algo<Convolution3DBackwardFilterImpl>( return megdnn::get_usable_algo<Convolution3DBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
......
...@@ -25,9 +25,9 @@ public: ...@@ -25,9 +25,9 @@ public:
const CanonizedFilterMeta& filter, const CanonizedFilterMeta& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
return get_algorithm_heuristic(src, filter, dst, return get_algorithm_heuristic(src, filter, dst,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src,
...@@ -52,14 +52,14 @@ protected: ...@@ -52,14 +52,14 @@ protected:
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src, Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const CanonizedFilterMeta& filter, const CanonizedFilterMeta& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
...@@ -73,9 +73,9 @@ public: ...@@ -73,9 +73,9 @@ public:
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff, const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
return get_algorithm_heuristic(filter, diff, grad, return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
size_t get_workspace_in_bytes(const TensorLayout& filter, size_t get_workspace_in_bytes(const TensorLayout& filter,
...@@ -102,14 +102,14 @@ protected: ...@@ -102,14 +102,14 @@ protected:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
...@@ -126,9 +126,9 @@ public: ...@@ -126,9 +126,9 @@ public:
const TensorLayout& diff, const TensorLayout& diff,
const CanonizedFilterMeta& grad, const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
return get_algorithm_heuristic(src, diff, grad, return get_algorithm_heuristic(src, diff, grad,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
...@@ -153,14 +153,14 @@ protected: ...@@ -153,14 +153,14 @@ protected:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src, Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& diff,
const CanonizedFilterMeta& grad, const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
......
...@@ -80,12 +80,11 @@ public: ...@@ -80,12 +80,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -73,12 +73,11 @@ public: ...@@ -73,12 +73,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -68,12 +68,11 @@ public: ...@@ -68,12 +68,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -59,10 +59,10 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, ...@@ -59,10 +59,10 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& mask, const TensorLayout& mask,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = make_canonized_filter_meta(im.ndim, filter, offset); auto fm = make_canonized_filter_meta(im.ndim, filter, offset);
return get_algorithm_heuristic(im, fm, offset, mask, dst, return get_algorithm_heuristic(im, fm, offset, mask, dst,
workspace_limit_in_bytes, reproducible); workspace_limit_in_bytes, attr);
} }
AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im,
...@@ -71,17 +71,17 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, ...@@ -71,17 +71,17 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& mask, const TensorLayout& mask,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, im, filter, offset, mask, dst); AlgoBase::SizeArgs args(this, im, filter, offset, mask, dst);
if (sm_algo_pack.algo_matmul.is_available_reproducible( if (sm_algo_pack.algo_matmul.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_matmul; return &sm_algo_pack.algo_matmul;
} }
megdnn_throw( megdnn_throw(ssprintf(
ssprintf("no %s deformable conv fwd algorithm with args(%s) and " "no deformable conv fwd algorithm with attribute%s , args(%s) and "
"workspace limit (%zu bytes)", "workspace limit (%zu bytes)",
reproducible ? "reproducible" : "usable", Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(),
args.to_string().c_str(), workspace_limit_in_bytes)); workspace_limit_in_bytes));
} }
const char* Fwd::get_algorithm_set_name() const { const char* Fwd::get_algorithm_set_name() const {
...@@ -115,27 +115,28 @@ AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( ...@@ -115,27 +115,28 @@ AlgoBwdFlt* BwdFlt::get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& filter_grad, const TensorLayout& filter_grad,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
auto fm = make_canonized_filter_meta(im.ndim, filter_grad, offset); auto fm = make_canonized_filter_meta(im.ndim, filter_grad, offset);
return get_algorithm_heuristic(im, offset, mask, out_grad, fm, return get_algorithm_heuristic(im, offset, mask, out_grad, fm,
workspace_limit_in_bytes, reproducible); workspace_limit_in_bytes, attr);
} }
AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( AlgoBwdFlt* BwdFlt::get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& mask, const TensorLayout& out_grad,
const CanonizedFilterMeta& filter_grad, const CanonizedFilterMeta& filter_grad,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, im, offset, mask, out_grad, filter_grad); AlgoBase::SizeArgs args(this, im, offset, mask, out_grad, filter_grad);
if (sm_algo_pack.algo_matmul.is_available_reproducible( if (sm_algo_pack.algo_matmul.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_matmul; return &sm_algo_pack.algo_matmul;
} }
megdnn_throw(ssprintf( megdnn_throw(
"no %s deformable conv bwd filter algorithm with args(%s) and " ssprintf("no deformable conv bwd filter algorithm with "
"workspace limit (%zu bytes)", "attribute%s, args(%s) and "
reproducible ? "reproducible" : "usable", args.to_string().c_str(), "workspace limit (%zu bytes)",
workspace_limit_in_bytes)); Algorithm::attribute_str(attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
} }
size_t BwdFlt::get_workspace_in_bytes( size_t BwdFlt::get_workspace_in_bytes(
...@@ -175,11 +176,11 @@ AlgoBwdData* BwdData::get_algorithm_heuristic( ...@@ -175,11 +176,11 @@ AlgoBwdData* BwdData::get_algorithm_heuristic(
const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
auto fm = make_canonized_filter_meta(im.ndim, filter, offset); auto fm = make_canonized_filter_meta(im.ndim, filter, offset);
return get_algorithm_heuristic(im, fm, offset, mask, out_grad, im_grad, return get_algorithm_heuristic(im, fm, offset, mask, out_grad, im_grad,
offset_grad, mask_grad, offset_grad, mask_grad,
workspace_limit_in_bytes, reproducible); workspace_limit_in_bytes, attr);
} }
AlgoBwdData* BwdData::get_algorithm_heuristic( AlgoBwdData* BwdData::get_algorithm_heuristic(
...@@ -187,18 +188,19 @@ AlgoBwdData* BwdData::get_algorithm_heuristic( ...@@ -187,18 +188,19 @@ AlgoBwdData* BwdData::get_algorithm_heuristic(
const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, im, filter, offset, mask, out_grad, im_grad, AlgoBase::SizeArgs args(this, im, filter, offset, mask, out_grad, im_grad,
offset_grad, mask_grad); offset_grad, mask_grad);
if (sm_algo_pack.algo_matmul.is_available_reproducible( if (sm_algo_pack.algo_matmul.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_matmul; return &sm_algo_pack.algo_matmul;
} }
megdnn_throw(ssprintf( megdnn_throw(
"no %s deformable conv bwd data algorithm with args(%s) and " ssprintf("no deformable conv bwd data algorithm with attribute%s, "
"workspace limit (%zu bytes)", "args(%s) and "
reproducible ? "reproducible" : "usable", args.to_string().c_str(), "workspace limit (%zu bytes)",
workspace_limit_in_bytes)); Algorithm::attribute_str(attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
} }
size_t BwdData::get_workspace_in_bytes( size_t BwdData::get_workspace_in_bytes(
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
const TensorLayout& mask, const TensorLayout& mask,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;
...@@ -60,7 +60,7 @@ protected: ...@@ -60,7 +60,7 @@ protected:
const TensorLayout& mask, const TensorLayout& mask,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
...@@ -81,7 +81,7 @@ public: ...@@ -81,7 +81,7 @@ public:
const TensorLayout& out_grad, const TensorLayout& out_grad,
const CanonizedFilterMeta& filter_grad, const CanonizedFilterMeta& filter_grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
size_t get_workspace_in_bytes(const TensorLayout& im, size_t get_workspace_in_bytes(const TensorLayout& im,
const TensorLayout& offset, const TensorLayout& offset,
...@@ -111,7 +111,7 @@ protected: ...@@ -111,7 +111,7 @@ protected:
const TensorLayout& out_grad, const TensorLayout& out_grad,
const TensorLayout& filter_grad, const TensorLayout& filter_grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
...@@ -132,7 +132,7 @@ public: ...@@ -132,7 +132,7 @@ public:
const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, bool reproducible); size_t workspace_limit_in_bytes, const AlgoAttribute& attr);
size_t get_workspace_in_bytes(const TensorLayout& im, size_t get_workspace_in_bytes(const TensorLayout& im,
const TensorLayout& filter, const TensorLayout& filter,
...@@ -166,7 +166,8 @@ protected: ...@@ -166,7 +166,8 @@ protected:
const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, bool reproducible) override; size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
......
...@@ -59,12 +59,11 @@ public: ...@@ -59,12 +59,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -59,12 +59,11 @@ public: ...@@ -59,12 +59,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -60,12 +60,11 @@ public: ...@@ -60,12 +60,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -24,26 +24,26 @@ LocalShareForwardImpl::get_algorithm_heuristic(const TensorLayout& src, ...@@ -24,26 +24,26 @@ LocalShareForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, src, filter, dst); AlgoBase::SizeArgs args(this, src, filter, dst);
if (sm_algo_pack.batch_size_aware_chwn_small_image if (sm_algo_pack.batch_size_aware_chwn_small_image
.is_available_reproducible(args, reproducible, .is_available_attribute(args, attr,
workspace_limit_in_bytes)) { workspace_limit_in_bytes)) {
return &sm_algo_pack.batch_size_aware_chwn_small_image; return &sm_algo_pack.batch_size_aware_chwn_small_image;
} }
if (sm_algo_pack.batch_size_aware_chwn.is_available_reproducible( if (sm_algo_pack.batch_size_aware_chwn.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batch_size_aware_chwn; return &sm_algo_pack.batch_size_aware_chwn;
} }
if (sm_algo_pack.batched_matmul.is_available_reproducible( if (sm_algo_pack.batched_matmul.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batched_matmul; return &sm_algo_pack.batched_matmul;
} }
megdnn_throw( megdnn_throw(ssprintf(
ssprintf("no %s local share conv algorithm with args(%s) and " "no local share conv algorithm with attribute%s, args(%s) and "
"workspace limit (%zu bytes)", "workspace limit (%zu bytes)",
reproducible ? "reproducible" : "usable", Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(),
args.to_string().c_str(), workspace_limit_in_bytes)); workspace_limit_in_bytes));
} }
std::vector<LocalShareForwardImpl::Algorithm*> std::vector<LocalShareForwardImpl::Algorithm*>
...@@ -79,21 +79,21 @@ LocalShareBackwardDataImpl::Algorithm* ...@@ -79,21 +79,21 @@ LocalShareBackwardDataImpl::Algorithm*
LocalShareBackwardDataImpl::get_algorithm_heuristic( LocalShareBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, filter, diff, grad); AlgoBase::SizeArgs args(this, filter, diff, grad);
if (sm_algo_pack.implicit_gemm.is_available_reproducible( if (sm_algo_pack.implicit_gemm.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.implicit_gemm; return &sm_algo_pack.implicit_gemm;
} }
if (sm_algo_pack.batched_matmul.is_available_reproducible( if (sm_algo_pack.batched_matmul.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batched_matmul; return &sm_algo_pack.batched_matmul;
} }
megdnn_throw( megdnn_throw(ssprintf(
ssprintf("no %s local share bwd data algorithm with args(%s) and " "no local share bwd data algorithm with attribute%s args(%s) and "
"workspace limit (%zu bytes)", "workspace limit (%zu bytes)",
reproducible ? "reproducible" : "usable", Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(),
args.to_string().c_str(), workspace_limit_in_bytes)); workspace_limit_in_bytes));
} }
std::vector<LocalShareBackwardDataImpl::Algorithm*> std::vector<LocalShareBackwardDataImpl::Algorithm*>
...@@ -129,20 +129,21 @@ LocalShareBackwardFilterImpl::Algorithm* ...@@ -129,20 +129,21 @@ LocalShareBackwardFilterImpl::Algorithm*
LocalShareBackwardFilterImpl::get_algorithm_heuristic( LocalShareBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, src, diff, grad); AlgoBase::SizeArgs args(this, src, diff, grad);
if (sm_algo_pack.implicit_gemm.is_available_reproducible( if (sm_algo_pack.implicit_gemm.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.implicit_gemm; return &sm_algo_pack.implicit_gemm;
} }
if (sm_algo_pack.batched_matmul.is_available_reproducible( if (sm_algo_pack.batched_matmul.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batched_matmul; return &sm_algo_pack.batched_matmul;
} }
megdnn_throw( megdnn_throw(
ssprintf("no %s local share bwd filter algorithm with args(%s) and " ssprintf("no local share bwd filter algorithm with attribute%s, "
"args(%s) and "
"workspace limit (%zu bytes)", "workspace limit (%zu bytes)",
reproducible ? "reproducible" : "usable", Algorithm::attribute_str(attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes)); args.to_string().c_str(), workspace_limit_in_bytes));
} }
......
...@@ -43,7 +43,7 @@ protected: ...@@ -43,7 +43,7 @@ protected:
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
...@@ -75,7 +75,7 @@ protected: ...@@ -75,7 +75,7 @@ protected:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
...@@ -108,7 +108,7 @@ protected: ...@@ -108,7 +108,7 @@ protected:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
......
...@@ -83,12 +83,11 @@ public: ...@@ -83,12 +83,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) const { bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) const { size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -30,30 +30,30 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, ...@@ -30,30 +30,30 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
AlgoBase::SizeArgs args{this, A, B, C}; AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.cublas.is_available_reproducible( if (sm_algo_pack.cublas.is_available_attribute(args, attr,
args, reproducible, workspace_limit_in_bytes)) { workspace_limit_in_bytes)) {
return &sm_algo_pack.cublas; return &sm_algo_pack.cublas;
} }
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
if (sm_algo_pack.cublas_lt.is_available_reproducible( if (sm_algo_pack.cublas_lt.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.cublas_lt; return &sm_algo_pack.cublas_lt;
} }
#endif #endif
#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000
if (sm_algo_pack.wmma_uint4x4x32.is_available_reproducible( if (sm_algo_pack.wmma_uint4x4x32.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.wmma_uint4x4x32; return &sm_algo_pack.wmma_uint4x4x32;
} }
#endif #endif
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<MatrixMulForwardImpl>( return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward"); "matrix mul forward", attr);
} else { } else {
return megdnn::get_usable_algo<MatrixMulForwardImpl>( return megdnn::get_usable_algo<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
......
...@@ -61,7 +61,7 @@ protected: ...@@ -61,7 +61,7 @@ protected:
const TensorLayout& B, const TensorLayout& B,
const TensorLayout& C, const TensorLayout& C,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
......
...@@ -63,12 +63,11 @@ public: ...@@ -63,12 +63,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) const { bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) const { size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -31,16 +31,16 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, ...@@ -31,16 +31,16 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
BatchedMatrixMulForwardImpl::Algorithm* BatchedMatrixMulForwardImpl::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
AlgoBase::SizeArgs args{this, A, B, C}; AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.algo_default.is_available_reproducible( if (sm_algo_pack.algo_default.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_default; return &sm_algo_pack.algo_default;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( return megdnn::get_algo_with_attribute<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward"); "batched matrix mul forward", attr);
} else { } else {
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
......
...@@ -40,7 +40,7 @@ private: ...@@ -40,7 +40,7 @@ private:
const TensorLayout& /*B*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; const AlgoAttribute& /*attr*/) override;
const char* get_algorithm_set_name() const override { const char* get_algorithm_set_name() const override {
return "FALLBACK BATCHED MATMUL"; return "FALLBACK BATCHED MATMUL";
......
...@@ -280,32 +280,29 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( ...@@ -280,32 +280,29 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
auto result = get_algorithm_heuristic_with_ncb( auto result = get_algorithm_heuristic_with_ncb(
fparam, workspace_limit_in_bytes, reproducible); fparam, workspace_limit_in_bytes, attr);
if (result == nullptr) { if (result == nullptr) {
result = naive::ConvBiasForwardImpl::get_algorithm_heuristic( result = naive::ConvBiasForwardImpl::get_algorithm_heuristic(
src, filter, bias, z, dst, workspace_limit_in_bytes, src, filter, bias, z, dst, workspace_limit_in_bytes, attr);
reproducible);
} }
return result; return result;
} }
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto algo_data_type = param.deduce_algo_data_type(); auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param); auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) { for (auto category : suggest_category_order) {
auto&& origin_algos = select_algo_type({algo_data_type, category}); auto&& origin_algos = select_algo_type({algo_data_type, category});
ConvBiasImpl::Algorithm* heuristic_algo = nullptr; ConvBiasImpl::Algorithm* heuristic_algo = nullptr;
for (auto i : origin_algos) { for (auto i : origin_algos) {
bool usable_reproducible = bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
static_cast<AlgoBase*>(i)->usable_reproducible( param, AlgoSelectionStrategy::HEURISTIC, attr);
param, AlgoSelectionStrategy::HEURISTIC, if (usable_attribute &&
reproducible);
if (usable_reproducible &&
static_cast<AlgoBase*>(i)->get_workspace(param) <= static_cast<AlgoBase*>(i)->get_workspace(param) <=
workspace_limit_in_bytes) { workspace_limit_in_bytes) {
//! store the first usable algo if no prefer algo, choose it as //! store the first usable algo if no prefer algo, choose it as
...@@ -499,8 +496,8 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( ...@@ -499,8 +496,8 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
} }
if (!m_prev_selected_algo || if (!m_prev_selected_algo ||
memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) { memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
m_prev_selected_algo = m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
get_algorithm_heuristic_with_ncb(param, workspace_size); param, workspace_size, AlgoAttribute::DEFAULT);
m_prev_selected_algo_sizep = param; m_prev_selected_algo_sizep = param;
} }
return m_prev_selected_algo; return m_prev_selected_algo;
......
...@@ -95,9 +95,7 @@ public: ...@@ -95,9 +95,7 @@ public:
const TensorLayout& z, const TensorLayout& z,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
//! size param for kernels with non-contiguous batch //! size param for kernels with non-contiguous batch
struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam {
...@@ -321,11 +319,11 @@ public: ...@@ -321,11 +319,11 @@ public:
return false; return false;
} }
bool usable_reproducible(const NCBKernSizeParam& param, bool usable_attribute(
AlgoSelectionStrategy algo_selection_strategy, const NCBKernSizeParam& param,
bool reproducible = true) const { AlgoSelectionStrategy algo_selection_strategy,
return (!reproducible || const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const {
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && return contain_attribute(attr) &&
usable(param, algo_selection_strategy); usable(param, algo_selection_strategy);
} }
...@@ -363,7 +361,7 @@ protected: ...@@ -363,7 +361,7 @@ protected:
virtual Algorithm* get_algorithm_heuristic_with_ncb( virtual Algorithm* get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible = false); const AlgoAttribute& attr);
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;
......
...@@ -198,13 +198,13 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms( ...@@ -198,13 +198,13 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
auto result = get_algorithm_heuristic_with_ncb( auto result = get_algorithm_heuristic_with_ncb(
fparam, workspace_limit_in_bytes, reproducible); fparam, workspace_limit_in_bytes, attr);
if (result == nullptr) { if (result == nullptr) {
result = naive::ConvolutionForwardImpl::get_algorithm_heuristic( result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
src, filter, dst, workspace_limit_in_bytes, reproducible); src, filter, dst, workspace_limit_in_bytes, attr);
} }
return result; return result;
} }
...@@ -312,18 +312,16 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, ...@@ -312,18 +312,16 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto algo_data_type = param.deduce_algo_data_type(); auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param); auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) { for (auto category : suggest_category_order) {
auto&& origin_algos = select_algo_type({algo_data_type, category}); auto&& origin_algos = select_algo_type({algo_data_type, category});
ConvolutionImpl::Algorithm* heuristic_algo = nullptr; ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
for (auto i : origin_algos) { for (auto i : origin_algos) {
bool usable_reproducible = bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
static_cast<AlgoBase*>(i)->usable_reproducible( param, AlgoSelectionStrategy::HEURISTIC, attr);
param, AlgoSelectionStrategy::HEURISTIC, if (usable_attribute &&
reproducible);
if (usable_reproducible &&
static_cast<AlgoBase*>(i)->get_workspace(param) <= static_cast<AlgoBase*>(i)->get_workspace(param) <=
workspace_limit_in_bytes) { workspace_limit_in_bytes) {
//! store the first usable algo if no prefer algo, choose it as //! store the first usable algo if no prefer algo, choose it as
...@@ -392,8 +390,8 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( ...@@ -392,8 +390,8 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
} }
if (!m_prev_selected_algo || if (!m_prev_selected_algo ||
memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) { memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
m_prev_selected_algo = m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
get_algorithm_heuristic_with_ncb(param, workspace_size); param, workspace_size, AlgoAttribute::DEFAULT);
m_prev_selected_algo_sizep = param; m_prev_selected_algo_sizep = param;
} }
return m_prev_selected_algo; return m_prev_selected_algo;
...@@ -515,15 +513,15 @@ ConvolutionBackwardDataImpl::Algorithm* ...@@ -515,15 +513,15 @@ ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
if (param().format == param::Convolution::Format::NHWCD4 || if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4) { param().format == param::Convolution::Format::NCHW4) {
return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
filter, diff, grad, workspace_limit_in_bytes, reproducible); filter, diff, grad, workspace_limit_in_bytes, attr);
} }
auto fparam = make_ncb_kern_size_param(filter, diff, grad); auto fparam = make_ncb_kern_size_param(filter, diff, grad);
return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes, return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
reproducible); attr);
} }
ConvolutionBackwardDataImpl::NCBKernSizeParam ConvolutionBackwardDataImpl::NCBKernSizeParam
...@@ -668,15 +666,15 @@ ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb( ...@@ -668,15 +666,15 @@ ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb( ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
if (param.filter_meta.group != 1) { if (param.filter_meta.group != 1) {
auto p1g = param; auto p1g = param;
p1g.filter_meta.group = 1; p1g.filter_meta.group = 1;
return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes, return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
reproducible); attr);
} }
return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes, return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
reproducible); attr);
} }
size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
...@@ -731,14 +729,10 @@ ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( ...@@ -731,14 +729,10 @@ ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
for (auto i : ncb_1g_get_all_algorithms(param)) { for (auto i : ncb_1g_get_all_algorithms(param)) {
if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) { if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
if (reproducible) { if (i->contain_attribute(attr)) {
if (i->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
return i;
}
} else {
return i; return i;
} }
} }
...@@ -788,7 +782,8 @@ ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { ...@@ -788,7 +782,8 @@ ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
if (!m_prev_selected_algo || if (!m_prev_selected_algo ||
memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) { memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
m_prev_selected_algo = ncb_1g_get_algorithm_heuristic( m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
param, std::numeric_limits<size_t>::max()); param, std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT);
m_prev_selected_algo_sizep = param; m_prev_selected_algo_sizep = param;
} }
return m_prev_selected_algo; return m_prev_selected_algo;
......
...@@ -90,7 +90,7 @@ public: ...@@ -90,7 +90,7 @@ public:
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
//! size param for kernels with non-contiguous batch //! size param for kernels with non-contiguous batch
struct NCBKernSizeParam { struct NCBKernSizeParam {
...@@ -238,11 +238,11 @@ public: ...@@ -238,11 +238,11 @@ public:
return false; return false;
} }
bool usable_reproducible(const NCBKernSizeParam& param, bool usable_attribute(
AlgoSelectionStrategy algo_selection_strategy, const NCBKernSizeParam& param,
bool reproducible = true) const { AlgoSelectionStrategy algo_selection_strategy,
return (!reproducible || const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const {
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && return contain_attribute(attr) &&
usable(param, algo_selection_strategy); usable(param, algo_selection_strategy);
} }
...@@ -272,7 +272,7 @@ protected: ...@@ -272,7 +272,7 @@ protected:
virtual Algorithm* get_algorithm_heuristic_with_ncb( virtual Algorithm* get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible = false); const AlgoAttribute& attr);
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;
...@@ -326,7 +326,7 @@ public: ...@@ -326,7 +326,7 @@ public:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;
//! size param for kernels with non-contiguous batch //! size param for kernels with non-contiguous batch
...@@ -421,12 +421,10 @@ protected: ...@@ -421,12 +421,10 @@ protected:
virtual ncb_kern_t dispatch_kern( virtual ncb_kern_t dispatch_kern(
ConvolutionBackwardDataImpl* opr, ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0; const NCBKernSizeParam& param) const = 0;
bool usable_reproducible(ConvolutionBackwardDataImpl* opr, bool usable_attribute(
const NCBKernSizeParam& param, ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param,
bool reproducible = true) const { const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const {
return (!reproducible || return contain_attribute(attr) && usable(opr, param);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
usable(opr, param);
} }
virtual bool is_preferred(const NCBKernSizeParam&) const { virtual bool is_preferred(const NCBKernSizeParam&) const {
return false; return false;
...@@ -451,7 +449,7 @@ protected: ...@@ -451,7 +449,7 @@ protected:
//! default impl calls ncb_1g_get_algorithm_heuristic() //! default impl calls ncb_1g_get_algorithm_heuristic()
virtual Algorithm* get_algorithm_heuristic_with_ncb( virtual Algorithm* get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible = false); const AlgoAttribute& attr);
//! get kernel pointer for float32 non-contiguous batch 1-group kernel //! get kernel pointer for float32 non-contiguous batch 1-group kernel
virtual ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, virtual ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo,
...@@ -469,7 +467,7 @@ protected: ...@@ -469,7 +467,7 @@ protected:
*/ */
virtual Algorithm* ncb_1g_get_algorithm_heuristic( virtual Algorithm* ncb_1g_get_algorithm_heuristic(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible = false); const AlgoAttribute& attr);
static bool is_matrix_mul_preferred(const NCBKernSizeParam& param); static bool is_matrix_mul_preferred(const NCBKernSizeParam& param);
/** /**
......
...@@ -131,19 +131,20 @@ MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( ...@@ -131,19 +131,20 @@ MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc(
MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
auto kern_size_param = make_kern_size_param(A, B, C); auto kern_size_param = make_kern_size_param(A, B, C);
if (auto algo = static_cast<AlgoBase*>( if (auto algo = static_cast<AlgoBase*>(
get_algorithm_from_desc(execution_policy().algo))) { get_algorithm_from_desc(execution_policy().algo))) {
megdnn_assert(algo->get_workspace(kern_size_param) < megdnn_assert(algo->get_workspace(kern_size_param) <
workspace_limit_in_bytes); workspace_limit_in_bytes);
auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo, auto cur = megdnn::get_algo_with_attribute<MatrixMulImpl>(algo, attr);
reproducible);
if (cur) if (cur)
return cur; return cur;
megdnn_throw( megdnn_throw(ssprintf(
"require reproducible algorithm, but given algorithm is not " "require algorithm with attribute%s, but given algorithm with "
"reproducible"); "attribute%s",
Algorithm::attribute_str(attr).c_str(),
Algorithm::attribute_str(algo->attribute()).c_str()));
} }
AlgoTypePack algo_type; AlgoTypePack algo_type;
algo_type.data_type = kern_size_param.deduce_algo_data_type(); algo_type.data_type = kern_size_param.deduce_algo_data_type();
...@@ -155,8 +156,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( ...@@ -155,8 +156,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) &&
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <=
workspace_limit_in_bytes) { workspace_limit_in_bytes) {
if (static_cast<AlgoBase*>(algo)->preferred_reproducible( if (static_cast<AlgoBase*>(algo)->preferred_attribute(
kern_size_param, reproducible)) { kern_size_param, attr)) {
//! use gemv algo if it's prefered //! use gemv algo if it's prefered
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
return algo; return algo;
...@@ -214,8 +215,9 @@ MatrixMulImpl::KernParam MatrixMulImpl::make_kern_param( ...@@ -214,8 +215,9 @@ MatrixMulImpl::KernParam MatrixMulImpl::make_kern_param(
size_t MatrixMulImpl::get_workspace_in_bytes(const TensorLayout& A, size_t MatrixMulImpl::get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B, const TensorLayout& B,
const TensorLayout& C) { const TensorLayout& C) {
if (auto algo = get_algorithm_heuristic( if (auto algo = get_algorithm_heuristic(A, B, C,
A, B, C, std::numeric_limits<size_t>::max(), false)) { std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT)) {
auto kern_size_param = make_kern_size_param(A, B, C); auto kern_size_param = make_kern_size_param(A, B, C);
return static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param); return static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param);
} }
...@@ -228,7 +230,7 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, ...@@ -228,7 +230,7 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
if (auto algo = get_algorithm_heuristic(A.layout, B.layout, C.layout, if (auto algo = get_algorithm_heuristic(A.layout, B.layout, C.layout,
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
false)) { AlgoAttribute::DEFAULT)) {
auto kern_param = make_kern_param(A, B, C, workspace); auto kern_param = make_kern_param(A, B, C, workspace);
auto kern = static_cast<AlgoBase*>(algo)->get_kern(kern_param); auto kern = static_cast<AlgoBase*>(algo)->get_kern(kern_param);
auto run = [kern, kern_param]() { kern(kern_param); }; auto run = [kern, kern_param]() { kern(kern_param); };
......
...@@ -223,11 +223,10 @@ public: ...@@ -223,11 +223,10 @@ public:
virtual InnerBlockSize get_inner_block_size() const { virtual InnerBlockSize get_inner_block_size() const {
megdnn_assert(0); megdnn_assert(0);
}; };
bool preferred_reproducible(const KernSizeParam& param, bool preferred_attribute(
bool reproducible = true) { const KernSizeParam& param,
return (!reproducible || const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) {
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && return contain_attribute(attr) && preferred(param);
preferred(param);
}; };
virtual MatmulDescription matmul_description() const = 0; virtual MatmulDescription matmul_description() const = 0;
...@@ -272,7 +271,7 @@ protected: ...@@ -272,7 +271,7 @@ protected:
const TensorLayout& B, const TensorLayout& B,
const TensorLayout& C, const TensorLayout& C,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
}; };
......
...@@ -125,16 +125,14 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -125,16 +125,14 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* bias */, const TensorLayout& /* z */, const TensorLayout& /* bias */, const TensorLayout& /* z */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */ const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */
, ,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = static_cast<HandleImpl*>(handle()) auto algo = static_cast<HandleImpl*>(handle())
->default_batch_conv_bias_fwd_algo(); ->default_batch_conv_bias_fwd_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
const TensorLayout& z, const TensorLayout& z,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
......
...@@ -76,7 +76,7 @@ BatchedMatrixMulForward::Algorithm* ...@@ -76,7 +76,7 @@ BatchedMatrixMulForward::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) { const AlgoAttribute& /*attr*/) {
return static_cast<HandleImpl*>(handle()) return static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo(); ->default_batched_matmul_fwd_algo();
} }
......
...@@ -32,7 +32,7 @@ public: ...@@ -32,7 +32,7 @@ public:
const TensorLayout& /*B*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override; const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
......
...@@ -246,16 +246,14 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -246,16 +246,14 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */, const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* bias */, const TensorLayout& /* z */, const TensorLayout& /* bias */, const TensorLayout& /* z */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
const TensorLayout& z, const TensorLayout& z,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
......
...@@ -272,16 +272,14 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &, ...@@ -272,16 +272,14 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &,
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = auto algo =
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); static_cast<HandleImpl*>(handle())->default_conv_fwd_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
...@@ -304,16 +302,14 @@ ConvolutionBackwardData::Algorithm* ...@@ -304,16 +302,14 @@ ConvolutionBackwardData::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */, const TensorLayout& /* filter */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
...@@ -337,16 +333,14 @@ ConvolutionBackwardFilter::Algorithm* ...@@ -337,16 +333,14 @@ ConvolutionBackwardFilter::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
......
...@@ -29,7 +29,7 @@ class ConvolutionForwardImpl: public ConvolutionForward { ...@@ -29,7 +29,7 @@ class ConvolutionForwardImpl: public ConvolutionForward {
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const PreprocessedFilter*) override { const PreprocessedFilter*) override {
...@@ -71,7 +71,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { ...@@ -71,7 +71,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override; const TensorLayout&) override;
...@@ -94,7 +94,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { ...@@ -94,7 +94,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override; const TensorLayout&) override;
......
...@@ -120,15 +120,13 @@ Convolution3DForward::Algorithm* ...@@ -120,15 +120,13 @@ Convolution3DForward::Algorithm*
Convolution3DForwardImpl::get_algorithm_heuristic( Convolution3DForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */, const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
...@@ -152,16 +150,14 @@ Convolution3DBackwardData::Algorithm* ...@@ -152,16 +150,14 @@ Convolution3DBackwardData::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_heuristic( Convolution3DBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */, const TensorLayout& /* filter */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = auto algo =
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
...@@ -187,16 +183,14 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( ...@@ -187,16 +183,14 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */ const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */
, ,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = static_cast<HandleImpl*>(handle()) auto algo = static_cast<HandleImpl*>(handle())
->default_conv3d_bwd_filter_algo(); ->default_conv3d_bwd_filter_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
......
...@@ -26,7 +26,7 @@ public: ...@@ -26,7 +26,7 @@ public:
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override {
return 0; return 0;
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override {
return 0; return 0;
...@@ -70,7 +70,7 @@ public: ...@@ -70,7 +70,7 @@ public:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override {
return 0; return 0;
......
...@@ -32,7 +32,7 @@ public: ...@@ -32,7 +32,7 @@ public:
const TensorLayout& /* mask */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */, const TensorLayout& /* dst */,
size_t /* workspace_limit_in_bytes */, size_t /* workspace_limit_in_bytes */,
bool /* reproducible */) override { const AlgoAttribute& /*attr*/) override {
return nullptr; return nullptr;
}; };
...@@ -74,7 +74,7 @@ public: ...@@ -74,7 +74,7 @@ public:
const TensorLayout& /* out_grad */, const TensorLayout& /* out_grad */,
const TensorLayout& /* filter_grad */, const TensorLayout& /* filter_grad */,
size_t /* workspace_limit_in_bytes */, size_t /* workspace_limit_in_bytes */,
bool /* reproducible */) override { const AlgoAttribute& /*attr*/) override {
return nullptr; return nullptr;
}; };
...@@ -121,7 +121,7 @@ public: ...@@ -121,7 +121,7 @@ public:
const TensorLayout& /* offset_grad */, const TensorLayout& /* offset_grad */,
const TensorLayout& /* mask_grad */, const TensorLayout& /* mask_grad */,
size_t /* workspace_limit_in_bytes */, size_t /* workspace_limit_in_bytes */,
bool /* reproducible */) override { const AlgoAttribute& /*attr*/) override {
return nullptr; return nullptr;
}; };
......
...@@ -162,16 +162,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, ...@@ -162,16 +162,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&,
LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = auto algo =
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
...@@ -196,16 +194,14 @@ LocalShareBackwardData::Algorithm* ...@@ -196,16 +194,14 @@ LocalShareBackwardData::Algorithm*
LocalShareBackwardDataImpl::get_algorithm_heuristic( LocalShareBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */, const TensorLayout& /* filter */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = static_cast<HandleImpl*>(handle()) auto algo = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_data_algo(); ->default_local_share_bwd_data_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
...@@ -230,16 +226,14 @@ LocalShareBackwardFilter::Algorithm* ...@@ -230,16 +226,14 @@ LocalShareBackwardFilter::Algorithm*
LocalShareBackwardFilterImpl::get_algorithm_heuristic( LocalShareBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) { const AlgoAttribute& attr) {
auto algo = static_cast<HandleImpl*>(handle()) auto algo = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_filter_algo(); ->default_local_share_bwd_filter_algo();
if (reproducible) { megdnn_assert(algo->contain_attribute(attr),
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), "require algorithm with attribute%s, but heuristic "
"require reproducible algorithm, but heuristic " "algorithm(%s) with attribute%s ",
"algorithm(%s) is not " Algorithm::attribute_str(attr).c_str(), algo->name(),
"reproducible", Algorithm::attribute_str(algo->attribute()).c_str());
algo->name());
}
return algo; return algo;
} }
......
...@@ -34,7 +34,7 @@ public: ...@@ -34,7 +34,7 @@ public:
const TensorLayout& /*filter*/, const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/, const TensorLayout& /*dst*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
...@@ -59,7 +59,7 @@ public: ...@@ -59,7 +59,7 @@ public:
const TensorLayout& /*diff*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/, const TensorLayout& /*grad*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
...@@ -84,7 +84,7 @@ public: ...@@ -84,7 +84,7 @@ public:
const TensorLayout& /*diff*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/, const TensorLayout& /*grad*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
......
...@@ -91,7 +91,7 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, ...@@ -91,7 +91,7 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) { const AlgoAttribute& /*attr*/) {
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
} }
......
...@@ -33,7 +33,7 @@ public: ...@@ -33,7 +33,7 @@ public:
const TensorLayout& /*B*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override; const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
......
...@@ -70,12 +70,11 @@ public: ...@@ -70,12 +70,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) const { bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) const { size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -32,16 +32,16 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, ...@@ -32,16 +32,16 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
BatchedMatrixMulForwardImpl::Algorithm* BatchedMatrixMulForwardImpl::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
AlgoBase::SizeArgs args{this, A, B, C}; AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.blas.is_available_reproducible(args, reproducible, if (sm_algo_pack.blas.is_available_attribute(args, attr,
workspace_limit_in_bytes)) { workspace_limit_in_bytes)) {
return &sm_algo_pack.blas; return &sm_algo_pack.blas;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( return megdnn::get_algo_with_attribute<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward"); "batched matrix mul forward", attr);
} else { } else {
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
......
...@@ -40,7 +40,7 @@ private: ...@@ -40,7 +40,7 @@ private:
const TensorLayout& /*B*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; const AlgoAttribute& /*attr*/) override;
const char* get_algorithm_set_name() const override { const char* get_algorithm_set_name() const override {
return "ROCM BATCHED MATMUL"; return "ROCM BATCHED MATMUL";
......
...@@ -74,12 +74,11 @@ public: ...@@ -74,12 +74,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
...@@ -96,24 +95,20 @@ public: ...@@ -96,24 +95,20 @@ public:
}; };
class ConvolutionBackwardDataImpl::AlgoMIOpen final : public AlgoBase { class ConvolutionBackwardDataImpl::AlgoMIOpen final : public AlgoBase {
bool m_is_reproducible; AlgoAttribute m_algo_attribute;
const char* m_name; const char* m_name;
miopenConvBwdDataAlgorithm_t find_best_algo(const ExecArgs& args); miopenConvBwdDataAlgorithm_t find_best_algo(const ExecArgs& args);
public: public:
AlgoMIOpen() = delete; AlgoMIOpen() = delete;
AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} AlgoMIOpen(AlgoAttribute attr) : m_algo_attribute(attr) {}
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0); return m_algo_attribute;
if (m_is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
} }
const char* name() const override { const char* name() const override {
...@@ -124,7 +119,7 @@ public: ...@@ -124,7 +119,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN)
std::string param() const override { std::string param() const override {
std::string ret; std::string ret;
serialize_write_pod(m_is_reproducible, ret); serialize_write_pod(m_algo_attribute, ret);
return ret; return ret;
} }
...@@ -170,7 +165,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { ...@@ -170,7 +165,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
public: public:
AlgoPack(); AlgoPack();
AlgoMIOpen miopen{true}; AlgoMIOpen miopen{AlgoAttribute::REPRODUCIBLE};
AlgoMatmul matmul; AlgoMatmul matmul;
AlgoChanwise chanwise; AlgoChanwise chanwise;
......
...@@ -71,12 +71,11 @@ public: ...@@ -71,12 +71,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
...@@ -93,25 +92,21 @@ public: ...@@ -93,25 +92,21 @@ public:
}; };
class ConvolutionBackwardFilterImpl::AlgoMIOpen final : public AlgoBase { class ConvolutionBackwardFilterImpl::AlgoMIOpen final : public AlgoBase {
bool m_is_reproducible; AlgoAttribute m_algo_attribute;
const char* m_name; const char* m_name;
miopenConvBwdWeightsAlgorithm_t find_best_algo(const ExecArgs& args); miopenConvBwdWeightsAlgorithm_t find_best_algo(const ExecArgs& args);
public: public:
AlgoMIOpen() = delete; AlgoMIOpen() = delete;
AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} AlgoMIOpen(AlgoAttribute attr) : m_algo_attribute(attr) {}
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0); return m_algo_attribute;
if (m_is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
} }
const char* name() const override { const char* name() const override {
return "MIOpenConvolutionBackwardFilter"; return "MIOpenConvolutionBackwardFilter";
...@@ -121,7 +116,7 @@ public: ...@@ -121,7 +116,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN)
std::string param() const override { std::string param() const override {
std::string ret; std::string ret;
serialize_write_pod(m_is_reproducible, ret); serialize_write_pod(m_algo_attribute, ret);
return ret; return ret;
} }
...@@ -166,7 +161,7 @@ class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { ...@@ -166,7 +161,7 @@ class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj {
public: public:
AlgoPack(); AlgoPack();
AlgoMIOpen miopen{true}; AlgoMIOpen miopen{AlgoAttribute::REPRODUCIBLE};
AlgoMatmul matmul; AlgoMatmul matmul;
AlgoChanwise chanwise; AlgoChanwise chanwise;
......
...@@ -73,12 +73,11 @@ public: ...@@ -73,12 +73,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) { size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
...@@ -94,25 +93,21 @@ public: ...@@ -94,25 +93,21 @@ public:
}; };
class ConvolutionForwardImpl::AlgoMIOpen final : public AlgoBase { class ConvolutionForwardImpl::AlgoMIOpen final : public AlgoBase {
bool m_is_reproducible; AlgoAttribute m_algo_attribute;
const char* m_name; const char* m_name;
miopenConvFwdAlgorithm_t find_best_algo(const ExecArgs& args); miopenConvFwdAlgorithm_t find_best_algo(const ExecArgs& args);
public: public:
AlgoMIOpen() = delete; AlgoMIOpen() = delete;
AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} AlgoMIOpen(AlgoAttribute attr) : m_algo_attribute(attr) {}
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0); return m_algo_attribute;
if (m_is_reproducible) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
} }
const char* name() const override { return "MIOpenConvolutionForward"; } const char* name() const override { return "MIOpenConvolutionForward"; }
...@@ -121,7 +116,7 @@ public: ...@@ -121,7 +116,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN)
std::string param() const override { std::string param() const override {
std::string ret; std::string ret;
serialize_write_pod(m_is_reproducible, ret); serialize_write_pod(m_algo_attribute, ret);
return ret; return ret;
} }
...@@ -215,7 +210,7 @@ class ConvolutionForwardImpl::AlgoPack : NonCopyableObj { ...@@ -215,7 +210,7 @@ class ConvolutionForwardImpl::AlgoPack : NonCopyableObj {
public: public:
AlgoPack(); AlgoPack();
AlgoMIOpen miopen{true}; AlgoMIOpen miopen{AlgoAttribute::REPRODUCIBLE};
AlgoMatmul matmul; AlgoMatmul matmul;
AlgoInplaceMatmul inplace_matmul; AlgoInplaceMatmul inplace_matmul;
Algo1x1 a1x1; Algo1x1 a1x1;
......
...@@ -33,70 +33,69 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, ...@@ -33,70 +33,69 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = check_layout_fwd(src, filter, dst); auto fm = check_layout_fwd(src, filter, dst);
return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes,
reproducible); attr);
} }
ConvolutionForwardImpl::Algorithm* ConvolutionForwardImpl::Algorithm*
ConvolutionForwardImpl::get_algorithm_heuristic( ConvolutionForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const CanonizedFilterMeta& filter, const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, src, filter, dst); AlgoBase::SizeArgs args(this, src, filter, dst);
//! MIOpen auto-tuning need to run with actual tensors, so we cannot get //! MIOpen auto-tuning need to run with actual tensors, so we cannot get
//! best algorithm here. //! best algorithm here.
if (is_miopen_supported(args)) { if (is_miopen_supported(args)) {
auto algo = megdnn::get_reproducible_algo<ConvolutionForwardImpl>( auto algo = megdnn::get_algo_with_attribute<ConvolutionForwardImpl>(
sm_algo_pack.miopen_algos[0], reproducible); sm_algo_pack.miopen_algos[0], attr);
if (algo) if (algo)
return algo; return algo;
} }
if (args.filter_meta.group > 1) { if (args.filter_meta.group > 1) {
if (sm_algo_pack.chanwise.is_available_reproducible( if (sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
} }
} }
auto prefer_1x1 = [&args, reproducible, workspace_limit_in_bytes]() { auto prefer_1x1 = [&args, attr, workspace_limit_in_bytes]() {
const size_t MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO = 4; const size_t MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO = 4;
size_t batch_size = args.src_layout->shape[0]; size_t batch_size = args.src_layout->shape[0];
if (batch_size > MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO) { if (batch_size > MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO) {
return false; return false;
} }
return sm_algo_pack.a1x1.is_available_reproducible( return sm_algo_pack.a1x1.is_available_attribute(
args, reproducible, workspace_limit_in_bytes); args, attr, workspace_limit_in_bytes);
}; };
if (prefer_1x1()) { if (prefer_1x1()) {
return &sm_algo_pack.a1x1; return &sm_algo_pack.a1x1;
} }
auto prefer_1x1_large_batch = [&args, reproducible, auto prefer_1x1_large_batch = [&args, attr, workspace_limit_in_bytes]() {
workspace_limit_in_bytes]() {
const size_t MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO = 32; const size_t MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO = 32;
size_t batch_size = args.src_layout->shape[0]; size_t batch_size = args.src_layout->shape[0];
if (batch_size < MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO) { if (batch_size < MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO) {
return false; return false;
} }
return sm_algo_pack.batched_matrix_mul.is_available_reproducible( return sm_algo_pack.batched_matrix_mul.is_available_attribute(
args, reproducible, workspace_limit_in_bytes); args, attr, workspace_limit_in_bytes);
}; };
if (prefer_1x1_large_batch()) { if (prefer_1x1_large_batch()) {
return &sm_algo_pack.batched_matrix_mul; return &sm_algo_pack.batched_matrix_mul;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvolutionForwardImpl>( return megdnn::get_algo_with_attribute<ConvolutionForwardImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv fwd"); "rocm conv fwd", attr);
} else { } else {
return megdnn::get_usable_algo<ConvolutionForwardImpl>( return megdnn::get_usable_algo<ConvolutionForwardImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
...@@ -157,36 +156,36 @@ ConvolutionBackwardDataImpl::Algorithm* ...@@ -157,36 +156,36 @@ ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = check_layout_fwd(grad, filter, diff); auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes,
reproducible); attr);
} }
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff, const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, filter, diff, grad); AlgoBase::SizeArgs args(this, filter, diff, grad);
if (is_miopen_supported(args.as_fwd_args())) { if (is_miopen_supported(args.as_fwd_args())) {
auto algo = megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( auto algo = megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.miopen_algos[0], reproducible); sm_algo_pack.miopen_algos[0], attr);
if (algo) if (algo)
return algo; return algo;
} }
if (args.filter_meta.group > 1 && if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible( sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv bwd_data"); "rocm conv bwd_data", attr);
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
...@@ -230,38 +229,38 @@ ConvolutionBackwardFilterImpl::Algorithm* ...@@ -230,38 +229,38 @@ ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
auto fm = check_layout_fwd(src, grad, diff); auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes,
reproducible); attr);
} }
ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
AlgoBase::SizeArgs args(this, src, diff, grad); AlgoBase::SizeArgs args(this, src, diff, grad);
if (is_miopen_supported(args.as_fwd_args())) { if (is_miopen_supported(args.as_fwd_args())) {
auto algo = auto algo =
megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>(
sm_algo_pack.miopen_algos[0], reproducible); sm_algo_pack.miopen_algos[0], attr);
if (algo) if (algo)
return algo; return algo;
} }
if (args.grad_filter_meta.group > 1 && if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible( sm_algo_pack.chanwise.is_available_attribute(
args, reproducible, workspace_limit_in_bytes)) { args, attr, workspace_limit_in_bytes)) {
// prefer special chanwise impl // prefer special chanwise impl
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( return megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv bwd_filter"); "rocm conv bwd_filter", attr);
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
......
...@@ -26,9 +26,9 @@ public: ...@@ -26,9 +26,9 @@ public:
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const CanonizedFilterMeta& filter, const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
return get_algorithm_heuristic(src, filter, dst, return get_algorithm_heuristic(src, filter, dst,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src,
...@@ -76,12 +76,12 @@ private: ...@@ -76,12 +76,12 @@ private:
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src, Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const CanonizedFilterMeta& filter, const CanonizedFilterMeta& filter,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
...@@ -94,9 +94,9 @@ public: ...@@ -94,9 +94,9 @@ public:
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff, const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
return get_algorithm_heuristic(filter, diff, grad, return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
size_t get_workspace_in_bytes(const TensorLayout& filter, size_t get_workspace_in_bytes(const TensorLayout& filter,
...@@ -122,12 +122,12 @@ private: ...@@ -122,12 +122,12 @@ private:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
...@@ -141,9 +141,9 @@ public: ...@@ -141,9 +141,9 @@ public:
const TensorLayout& diff, const TensorLayout& diff,
const CanonizedFilterMeta& grad, const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { const AlgoAttribute& attr) {
return get_algorithm_heuristic(src, diff, grad, return get_algorithm_heuristic(src, diff, grad,
workspace_limit_in_bytes, reproducible) workspace_limit_in_bytes, attr)
->info(); ->info();
} }
size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src,
...@@ -169,12 +169,12 @@ private: ...@@ -169,12 +169,12 @@ private:
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src, Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& diff,
const CanonizedFilterMeta& grad, const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); const AlgoAttribute& attr);
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
......
...@@ -70,12 +70,11 @@ public: ...@@ -70,12 +70,11 @@ public:
bool is_available_wk(const SizeArgs& args, size_t limit) const { bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_attribute(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
size_t limit = std::numeric_limits<size_t>::max()) const { size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || return contain_attribute(attr) && is_available_wk(args, limit);
contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
is_available_wk(args, limit);
} }
AlgoBase& check_workspace(const SizeArgs& args, AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) { const Workspace& workspace) {
......
...@@ -29,16 +29,16 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, ...@@ -29,16 +29,16 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
AlgoBase::SizeArgs args{this, A, B, C}; AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.blas.is_available_reproducible( if (sm_algo_pack.blas.is_available_attribute(args, attr,
args, reproducible, workspace_limit_in_bytes)) { workspace_limit_in_bytes)) {
return &sm_algo_pack.blas; return &sm_algo_pack.blas;
} }
if (reproducible) { if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_reproducible_algo<MatrixMulForwardImpl>( return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward"); "matrix mul forward", attr);
} else { } else {
return megdnn::get_usable_algo<MatrixMulForwardImpl>( return megdnn::get_usable_algo<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
......
...@@ -40,7 +40,7 @@ private: ...@@ -40,7 +40,7 @@ private:
const TensorLayout& /*B*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; const AlgoAttribute& /*attr*/) override;
const char* get_algorithm_set_name() const override { const char* get_algorithm_set_name() const override {
return "ROCM MATMUL"; return "ROCM MATMUL";
......
...@@ -278,6 +278,15 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( ...@@ -278,6 +278,15 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
return ret; return ret;
} }
AlgoAttribute extract_algo_attribute_from_execution_strategy(
const ExecutionStrategy& strategy) {
AlgoAttribute ret = AlgoAttribute::DEFAULT;
if (strategy & ExecutionStrategy::REPRODUCIBLE) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
//! Test whether the algo attribute of a algo match the require //! Test whether the algo attribute of a algo match the require
//! algo_strategy //! algo_strategy
static bool algo_attribute_match_strategy(AlgoAttribute attribute, static bool algo_attribute_match_strategy(AlgoAttribute attribute,
...@@ -290,7 +299,6 @@ static bool algo_attribute_match_strategy(AlgoAttribute attribute, ...@@ -290,7 +299,6 @@ static bool algo_attribute_match_strategy(AlgoAttribute attribute,
} }
return ret; return ret;
} }
} // namespace } // namespace
namespace mgb { namespace mgb {
...@@ -303,9 +311,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, ...@@ -303,9 +311,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
return; return;
AlgoChooserProfileCache::Result prof_rst; AlgoChooserProfileCache::Result prof_rst;
std::string str_on_inp_shape = ssprintf( auto target_attribute =
"on input layouts (%s, %s)", ctx.layouts()[0].to_string().c_str(), extract_algo_attribute_from_execution_strategy(selected_strategy);
ctx.layouts()[1].to_string().c_str()); std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out);
double cur_timeout = 0; double cur_timeout = 0;
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
...@@ -316,20 +324,22 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, ...@@ -316,20 +324,22 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst;
std::string msg = ssprintf("profiling %s algorithm %s %s", std::string msg = ssprintf("profiling %s algorithm %s %s",
ctx.mgb_opr()->dyn_typeinfo()->name, ctx.mgb_opr()->dyn_typeinfo()->name,
algo.name.c_str(), str_on_inp_shape.c_str()); algo.name.c_str(), layouts_str.c_str());
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
policy.algo = algo.desc; policy.algo = algo.desc;
ctx.construct_execution_policy(selected_strategy, policy); ctx.construct_execution_policy(selected_strategy, policy);
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) {
continue; continue;
} }
auto algo_attribute = ctx.megdnn_opr() auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo);
->get_algorithm_from_desc(policy.algo) if (!algo_attribute_match_strategy(palgo->attribute(),
->attribute(); selected_strategy)) {
if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) {
mgb_log_debug( mgb_log_debug(
"skip algo %s, which is not match the profile strategy.", "skip algo %s with attribute%s, which is not match the "
algo.name.c_str()); "profile strategy required attribute%s.",
algo.name.c_str(),
Algorithm::attribute_str(palgo->attribute()).c_str(),
Algorithm::attribute_str(target_attribute).c_str());
continue; continue;
} }
...@@ -360,9 +370,10 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, ...@@ -360,9 +370,10 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
rst.workspace, rst.time); rst.workspace, rst.time);
prof_rst.push_back(rst); prof_rst.push_back(rst);
} }
std::string msg = ssprintf("no usable %s algorithm %s", std::string msg =
ctx.mgb_opr()->dyn_typeinfo()->name, ssprintf("no usable %s algorithm %s with attribute(%s)",
str_on_inp_shape.c_str()); ctx.mgb_opr()->dyn_typeinfo()->name, layouts_str.c_str(),
Algorithm::attribute_str(target_attribute).c_str());
mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
FixedTensorLayouts origin_layouts = ctx.layouts(); FixedTensorLayouts origin_layouts = ctx.layouts();
...@@ -589,14 +600,15 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( ...@@ -589,14 +600,15 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
"workspace_limit should not be setted if choose algo by " "workspace_limit should not be setted if choose algo by "
"heuristic"); "heuristic");
} }
bool reproducible = static_cast<bool>(selected_strategy &
ExecutionStrategy::REPRODUCIBLE);
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, reproducible), args..., workspace_limit,
m_layouts).desc; extract_algo_attribute_from_execution_strategy(
selected_strategy)),
m_layouts)
.desc;
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(algo, "Unknown algo description"); mgb_assert(algo, "Unknown algo description");
...@@ -647,8 +659,6 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( ...@@ -647,8 +659,6 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
ExecutionStrategy selected_strategy, ExecutionStrategy selected_strategy,
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, typename AlgoChooser<Opr>::ImplExecutionPolicy& policy,
bool retrive_from_cache) const { bool retrive_from_cache) const {
bool reproducible = static_cast<bool>(selected_strategy &
ExecutionStrategy::REPRODUCIBLE);
if (!policy.algo.valid()) { if (!policy.algo.valid()) {
if (retrive_from_cache) { if (retrive_from_cache) {
policy.algo = policy.algo =
...@@ -656,11 +666,13 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( ...@@ -656,11 +666,13 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
} else { } else {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( policy.algo =
args..., workspace_limit, APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
reproducible), args..., workspace_limit,
m_layouts) extract_algo_attribute_from_execution_strategy(
.desc; selected_strategy)),
m_layouts)
.desc;
} }
mgb_assert(policy.algo.valid(), mgb_assert(policy.algo.valid(),
"No algo found from cache or heuristic, maybe some error " "No algo found from cache or heuristic, maybe some error "
......
...@@ -2375,7 +2375,7 @@ public: ...@@ -2375,7 +2375,7 @@ public:
AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p2,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible)); const AlgoAttribute& attr));
MOCK_METHOD3(get_all_algorithms, MOCK_METHOD3(get_all_algorithms,
std::vector<Algorithm*>(const TensorLayout& p0, std::vector<Algorithm*>(const TensorLayout& p0,
...@@ -2385,7 +2385,7 @@ public: ...@@ -2385,7 +2385,7 @@ public:
Algorithm*(const TensorLayout& p0, const TensorLayout& p1, Algorithm*(const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p2,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible)); const AlgoAttribute& attr));
MOCK_METHOD1(get_algorithm_from_desc, MOCK_METHOD1(get_algorithm_from_desc,
Algorithm*(const AlgorithmDesc&)); Algorithm*(const AlgorithmDesc&));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册