提交 c338e876 编写于 作者: M Megvii Engine Team

refactor(mgb/dnn): add negative attribute for algo

GitOrigin-RevId: 88b1ce94a514d3b8c69f786dc1eb4c5f2606c5c6
上级 ec1a99ac
......@@ -165,7 +165,15 @@ public:
virtual std::string param() const { return {}; }
virtual uint32_t type() const = 0;
bool contain_attribute(const Attribute& attr) const;
//! if algo contain all of the attribute in attr
bool contain_attribute_all(const Attribute& attr) const;
//! if algo contain any attribute in attr
bool contain_attribute_any(const Attribute& attr) const;
void check_attribute(
const Attribute& positive_attr = Attribute::DEFAULT,
const Attribute& negative_attr = Attribute::DEFAULT) const;
static std::string attribute_str(const Attribute& attr);
......@@ -342,9 +350,10 @@ public:
const TensorLayout& p2,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) {
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes,
attr)
positive_attr, negative_attr)
->info();
}
......@@ -367,7 +376,8 @@ protected:
const TensorLayout& p2,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0;
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};
//! specializae for nargs == 4
......@@ -402,9 +412,10 @@ public:
const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) {
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes,
attr)
positive_attr, negative_attr)
->info();
}
......@@ -427,7 +438,8 @@ protected:
const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0;
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};
//! specializae for nargs == 5
......@@ -464,9 +476,11 @@ public:
const TensorLayout& p4,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) {
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
......@@ -491,7 +505,8 @@ protected:
const TensorLayout& p4,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0;
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};
//! specializae for nargs == 8
......@@ -528,9 +543,11 @@ public:
const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) {
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
......@@ -557,7 +574,8 @@ protected:
const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0;
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};
} // namespace detail
......
......@@ -27,7 +27,7 @@ inline const char* attr_str(const AlgoAttribute& attr) {
return #attr;
switch (attr) { FOREACH_ALGO_ATTRIBUTE(cb) }
#undef cb
return "unknown arch";
return "UNKNOWN";
}
} // namespace
......@@ -43,11 +43,30 @@ std::string Algorithm::attribute_str(const Attribute& attr) {
ret.append(attr_str(sub_attr));
attr_val = attr_val & (attr_val - 1);
}
if (ret.empty()) {
ret = "DEFAULT";
}
return ret;
}
bool Algorithm::contain_attribute(const Attribute& attr) const {
bool Algorithm::contain_attribute_all(const Attribute& attr) const {
return attr == static_cast<Attribute>(attribute() & attr);
}
bool Algorithm::contain_attribute_any(const Attribute& attr) const {
return static_cast<bool>(attribute() & attr);
}
void Algorithm::check_attribute(const Attribute& positive_attr,
const Attribute& negative_attr) const {
megdnn_assert(contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr),
"require algorithm with attribute(%s) and without "
"attribute(%s), but get"
"algorithm(%s) with attribute(%s) ",
Algorithm::attribute_str(positive_attr).c_str(),
Algorithm::attribute_str(negative_attr).c_str(), name(),
Algorithm::attribute_str(attribute()).c_str());
}
// vim: syntax=cpp.doxygen
......@@ -32,7 +32,7 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
} else {
ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT).desc;
AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT).desc;
}
return static_cast<typename Opr::AlgoBase*>(
opr->get_algorithm_from_desc(ret));
......@@ -51,6 +51,7 @@ typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) {
return static_cast<typename Opr::AlgoBase*>(
opr->get_algorithm_heuristic(std::forward<Args>(args)...,
std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT,
AlgoAttribute::DEFAULT));
}
}
......@@ -74,34 +75,37 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms(
}
/*!
* \brief a helper function to get an algorithm with attribute. If require a
* algorithm with specified attribute, and the given algorithm has that
* \brief a helper function to get an algorithm match attribute. If require a
* algorithm with specified attribute, and the given algorithm match that
* attribute, return the given algorithm. Otherwise return nullptr
*/
template <typename Opr>
typename Opr::Algorithm* get_algo_with_attribute(typename Opr::AlgoBase* algo,
const AlgoAttribute& attr) {
if (algo->contain_attribute(attr)) {
typename Opr::Algorithm* get_algo_match_attribute(
typename Opr::AlgoBase* algo, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
if (algo->contain_attribute_all(positive_attr) &&
!algo->contain_attribute_any(negative_attr)) {
return algo;
}
return nullptr;
}
template <typename Opr>
typename Opr::Algorithm* get_algo_with_attribute(
typename Opr::Algorithm* get_algo_match_attribute(
const std::vector<typename Opr::AlgoBase*>& algos,
const typename Opr::AlgoBase::SizeArgs& args,
size_t workspace_limit_in_bytes, const char* name,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) {
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max();
bool available_but_limited_by_workspace = false;
bool available_but_without_attribute = false;
bool available_but_attribute_mismatch = false;
for (auto i : algos) {
if (i->is_available_attribute(args, attr,
if (i->is_available_attribute(args, positive_attr, negative_attr,
workspace_limit_in_bytes)) {
return i;
}
if (i->is_available_attribute(args)) {
if (i->is_available_attribute(args, positive_attr, negative_attr)) {
if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) {
available_but_limited_by_workspace = true;
min_workspace_limit_in_bytes =
......@@ -110,53 +114,27 @@ typename Opr::Algorithm* get_algo_with_attribute(
}
}
if (i->is_available(args)) {
if (!i->contain_attribute(attr))
available_but_without_attribute = true;
if (!(i->contain_attribute_all(positive_attr) &&
!i->contain_attribute_any(negative_attr)))
available_but_attribute_mismatch = true;
}
}
MEGDNN_MARK_USED_VAR(name);
if (available_but_limited_by_workspace) {
megdnn_throw(
ssprintf("no %s algorithm without attribute(%s) with "
"attribute(%s) : %s workspace limit %zu is "
"less than mini workspace limit %zu",
name, Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes,
min_workspace_limit_in_bytes));
} else if (available_but_attribute_mismatch) {
megdnn_throw(ssprintf(
"no %s algorithm with attribute:%s : %s workspace limit %zu is "
"less than mini workspace limit %zu",
name, Algorithm::attribute_str(attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes,
min_workspace_limit_in_bytes));
} else if (available_but_without_attribute) {
megdnn_throw(ssprintf("no %s algorithm with attribute:%s", name,
Algorithm::attribute_str(attr).c_str()));
} else {
megdnn_throw(ssprintf("no usable %s algorithm", name));
}
}
template <typename Opr>
typename Opr::Algorithm* get_usable_algo(
const std::vector<typename Opr::AlgoBase*>& algos,
const typename Opr::AlgoBase::SizeArgs& args,
size_t workspace_limit_in_bytes, const char* name) {
size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max();
bool available_but_limited_by_workspace = false;
for (auto i : algos) {
if (i->is_available_wk(args, workspace_limit_in_bytes)) {
return i;
}
if (i->is_available(args)) {
available_but_limited_by_workspace = true;
min_workspace_limit_in_bytes =
std::min(min_workspace_limit_in_bytes,
i->get_workspace_in_bytes(args));
}
}
MEGDNN_MARK_USED_VAR(name);
if (available_but_limited_by_workspace) {
megdnn_throw(ssprintf(
"no usable %s algorithm: %s workspace limit %zu is "
"less than mini workspace limit %zu",
name, args.to_string().c_str(), workspace_limit_in_bytes,
min_workspace_limit_in_bytes));
"no %s algorithm without attribute(%s) with attribute(%s)", name,
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str()));
} else {
megdnn_throw(ssprintf("no usable %s algorithm", name));
}
......
......@@ -67,9 +67,12 @@ public:
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......
......@@ -22,21 +22,24 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, filter, bias, z, dst);
if (sm_algo_pack.int8_nchw4_gemm_dotprod.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.int8_nchw4_gemm_dotprod;
}
if (sm_algo_pack.int8_nchw4_implicit_gemm_dotprod.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.int8_nchw4_implicit_gemm_dotprod;
}
megdnn_throw(ssprintf(
"no batch conv bias algorithm with attribute%s args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(),
workspace_limit_in_bytes));
megdnn_throw(
ssprintf("no batch conv bias algorithm without attribute(%s) with "
"attribute(%s) args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
}
std::vector<BatchConvBiasForwardImpl::Algorithm*>
......
......@@ -42,13 +42,12 @@ protected:
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......
......@@ -70,9 +70,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -55,26 +55,37 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms(
Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
AlgoBase::SizeArgs args(this, A, B, C);
if (sm_algo_pack.cublas.is_available_attribute(args, attr)) {
if (sm_algo_pack.cublas.is_available_attribute(args, positive_attr,
negative_attr)) {
return &sm_algo_pack.cublas;
}
#if CUDA_VERSION >= 10010
else if (sm_algo_pack.cublasLt.is_available_attribute(args, attr)) {
else if (sm_algo_pack.cublasLt.is_available_attribute(args, positive_attr,
negative_attr)) {
return &sm_algo_pack.cublasLt;
}
#endif
else if (sm_algo_pack.int8x8x32.is_available_attribute(args, attr)) {
else if (sm_algo_pack.int8x8x32.is_available_attribute(args, positive_attr,
negative_attr)) {
return &sm_algo_pack.int8x8x32;
} else {
if (sm_algo_pack.brute_force.is_available_attribute(args, attr)) {
if (sm_algo_pack.brute_force.is_available_attribute(args, positive_attr,
negative_attr)) {
return &sm_algo_pack.brute_force;
}
}
megdnn_throw("No usable algo for batched_matrix_mul");
megdnn_throw(ssprintf(
"no batched_matrix_mul algorithm without attribute(%s) with "
"attribute(%s) args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
return nullptr;
};
......
......@@ -45,11 +45,10 @@ protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......
......@@ -129,9 +129,12 @@ public:
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -426,7 +429,7 @@ public:
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
......
......@@ -51,7 +51,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
using namespace conv_bias;
AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
auto dst_layout = *args.dst_layout;
......@@ -74,7 +75,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
};
auto get_cudnn_algo =
[this, &conv_args, &args, workspace_limit_in_bytes, attr](
[this, &conv_args, &args, workspace_limit_in_bytes, positive_attr,
negative_attr](
const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>&
cb) -> AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle());
......@@ -93,7 +95,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
for (int i = 0; i < ret_count; ++i) {
auto conv_bias_algo = cb(algo_perf[i].algo);
if (conv_bias_algo->is_available_attribute(
args, attr, workspace_limit_in_bytes))
args, positive_attr, negative_attr,
workspace_limit_in_bytes))
return conv_bias_algo;
}
#else
......@@ -105,18 +108,20 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
workspace_limit_in_bytes, &algo));
auto conv_bias_algo = cb(algo);
if (conv_bias_algo->is_available_attribute(args, attr,
if (conv_bias_algo->is_available_attribute(args, positive_attr,
negative_attr,
workspace_limit_in_bytes))
return conv_bias_algo;
#endif
return nullptr;
};
auto get_1x1_algo = [workspace_limit_in_bytes,
attr](const AlgoBase::SizeArgs& size_arg)
auto get_1x1_algo = [workspace_limit_in_bytes, positive_attr,
negative_attr](const AlgoBase::SizeArgs& size_arg)
-> ConvBiasForwardImpl::AlgoBase* {
if (sm_algo_pack.batched_matmul.is_available_attribute(
size_arg, attr, workspace_limit_in_bytes)) {
size_arg, positive_attr, negative_attr,
workspace_limit_in_bytes)) {
return &sm_algo_pack.batched_matmul;
}
return nullptr;
......@@ -145,10 +150,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
if (is_chanwise) {
if (prefer_dnn_chanwise) {
if (sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes))
args, positive_attr, negative_attr,
workspace_limit_in_bytes))
return &sm_algo_pack.chanwise;
if (sm_algo_pack.chanwise8x8x32.is_available_attribute(
args, attr, workspace_limit_in_bytes))
args, positive_attr, negative_attr,
workspace_limit_in_bytes))
return &sm_algo_pack.chanwise8x8x32;
} else {
conv_args.dst_layout = &dst_layout;
......@@ -163,7 +170,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
//! Prefer CUDNN CONVBIAS.
bool cudnn_conv_bias_act_supported = false;
for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) {
if (algo.is_available_attribute(args, attr, workspace_limit_in_bytes)) {
if (algo.is_available_attribute(args, positive_attr, negative_attr,
workspace_limit_in_bytes)) {
cudnn_conv_bias_act_supported = true;
break;
}
......@@ -201,30 +209,18 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
}
if (sm_algo_pack.fallback_nchw_qs8.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.fallback_nchw_qs8;
}
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd", attr);
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd");
}
return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd", positive_attr, negative_attr);
} else {
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd", attr);
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
}
return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd", positive_attr, negative_attr);
}
}
......
......@@ -76,13 +76,12 @@ public:
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......
......@@ -84,9 +84,12 @@ public:
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -229,7 +232,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
......
......@@ -80,9 +80,12 @@ public:
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......@@ -214,7 +217,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
......
......@@ -65,9 +65,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......
......@@ -33,14 +33,15 @@ using namespace convolution;
/* ============== ConvolutionForwardImpl ============== */
ConvolutionForwardImpl::Algorithm*
ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
ConvolutionForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args{this, src, filter, dst};
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
MEGDNN_MARK_USED_VAR(attr);
MEGDNN_MARK_USED_VAR(positive_attr);
MEGDNN_MARK_USED_VAR(negative_attr);
return &sm_algo_pack.algo_default;
}
......@@ -101,46 +102,45 @@ ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(filter, fm, diff, grad,
workspace_limit_in_bytes, attr);
workspace_limit_in_bytes, positive_attr,
negative_attr);
}
ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad);
if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
// prefer special chanwise impl
return &sm_algo_pack.chanwise;
}
if (args.filter_layout->dtype.enumv() ==
DTypeTrait<dtype::QuantizedS8>::enumv) {
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.int8_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data", attr);
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.int8_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
}
return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.int8_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data", positive_attr, negative_attr);
}
auto get_cudnn_algo = [this, &args, workspace_limit_in_bytes,
attr]() -> ConvolutionBackwardDataImpl::AlgoBase* {
auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes, positive_attr,
negative_attr]() -> ConvolutionBackwardDataImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle());
CUDNNBwdDataDescs desc;
args.init_desc(desc);
#if CUDNN_MAJOR >= 7
MEGDNN_MARK_USED_VAR(negative_attr);
int max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(
cudnn_handle, &max_count));
......@@ -153,7 +153,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
for (int i = 0; i < ret_count; ++i) {
if (algo_perf[i].memory > workspace_limit_in_bytes)
continue;
if (attr & AlgoAttribute::REPRODUCIBLE) {
if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) {
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
return reinterpret_cast<AlgoBase*>(
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
......@@ -174,8 +174,8 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
auto&& cast_algo =
reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo));
return reinterpret_cast<AlgoBase*>(
megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
cast_algo, attr));
megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
cast_algo, positive_attr, negative_attr));
#endif
};
......@@ -197,25 +197,13 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
if (args.filter_layout->dtype.enumv() !=
DTypeTrait<dtype::BFloat16>::enumv) {
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data", attr);
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data");
}
return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data", positive_attr, negative_attr);
} else {
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data", attr);
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
}
return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data", positive_attr, negative_attr);
}
}
......@@ -255,29 +243,33 @@ ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, grad, fm,
workspace_limit_in_bytes, attr);
workspace_limit_in_bytes, positive_attr,
negative_attr);
}
ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta);
if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
// prefer special chanwise impl
return &sm_algo_pack.chanwise;
}
auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes,
attr]() -> ConvolutionBackwardFilterImpl::AlgoBase* {
[this, &args, workspace_limit_in_bytes, positive_attr,
negative_attr]() -> ConvolutionBackwardFilterImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle());
CUDNNBwdFilterDescs desc;
args.init_desc(desc);
......@@ -293,6 +285,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
}
#endif
#if CUDNN_MAJOR >= 7
MEGDNN_MARK_USED_VAR(negative_attr);
int max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(
cudnn_handle, &max_count));
......@@ -305,7 +298,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
for (int i = 0; i < ret_count; ++i) {
if (algo_perf[i].memory > workspace_limit_in_bytes)
continue;
if (attr & AlgoAttribute::REPRODUCIBLE) {
if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) {
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
return reinterpret_cast<AlgoBase*>(
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
......@@ -326,8 +319,8 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
auto&& cast_algo =
reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo));
return reinterpret_cast<AlgoBase*>(
megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>(
cast_algo, attr));
megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>(
cast_algo, positive_attr, negative_attr));
#endif
};
......@@ -348,27 +341,13 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
}
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<
ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter", attr);
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter");
}
return megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter", positive_attr, negative_attr);
} else {
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<
ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter", attr);
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
}
return megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter", positive_attr, negative_attr);
}
}
......
......@@ -59,11 +59,11 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......@@ -77,19 +77,22 @@ public:
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(filter, filter_meta, diff, grad,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
......@@ -118,11 +121,11 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
......@@ -130,7 +133,8 @@ private:
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static AlgoPack sm_algo_pack;
};
......@@ -146,19 +150,22 @@ public:
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, diff, grad, grad_meta,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
......@@ -181,11 +188,11 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
......@@ -193,7 +200,8 @@ private:
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static AlgoPack sm_algo_pack;
};
......
......@@ -77,9 +77,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......@@ -164,7 +167,7 @@ public:
TensorLayout& grad_pg);
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
......
......@@ -71,9 +71,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......@@ -170,7 +173,7 @@ public:
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
......
......@@ -76,9 +76,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......@@ -124,7 +127,7 @@ public:
AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
ret |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
......
......@@ -97,8 +97,10 @@ namespace convolution3d {
const cudnnConvolutionDescriptor_t conv_desc,
const cudnnTensorDescriptor_t y_desc,
size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo,
const AlgoAttribute& attr) {
MEGDNN_MARK_USED_VAR(attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(positive_attr);
MEGDNN_MARK_USED_VAR(negative_attr);
#if CUDNN_MAJOR >= 7
int algo_max_count = 0;
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(
......@@ -118,7 +120,7 @@ namespace convolution3d {
cudnn_handle, x_desc, w_desc, conv_desc, y_desc,
algo_perf[i].algo, &workspace_size));
if (workspace_size > workspace_limit_in_bytes) continue;
if (!(attr & AlgoAttribute::REPRODUCIBLE)) {
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo;
return true;
} else {
......@@ -144,8 +146,11 @@ namespace convolution3d {
const cudnnConvolutionDescriptor_t conv_desc,
const cudnnTensorDescriptor_t dx_desc,
size_t workspace_limit_in_bytes,
cudnnConvolutionBwdDataAlgo_t* algo, const AlgoAttribute& attr) {
MEGDNN_MARK_USED_VAR(attr);
cudnnConvolutionBwdDataAlgo_t* algo,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(positive_attr);
MEGDNN_MARK_USED_VAR(negative_attr);
#if CUDNN_MAJOR >= 7
int algo_max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(
......@@ -166,7 +171,7 @@ namespace convolution3d {
cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc,
algo_perf[i].algo, &workspace_size));
if (workspace_size > workspace_limit_in_bytes) continue;
if (!(attr & AlgoAttribute::REPRODUCIBLE)) {
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo;
return true;
} else {
......@@ -193,8 +198,11 @@ namespace convolution3d {
const cudnnConvolutionDescriptor_t conv_desc,
const cudnnFilterDescriptor_t dw_desc,
size_t workspace_limit_in_bytes,
cudnnConvolutionBwdFilterAlgo_t* algo, const AlgoAttribute& attr) {
MEGDNN_MARK_USED_VAR(attr);
cudnnConvolutionBwdFilterAlgo_t* algo,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(positive_attr);
MEGDNN_MARK_USED_VAR(negative_attr);
#if CUDNN_MAJOR >= 7
int algo_max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(
......@@ -215,7 +223,7 @@ namespace convolution3d {
cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc,
algo_perf[i].algo, &workspace_size));
if (workspace_size > workspace_limit_in_bytes) continue;
if (!(attr & AlgoAttribute::REPRODUCIBLE)) {
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
*algo = algo_perf[i].algo;
return true;
} else {
......@@ -235,7 +243,6 @@ namespace convolution3d {
#endif
}
} // namespace convolution3d
} // namespace cuda
} // namespace megdnn
......
......@@ -33,16 +33,18 @@ Convolution3DForwardImpl::Algorithm*
Convolution3DForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = check_layout_fwd(src, filter, dst);
return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
Convolution3DForwardImpl::Algorithm*
Convolution3DForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, filter, dst);
#if CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5)
......@@ -51,25 +53,27 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
// version is lower than v7.5.0 is still slower than our implementation
// in many channel-wise cases
if (sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr,
workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise;
}
}
#endif
auto prefer_1x1x1 = [&args, attr, workspace_limit_in_bytes]() {
auto prefer_1x1x1 = [&args, positive_attr, negative_attr,
workspace_limit_in_bytes]() {
const size_t MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO = 4;
size_t batch_size = args.src_layout->shape[0];
if (batch_size > MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO) {
return false;
}
return sm_algo_pack.a1x1x1.is_available_attribute(
args, attr, workspace_limit_in_bytes);
args, positive_attr, negative_attr, workspace_limit_in_bytes);
};
auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes,
attr]() -> Convolution3DForwardImpl::AlgoBase* {
[this, &args, workspace_limit_in_bytes, positive_attr,
negative_attr]() -> Convolution3DForwardImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle());
cudnnConvolutionFwdAlgo_t algo;
CUDNNForwardDescs desc;
......@@ -78,11 +82,12 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
bool got = cudnn_get_convolution_fwd_algo_helper(
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc,
desc.conv_desc.desc, desc.dst_desc.desc,
workspace_limit_in_bytes, &algo, attr);
workspace_limit_in_bytes, &algo, positive_attr, negative_attr);
if (got) {
return static_cast<AlgoBase*>(
megdnn::get_algo_with_attribute<Convolution3DForwardImpl>(
sm_algo_pack.cudnn_from_enum(algo), attr));
megdnn::get_algo_match_attribute<Convolution3DForwardImpl>(
sm_algo_pack.cudnn_from_enum(algo), positive_attr,
negative_attr));
} else {
return nullptr;
}
......@@ -108,15 +113,9 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
args = orig_args;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<Convolution3DForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d fwd", attr);
} else {
return megdnn::get_usable_algo<Convolution3DForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d fwd");
}
return megdnn::get_algo_match_attribute<Convolution3DForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d fwd", positive_attr, negative_attr);
}
std::vector<Convolution3DForwardImpl::Algorithm*>
......@@ -169,28 +168,30 @@ Convolution3DBackwardDataImpl::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
Convolution3DBackwardDataImpl::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise;
}
auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes,
attr]() -> Convolution3DBackwardDataImpl::AlgoBase* {
[this, &args, workspace_limit_in_bytes, positive_attr,
negative_attr]() -> Convolution3DBackwardDataImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle());
cudnnConvolutionBwdDataAlgo_t algo;
CUDNNBwdDataDescs desc;
......@@ -198,11 +199,12 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
bool got = cudnn_get_convolution_bwd_data_algo_helper(
cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc,
desc.conv_desc.desc, desc.grad_desc.desc,
workspace_limit_in_bytes, &algo, attr);
workspace_limit_in_bytes, &algo, positive_attr, negative_attr);
if (got) {
return static_cast<AlgoBase*>(megdnn::get_algo_with_attribute<
return static_cast<AlgoBase*>(megdnn::get_algo_match_attribute<
Convolution3DBackwardDataImpl>(
sm_algo_pack.cudnn_from_enum(algo), attr));
sm_algo_pack.cudnn_from_enum(algo), positive_attr,
negative_attr));
} else {
return nullptr;
}
......@@ -224,15 +226,9 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
args = orig_args;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<Convolution3DBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d bwd data", attr);
} else {
return megdnn::get_usable_algo<Convolution3DBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d bwd data");
}
return megdnn::get_algo_match_attribute<Convolution3DBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d bwd data", positive_attr, negative_attr);
}
size_t Convolution3DBackwardDataImpl::get_workspace_in_bytes(
......@@ -269,28 +265,30 @@ Convolution3DBackwardFilterImpl::Algorithm*
Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
Convolution3DBackwardFilterImpl::Algorithm*
Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, diff, grad);
if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise;
}
auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes,
attr]() -> Convolution3DBackwardFilterImpl::AlgoBase* {
[this, &args, workspace_limit_in_bytes, positive_attr,
negative_attr]() -> Convolution3DBackwardFilterImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle());
cudnnConvolutionBwdFilterAlgo_t algo;
CUDNNBwdFilterDescs desc;
......@@ -298,11 +296,12 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
bool got = cudnn_get_convolution_bwd_filter_algo_helper(
cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc,
desc.conv_desc.desc, desc.grad_desc.desc,
workspace_limit_in_bytes, &algo, attr);
workspace_limit_in_bytes, &algo, positive_attr, negative_attr);
if (got) {
return static_cast<AlgoBase*>(megdnn::get_algo_with_attribute<
return static_cast<AlgoBase*>(megdnn::get_algo_match_attribute<
Convolution3DBackwardFilterImpl>(
sm_algo_pack.cudnn_from_enum(algo), attr));
sm_algo_pack.cudnn_from_enum(algo), positive_attr,
negative_attr));
} else {
return nullptr;
}
......@@ -323,15 +322,9 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
args = orig_args;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<Convolution3DBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d bwd filter", attr);
} else {
return megdnn::get_usable_algo<Convolution3DBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d bwd filter");
}
return megdnn::get_algo_match_attribute<Convolution3DBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv3d bwd filter", positive_attr, negative_attr);
}
size_t Convolution3DBackwardFilterImpl::get_workspace_in_bytes(
......
......@@ -25,9 +25,11 @@ public:
const CanonizedFilterMeta& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, filter, dst,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& src,
......@@ -48,19 +50,19 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const CanonizedFilterMeta& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static AlgoPack sm_algo_pack;
};
......@@ -73,9 +75,11 @@ public:
AlgorithmInfo get_algorithm_info_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& filter,
......@@ -98,18 +102,19 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static AlgoPack sm_algo_pack;
};
......@@ -122,13 +127,14 @@ public:
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) override;
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, diff, grad,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
......@@ -149,18 +155,19 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static AlgoPack sm_algo_pack;
};
......
......@@ -82,9 +82,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -75,9 +75,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -70,9 +70,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -59,10 +59,12 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& mask,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = make_canonized_filter_meta(im.ndim, filter, offset);
return get_algorithm_heuristic(im, fm, offset, mask, dst,
workspace_limit_in_bytes, attr);
workspace_limit_in_bytes, positive_attr,
negative_attr);
}
AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im,
......@@ -71,17 +73,20 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& mask,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, im, filter, offset, mask, dst);
if (sm_algo_pack.algo_matmul.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_matmul;
}
megdnn_throw(ssprintf(
"no deformable conv fwd algorithm with attribute%s , args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(),
workspace_limit_in_bytes));
megdnn_throw(
ssprintf("no deformable conv fwd algorithm without attribute(%s) "
"with attribute(%s) , args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
}
const char* Fwd::get_algorithm_set_name() const {
......@@ -114,28 +119,33 @@ std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms(const TensorLayout& /* im */
AlgoBwdFlt* BwdFlt::get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& filter_grad,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
const TensorLayout& filter_grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = make_canonized_filter_meta(im.ndim, filter_grad, offset);
return get_algorithm_heuristic(im, offset, mask, out_grad, fm,
workspace_limit_in_bytes, attr);
workspace_limit_in_bytes, positive_attr,
negative_attr);
}
AlgoBwdFlt* BwdFlt::get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad,
const CanonizedFilterMeta& filter_grad,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
const CanonizedFilterMeta& filter_grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, im, offset, mask, out_grad, filter_grad);
if (sm_algo_pack.algo_matmul.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_matmul;
}
megdnn_throw(
ssprintf("no deformable conv bwd filter algorithm with "
"attribute%s, args(%s) and "
ssprintf("no deformable conv bwd filter algorithm without "
"attribute(%s) with "
"attribute(%s), args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(attr).c_str(),
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
}
......@@ -176,11 +186,12 @@ AlgoBwdData* BwdData::get_algorithm_heuristic(
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = make_canonized_filter_meta(im.ndim, filter, offset);
return get_algorithm_heuristic(im, fm, offset, mask, out_grad, im_grad,
offset_grad, mask_grad,
workspace_limit_in_bytes, attr);
return get_algorithm_heuristic(
im, fm, offset, mask, out_grad, im_grad, offset_grad, mask_grad,
workspace_limit_in_bytes, positive_attr, negative_attr);
}
AlgoBwdData* BwdData::get_algorithm_heuristic(
......@@ -188,18 +199,21 @@ AlgoBwdData* BwdData::get_algorithm_heuristic(
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, im, filter, offset, mask, out_grad, im_grad,
offset_grad, mask_grad);
if (sm_algo_pack.algo_matmul.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_matmul;
}
megdnn_throw(
ssprintf("no deformable conv bwd data algorithm with attribute%s, "
ssprintf("no deformable conv bwd data algorithm without "
"attribute(%s) with attribute(%s), "
"args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(attr).c_str(),
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
}
......
......@@ -36,7 +36,8 @@ public:
const TensorLayout& mask,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
const char* get_algorithm_set_name() const override;
......@@ -54,13 +55,12 @@ protected:
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& filter,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......@@ -81,7 +81,8 @@ public:
const TensorLayout& out_grad,
const CanonizedFilterMeta& filter_grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
size_t get_workspace_in_bytes(const TensorLayout& im,
const TensorLayout& offset,
......@@ -105,13 +106,12 @@ protected:
const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& filter_grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& out_grad,
const TensorLayout& filter_grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& filter_grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......@@ -132,7 +132,8 @@ public:
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr);
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
size_t get_workspace_in_bytes(const TensorLayout& im,
const TensorLayout& filter,
......@@ -166,8 +167,8 @@ protected:
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......
......@@ -61,9 +61,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -61,9 +61,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -62,9 +62,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -20,30 +20,32 @@ using namespace cuda;
/* ============== LocalShareForwardImpl ============== */
LocalShareForwardImpl::Algorithm*
LocalShareForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
LocalShareForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, filter, dst);
if (sm_algo_pack.batch_size_aware_chwn_small_image
.is_available_attribute(args, attr,
.is_available_attribute(args, positive_attr, negative_attr,
workspace_limit_in_bytes)) {
return &sm_algo_pack.batch_size_aware_chwn_small_image;
}
if (sm_algo_pack.batch_size_aware_chwn.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batch_size_aware_chwn;
}
if (sm_algo_pack.batched_matmul.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batched_matmul;
}
megdnn_throw(ssprintf(
"no local share conv algorithm with attribute%s, args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(),
workspace_limit_in_bytes));
megdnn_throw(
ssprintf("no local share conv algorithm without attribute(%s) with "
"attribute(%s), args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
}
std::vector<LocalShareForwardImpl::Algorithm*>
......@@ -79,21 +81,24 @@ LocalShareBackwardDataImpl::Algorithm*
LocalShareBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
if (sm_algo_pack.implicit_gemm.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.implicit_gemm;
}
if (sm_algo_pack.batched_matmul.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batched_matmul;
}
megdnn_throw(ssprintf(
"no local share bwd data algorithm with attribute%s args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(),
workspace_limit_in_bytes));
megdnn_throw(
ssprintf("no local share bwd data algorithm without attribute(%s) "
"with attribute(%s) args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
}
std::vector<LocalShareBackwardDataImpl::Algorithm*>
......@@ -129,21 +134,24 @@ LocalShareBackwardFilterImpl::Algorithm*
LocalShareBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, diff, grad);
if (sm_algo_pack.implicit_gemm.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.implicit_gemm;
}
if (sm_algo_pack.batched_matmul.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.batched_matmul;
}
megdnn_throw(
ssprintf("no local share bwd filter algorithm with attribute%s, "
ssprintf("no local share bwd filter algorithm without "
"attribute(%s) with attribute(%s), "
"args(%s) and "
"workspace limit (%zu bytes)",
Algorithm::attribute_str(attr).c_str(),
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes));
}
......
......@@ -39,11 +39,12 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
};
......@@ -71,11 +72,11 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......@@ -104,11 +105,11 @@ protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......
......@@ -85,9 +85,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) const {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -30,35 +30,30 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.cublas.is_available_attribute(args, attr,
workspace_limit_in_bytes)) {
if (sm_algo_pack.cublas.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.cublas;
}
#if CUDA_VERSION >= 10010
if (sm_algo_pack.cublas_lt.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.cublas_lt;
}
#endif
#if CUDA_VERSION >= 10000
if (sm_algo_pack.wmma_uint4x4x32.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.wmma_uint4x4x32;
}
#endif
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward", attr);
} else {
return megdnn::get_usable_algo<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward");
}
return megdnn::get_algo_match_attribute<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward", positive_attr, negative_attr);
}
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
......
......@@ -57,11 +57,10 @@ protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
......
......@@ -65,9 +65,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) const {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -31,21 +31,16 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
BatchedMatrixMulForwardImpl::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.algo_default.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.algo_default;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward", attr);
} else {
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward");
}
return megdnn::get_algo_match_attribute<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward", positive_attr, negative_attr);
}
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes(
......
......@@ -36,11 +36,11 @@ private:
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override;
const char* get_algorithm_set_name() const override {
return "FALLBACK BATCHED MATMUL";
......
......@@ -280,20 +280,23 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
auto result = get_algorithm_heuristic_with_ncb(
fparam, workspace_limit_in_bytes, attr);
fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
if (result == nullptr) {
result = naive::ConvBiasForwardImpl::get_algorithm_heuristic(
src, filter, bias, z, dst, workspace_limit_in_bytes, attr);
src, filter, bias, z, dst, workspace_limit_in_bytes,
positive_attr, negative_attr);
}
return result;
}
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) {
......@@ -301,7 +304,8 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
ConvBiasImpl::Algorithm* heuristic_algo = nullptr;
for (auto i : origin_algos) {
bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
param, AlgoSelectionStrategy::HEURISTIC, attr);
param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
negative_attr);
if (usable_attribute &&
static_cast<AlgoBase*>(i)->get_workspace(param) <=
workspace_limit_in_bytes) {
......@@ -497,7 +501,8 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
if (!m_prev_selected_algo ||
memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
param, workspace_size, AlgoAttribute::DEFAULT);
param, workspace_size, AlgoAttribute::DEFAULT,
AlgoAttribute::DEFAULT);
m_prev_selected_algo_sizep = param;
}
return m_prev_selected_algo;
......
......@@ -89,13 +89,12 @@ public:
const TensorLayout& dst) override;
//! implemented by get_algorithm_heuristic_with_ncb()
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
//! size param for kernels with non-contiguous batch
struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam {
......@@ -319,11 +318,14 @@ public:
return false;
}
bool usable_attribute(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const {
return contain_attribute(attr) &&
bool usable_attribute(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy,
const AlgoAttribute& positive_attr =
AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr =
AlgoAttribute::DEFAULT) const {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
usable(param, algo_selection_strategy);
}
......@@ -361,7 +363,8 @@ protected:
virtual Algorithm* get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
const char* get_algorithm_set_name() const override;
......
......@@ -198,13 +198,15 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
auto result = get_algorithm_heuristic_with_ncb(
fparam, workspace_limit_in_bytes, attr);
fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
if (result == nullptr) {
result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
src, filter, dst, workspace_limit_in_bytes, attr);
src, filter, dst, workspace_limit_in_bytes, positive_attr,
negative_attr);
}
return result;
}
......@@ -312,7 +314,8 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) {
......@@ -320,7 +323,8 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
for (auto i : origin_algos) {
bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
param, AlgoSelectionStrategy::HEURISTIC, attr);
param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
negative_attr);
if (usable_attribute &&
static_cast<AlgoBase*>(i)->get_workspace(param) <=
workspace_limit_in_bytes) {
......@@ -391,7 +395,8 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
if (!m_prev_selected_algo ||
memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
param, workspace_size, AlgoAttribute::DEFAULT);
param, workspace_size, AlgoAttribute::DEFAULT,
AlgoAttribute::DEFAULT);
m_prev_selected_algo_sizep = param;
}
return m_prev_selected_algo;
......@@ -513,15 +518,17 @@ ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4) {
return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
filter, diff, grad, workspace_limit_in_bytes, attr);
filter, diff, grad, workspace_limit_in_bytes, positive_attr,
negative_attr);
}
auto fparam = make_ncb_kern_size_param(filter, diff, grad);
return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
ConvolutionBackwardDataImpl::NCBKernSizeParam
......@@ -666,15 +673,16 @@ ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
if (param.filter_meta.group != 1) {
auto p1g = param;
p1g.filter_meta.group = 1;
return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
......@@ -729,10 +737,12 @@ ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
for (auto i : ncb_1g_get_all_algorithms(param)) {
if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
if (i->contain_attribute(attr)) {
if (i->contain_attribute_all(positive_attr) &&
!i->contain_attribute_any(negative_attr)) {
return i;
}
}
......@@ -783,7 +793,7 @@ ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
param, std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT);
AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
m_prev_selected_algo_sizep = param;
}
return m_prev_selected_algo;
......
......@@ -86,11 +86,11 @@ public:
const TensorLayout& dst) override;
//! implemented by get_algorithm_heuristic_with_ncb()
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
//! size param for kernels with non-contiguous batch
struct NCBKernSizeParam {
......@@ -238,11 +238,14 @@ public:
return false;
}
bool usable_attribute(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const {
return contain_attribute(attr) &&
bool usable_attribute(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy,
const AlgoAttribute& positive_attr =
AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr =
AlgoAttribute::DEFAULT) const {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
usable(param, algo_selection_strategy);
}
......@@ -272,7 +275,8 @@ protected:
virtual Algorithm* get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
const char* get_algorithm_set_name() const override;
......@@ -322,11 +326,11 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
const char* get_algorithm_set_name() const override;
//! size param for kernels with non-contiguous batch
......@@ -421,10 +425,14 @@ protected:
virtual ncb_kern_t dispatch_kern(
ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0;
bool usable_attribute(
ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const {
return contain_attribute(attr) && usable(opr, param);
bool usable_attribute(ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param,
const AlgoAttribute& positive_attr =
AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr =
AlgoAttribute::DEFAULT) const {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) && usable(opr, param);
}
virtual bool is_preferred(const NCBKernSizeParam&) const {
return false;
......@@ -449,7 +457,8 @@ protected:
//! default impl calls ncb_1g_get_algorithm_heuristic()
virtual Algorithm* get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
//! get kernel pointer for float32 non-contiguous batch 1-group kernel
virtual ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo,
......@@ -467,7 +476,8 @@ protected:
*/
virtual Algorithm* ncb_1g_get_algorithm_heuristic(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static bool is_matrix_mul_preferred(const NCBKernSizeParam& param);
/**
......
......@@ -131,20 +131,24 @@ MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc(
MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto kern_size_param = make_kern_size_param(A, B, C);
if (auto algo = static_cast<AlgoBase*>(
get_algorithm_from_desc(execution_policy().algo))) {
megdnn_assert(algo->get_workspace(kern_size_param) <
workspace_limit_in_bytes);
auto cur = megdnn::get_algo_with_attribute<MatrixMulImpl>(algo, attr);
auto cur = megdnn::get_algo_match_attribute<MatrixMulImpl>(
algo, positive_attr, negative_attr);
if (cur)
return cur;
megdnn_throw(ssprintf(
"require algorithm with attribute%s, but given algorithm with "
"attribute%s",
Algorithm::attribute_str(attr).c_str(),
Algorithm::attribute_str(algo->attribute()).c_str()));
megdnn_throw(
ssprintf("require algorithm without attribute(%s) with "
"attribute(%s), but given algorithm with "
"attribute(%s)",
Algorithm::attribute_str(negative_attr).c_str(),
Algorithm::attribute_str(positive_attr).c_str(),
Algorithm::attribute_str(algo->attribute()).c_str()));
}
AlgoTypePack algo_type;
algo_type.data_type = kern_size_param.deduce_algo_data_type();
......@@ -157,7 +161,7 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <=
workspace_limit_in_bytes) {
if (static_cast<AlgoBase*>(algo)->preferred_attribute(
kern_size_param, attr)) {
kern_size_param, positive_attr, negative_attr)) {
//! use gemv algo if it's prefered
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
return algo;
......@@ -215,9 +219,9 @@ MatrixMulImpl::KernParam MatrixMulImpl::make_kern_param(
size_t MatrixMulImpl::get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
if (auto algo = get_algorithm_heuristic(A, B, C,
std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT)) {
if (auto algo = get_algorithm_heuristic(
A, B, C, std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT)) {
auto kern_size_param = make_kern_size_param(A, B, C);
return static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param);
}
......@@ -230,6 +234,7 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
if (auto algo = get_algorithm_heuristic(A.layout, B.layout, C.layout,
std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT,
AlgoAttribute::DEFAULT)) {
auto kern_param = make_kern_param(A, B, C, workspace);
auto kern = static_cast<AlgoBase*>(algo)->get_kern(kern_param);
......
......@@ -225,8 +225,11 @@ public:
};
bool preferred_attribute(
const KernSizeParam& param,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) {
return contain_attribute(attr) && preferred(param);
const AlgoAttribute& positive_attr =
AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) && preferred(param);
};
virtual MatmulDescription matmul_description() const = 0;
......@@ -267,12 +270,10 @@ protected:
const TensorLayout& B,
const TensorLayout& C) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
};
} // namespace fallback
......
......@@ -125,14 +125,11 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* bias */, const TensorLayout& /* z */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */
,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo = static_cast<HandleImpl*>(handle())
->default_batch_conv_bias_fwd_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......
......@@ -31,13 +31,12 @@ public:
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
......
......@@ -76,7 +76,8 @@ BatchedMatrixMulForward::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) {
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) {
return static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo();
}
......
......@@ -28,11 +28,11 @@ public:
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
......
......@@ -246,14 +246,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* bias */, const TensorLayout& /* z */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......
......@@ -31,13 +31,12 @@ public:
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
......
......@@ -272,14 +272,11 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &,
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo =
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......@@ -302,14 +299,11 @@ ConvolutionBackwardData::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......@@ -333,14 +327,11 @@ ConvolutionBackwardFilter::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo =
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......
......@@ -25,11 +25,11 @@ class ConvolutionForwardImpl: public ConvolutionForward {
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const PreprocessedFilter*) override {
......@@ -67,11 +67,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override;
......@@ -90,11 +90,11 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override;
......
......@@ -120,13 +120,10 @@ Convolution3DForward::Algorithm*
Convolution3DForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......@@ -150,14 +147,11 @@ Convolution3DBackwardData::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo =
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......@@ -183,14 +177,11 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */
,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo = static_cast<HandleImpl*>(handle())
->default_conv3d_bwd_filter_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......
......@@ -22,11 +22,11 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
......@@ -44,11 +44,11 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
......@@ -66,11 +66,11 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
......
......@@ -26,13 +26,13 @@ public:
return std::vector<Algorithm*>();
};
Algorithm* get_algorithm_heuristic(const TensorLayout& /* src */,
const TensorLayout& /* filter */,
const TensorLayout& /* offset */,
const TensorLayout& /* mask */,
const TensorLayout& /* dst */,
size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*attr*/) override {
Algorithm* get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */,
size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override {
return nullptr;
};
......@@ -68,13 +68,13 @@ public:
return std::vector<Algorithm*>();
};
Algorithm* get_algorithm_heuristic(const TensorLayout& /* im */,
const TensorLayout& /* offset */,
const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */,
const TensorLayout& /* filter_grad */,
size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*attr*/) override {
Algorithm* get_algorithm_heuristic(
const TensorLayout& /* im */, const TensorLayout& /* offset */,
const TensorLayout& /* mask */, const TensorLayout& /* out_grad */,
const TensorLayout& /* filter_grad */,
size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override {
return nullptr;
};
......@@ -112,16 +112,16 @@ public:
return std::vector<Algorithm*>();
};
Algorithm* get_algorithm_heuristic(const TensorLayout& /* im */,
const TensorLayout& /* filter */,
const TensorLayout& /* offset */,
const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */,
const TensorLayout& /* im_grad */,
const TensorLayout& /* offset_grad */,
const TensorLayout& /* mask_grad */,
size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*attr*/) override {
Algorithm* get_algorithm_heuristic(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */,
const TensorLayout& /* im_grad */,
const TensorLayout& /* offset_grad */,
const TensorLayout& /* mask_grad */,
size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override {
return nullptr;
};
......
......@@ -162,14 +162,11 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&,
LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo =
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......@@ -194,14 +191,11 @@ LocalShareBackwardData::Algorithm*
LocalShareBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_data_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......@@ -226,14 +220,11 @@ LocalShareBackwardFilter::Algorithm*
LocalShareBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_filter_algo();
megdnn_assert(algo->contain_attribute(attr),
"require algorithm with attribute%s, but heuristic "
"algorithm(%s) with attribute%s ",
Algorithm::attribute_str(attr).c_str(), algo->name(),
Algorithm::attribute_str(algo->attribute()).c_str());
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
......
......@@ -30,11 +30,11 @@ public:
const TensorLayout& /*src*/, const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/,
const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/,
size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /*src*/, const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
......@@ -55,11 +55,11 @@ public:
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*filter*/,
const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/,
size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
......@@ -80,11 +80,11 @@ public:
const TensorLayout& /*src*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/,
const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/,
size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /*src*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
......
......@@ -91,7 +91,8 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) {
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) {
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
}
......
......@@ -29,11 +29,11 @@ public:
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
......
......@@ -72,9 +72,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) const {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -32,21 +32,17 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
BatchedMatrixMulForwardImpl::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.blas.is_available_attribute(args, attr,
workspace_limit_in_bytes)) {
if (sm_algo_pack.blas.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.blas;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward", attr);
} else {
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward");
}
return megdnn::get_algo_match_attribute<BatchedMatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"batched matrix mul forward", positive_attr, negative_attr);
}
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes(
......
......@@ -36,11 +36,11 @@ private:
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override;
const char* get_algorithm_set_name() const override {
return "ROCM BATCHED MATMUL";
......
......@@ -76,9 +76,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......
......@@ -73,9 +73,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......
......@@ -75,9 +75,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
......
......@@ -29,40 +29,43 @@ using namespace rocm;
/* ============== ConvolutionForwardImpl ============== */
ConvolutionForwardImpl::Algorithm*
ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
ConvolutionForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = check_layout_fwd(src, filter, dst);
return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
ConvolutionForwardImpl::Algorithm*
ConvolutionForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, filter, dst);
//! MIOpen auto-tuning need to run with actual tensors, so we cannot get
//! best algorithm here.
if (is_miopen_supported(args)) {
auto algo = megdnn::get_algo_with_attribute<ConvolutionForwardImpl>(
sm_algo_pack.miopen_algos[0], attr);
auto algo = megdnn::get_algo_match_attribute<ConvolutionForwardImpl>(
sm_algo_pack.miopen_algos[0], positive_attr, negative_attr);
if (algo)
return algo;
}
if (args.filter_meta.group > 1) {
if (sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr,
workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise;
}
}
auto prefer_1x1 = [&args, attr, workspace_limit_in_bytes]() {
auto prefer_1x1 = [&args, positive_attr, negative_attr,
workspace_limit_in_bytes]() {
const size_t MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO = 4;
size_t batch_size = args.src_layout->shape[0];
......@@ -70,14 +73,15 @@ ConvolutionForwardImpl::get_algorithm_heuristic(
return false;
}
return sm_algo_pack.a1x1.is_available_attribute(
args, attr, workspace_limit_in_bytes);
args, positive_attr, negative_attr, workspace_limit_in_bytes);
};
if (prefer_1x1()) {
return &sm_algo_pack.a1x1;
}
auto prefer_1x1_large_batch = [&args, attr, workspace_limit_in_bytes]() {
auto prefer_1x1_large_batch = [&args, positive_attr, negative_attr,
workspace_limit_in_bytes]() {
const size_t MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO = 32;
size_t batch_size = args.src_layout->shape[0];
......@@ -85,22 +89,16 @@ ConvolutionForwardImpl::get_algorithm_heuristic(
return false;
}
return sm_algo_pack.batched_matrix_mul.is_available_attribute(
args, attr, workspace_limit_in_bytes);
args, positive_attr, negative_attr, workspace_limit_in_bytes);
};
if (prefer_1x1_large_batch()) {
return &sm_algo_pack.batched_matrix_mul;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<ConvolutionForwardImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv fwd", attr);
} else {
return megdnn::get_usable_algo<ConvolutionForwardImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv fwd");
}
return megdnn::get_algo_match_attribute<ConvolutionForwardImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv fwd", positive_attr, negative_attr);
}
std::vector<ConvolutionForwardImpl::Algorithm*>
......@@ -156,41 +154,39 @@ ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
if (is_miopen_supported(args.as_fwd_args())) {
auto algo = megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.miopen_algos[0], attr);
auto algo =
megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.miopen_algos[0], positive_attr,
negative_attr);
if (algo)
return algo;
}
if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.chanwise;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv bwd_data", attr);
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv bwd_data");
}
return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv bwd_data", positive_attr, negative_attr);
}
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
......@@ -229,43 +225,40 @@ ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes,
attr);
positive_attr, negative_attr);
}
ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, diff, grad);
if (is_miopen_supported(args.as_fwd_args())) {
auto algo =
megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>(
sm_algo_pack.miopen_algos[0], attr);
megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>(
sm_algo_pack.miopen_algos[0], positive_attr,
negative_attr);
if (algo)
return algo;
}
if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_attribute(
args, attr, workspace_limit_in_bytes)) {
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
// prefer special chanwise impl
return &sm_algo_pack.chanwise;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv bwd_filter", attr);
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv bwd_filter");
}
return megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes,
"rocm conv bwd_filter", positive_attr, negative_attr);
}
size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
......
......@@ -26,9 +26,11 @@ public:
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, filter, dst,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& src,
......@@ -72,16 +74,17 @@ private:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const CanonizedFilterMeta& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static AlgoPack sm_algo_pack;
};
......@@ -94,9 +97,11 @@ public:
AlgorithmInfo get_algorithm_info_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& filter,
......@@ -118,16 +123,17 @@ private:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static AlgoPack sm_algo_pack;
};
......@@ -137,13 +143,14 @@ public:
using ConvolutionBackwardFilter::ConvolutionBackwardFilter;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) {
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, diff, grad,
workspace_limit_in_bytes, attr)
workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& src,
......@@ -165,16 +172,17 @@ private:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr);
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr);
static AlgoPack sm_algo_pack;
};
......
......@@ -72,9 +72,12 @@ public:
}
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
size_t limit = std::numeric_limits<size_t>::max()) const {
return contain_attribute(attr) && is_available_wk(args, limit);
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
......
......@@ -29,21 +29,16 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args{this, A, B, C};
if (sm_algo_pack.blas.is_available_attribute(args, attr,
workspace_limit_in_bytes)) {
if (sm_algo_pack.blas.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.blas;
}
if (attr != AlgoAttribute::DEFAULT) {
return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward", attr);
} else {
return megdnn::get_usable_algo<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward");
}
return megdnn::get_algo_match_attribute<MatrixMulForwardImpl>(
sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
"matrix mul forward", positive_attr, negative_attr);
}
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
......
......@@ -36,11 +36,11 @@ private:
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*attr*/) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override;
const char* get_algorithm_set_name() const override {
return "ROCM MATMUL";
......
......@@ -278,27 +278,21 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
return ret;
}
AlgoAttribute extract_algo_attribute_from_execution_strategy(
//! return pair<positive_attr, negative_attr>
std::pair<AlgoAttribute, AlgoAttribute>
extract_algo_attribute_from_execution_strategy(
const ExecutionStrategy& strategy) {
AlgoAttribute ret = AlgoAttribute::DEFAULT;
std::pair<AlgoAttribute, AlgoAttribute> ret =
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
if (strategy & ExecutionStrategy::REPRODUCIBLE) {
ret |= AlgoAttribute::REPRODUCIBLE;
ret.first |= AlgoAttribute::REPRODUCIBLE;
}
return ret;
}
//! Test whether the algo attribute of a algo match the require
//! algo_strategy
static bool algo_attribute_match_strategy(AlgoAttribute attribute,
ExecutionStrategy selected_strategy) {
bool ret = true;
if (selected_strategy & ExecutionStrategy::OPTMIZED) {
ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute));
} else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) {
ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute);
if (strategy & ExecutionStrategy::OPTMIZED) {
ret.second |= AlgoAttribute::NAIVE;
}
return ret;
}
} // namespace
namespace mgb {
......@@ -311,7 +305,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
return;
AlgoChooserProfileCache::Result prof_rst;
auto target_attribute =
auto target_attr =
extract_algo_attribute_from_execution_strategy(selected_strategy);
std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out);
double cur_timeout = 0;
......@@ -332,14 +326,16 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
continue;
}
auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo);
if (!algo_attribute_match_strategy(palgo->attribute(),
selected_strategy)) {
if (!(palgo->contain_attribute_all(target_attr.first) &&
!palgo->contain_attribute_any(target_attr.second))) {
mgb_log_debug(
"skip algo %s with attribute%s, which is not match the "
"profile strategy required attribute%s.",
"skip algo %s with attribute(%s), which is not match the "
"profile strategy required contain attribute(%s) and not "
"contain attribute(%s).",
algo.name.c_str(),
Algorithm::attribute_str(palgo->attribute()).c_str(),
Algorithm::attribute_str(target_attribute).c_str());
Algorithm::attribute_str(target_attr.first).c_str(),
Algorithm::attribute_str(target_attr.second).c_str());
continue;
}
......@@ -370,10 +366,12 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
rst.workspace, rst.time);
prof_rst.push_back(rst);
}
std::string msg =
ssprintf("no usable %s algorithm %s with attribute(%s)",
ctx.mgb_opr()->dyn_typeinfo()->name, layouts_str.c_str(),
Algorithm::attribute_str(target_attribute).c_str());
std::string msg = ssprintf(
"no usable %s algorithm %s with attribute(%s) and without "
"attribute(%s)",
ctx.mgb_opr()->dyn_typeinfo()->name, layouts_str.c_str(),
Algorithm::attribute_str(target_attr.first).c_str(),
Algorithm::attribute_str(target_attr.second).c_str());
mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
FixedTensorLayouts origin_layouts = ctx.layouts();
......@@ -460,9 +458,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(palgo, "Unknown algo description");
ret.append("): algo=" + std::string(palgo->name()));
ret.append(ssprintf(" workspace=%.2fMiB attirbute=%d",
ret.append(ssprintf(" workspace=%.2fMiB attirbute(%s)",
workspace / (1024 * 1024.0),
static_cast<uint32_t>(palgo->attribute())));
Algorithm::attribute_str(palgo->attribute()).c_str()));
mgb_log_debug("%s", ret.c_str());
megdnn_opr->execution_policy() = policy;
......@@ -602,13 +600,14 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
}
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
auto attr =
extract_algo_attribute_from_execution_strategy(selected_strategy);
ImplExecutionPolicy policy;
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit,
extract_algo_attribute_from_execution_strategy(
selected_strategy)),
m_layouts)
.desc;
policy.algo =
APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, attr.first, attr.second),
m_layouts)
.desc;
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(algo, "Unknown algo description");
......@@ -666,13 +665,14 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
} else {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
policy.algo =
APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit,
extract_algo_attribute_from_execution_strategy(
selected_strategy)),
m_layouts)
.desc;
auto attr = extract_algo_attribute_from_execution_strategy(
selected_strategy);
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, attr.first,
attr.second),
m_layouts)
.desc;
}
mgb_assert(policy.algo.valid(),
"No algo found from cache or heuristic, maybe some error "
......
......@@ -2189,7 +2189,7 @@ TEST(TestOprDNN, HeuristicReproducible) {
megdnn_opr->get_algorithm_from_desc(algo);
mgb_assert(palgo, "Unknown algo description");
if (strategy == S(S::HEURISTIC | S::REPRODUCIBLE)) {
EXPECT_TRUE(palgo->contain_attribute(
EXPECT_TRUE(palgo->contain_attribute_all(
megdnn::AlgoAttribute::REPRODUCIBLE));
}
algo_name0 = palgo->name();
......@@ -2371,21 +2371,23 @@ public:
std::vector<AlgorithmInfo>(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2));
MOCK_METHOD5(get_algorithm_info_heuristic,
MOCK_METHOD6(get_algorithm_info_heuristic,
AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr));
const TensorLayout& p2,
size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr));
MOCK_METHOD3(get_all_algorithms,
std::vector<Algorithm*>(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2));
MOCK_METHOD5(get_algorithm_heuristic,
MOCK_METHOD6(get_algorithm_heuristic,
Algorithm*(const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2,
size_t workspace_limit_in_bytes,
const AlgoAttribute& attr));
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr));
MOCK_METHOD1(get_algorithm_from_desc,
Algorithm*(const AlgorithmDesc&));
......@@ -2468,7 +2470,7 @@ TEST_F(TestWeightPreprocess, NoPreprocessNeeded) {
auto& mock = mock_conv();
MockAlgorithm algo;
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _))
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _, _))
.WillRepeatedly(Return(&algo));
EXPECT_CALL(mock, get_algorithm_from_desc(_))
.WillRepeatedly(Return(&algo));
......@@ -2508,7 +2510,7 @@ TEST_F(TestWeightPreprocess, PreprocessCalledOnlyOnce) {
.WillRepeatedly(Return(&algo));
Expectation algo_call =
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _))
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _, _))
.WillOnce(Return(&algo));
Expectation ws_call = EXPECT_CALL(mock, get_workspace_in_bytes(_, _, _, _))
.After(algo_call)
......@@ -2567,7 +2569,7 @@ TEST_F(TestNoWeightPreprocess, NoPreprocess) {
auto& mock = mock_conv();
MockAlgorithm algo;
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _))
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _, _))
.WillRepeatedly(Return(&algo));
EXPECT_CALL(mock, get_algorithm_from_desc(_))
.WillRepeatedly(Return(&algo));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册