提交 a3ea1f15 编写于 作者: M Megvii Engine Team

feat(mgb/opr): add fast profile and combined Execution strategy

GitOrigin-RevId: 843dc3a7907bc6ec9a728ec6425b7910d9c136c5
上级 80f00643
...@@ -506,10 +506,66 @@ struct DynOutMallocPolicyCall { ...@@ -506,10 +506,66 @@ struct DynOutMallocPolicyCall {
} }
}; };
template <typename T>
class EnumClassBit {
std::underlying_type_t<T> m_val;
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {}
public:
constexpr EnumClassBit(T v)
: m_val(static_cast<std::underlying_type_t<T>>(v)) {}
constexpr operator T() const { return static_cast<T>(m_val); }
constexpr explicit operator bool() const { return m_val; }
#define DEF_OPR(op) \
constexpr EnumClassBit operator op(const EnumClassBit& rhs) const { \
return m_val op rhs.m_val; \
}
DEF_OPR(&)
DEF_OPR(|)
DEF_OPR (^)
constexpr EnumClassBit operator~() const { return ~m_val; }
#undef DEF_OPR
};
#endif // MEGDNN_CC_HOST #endif // MEGDNN_CC_HOST
} // namespace megdnn } // namespace megdnn
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) \
op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
}
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \
inline constexpr cls& operator op##=(cls& x, cls y) { \
x = x op ::megdnn::EnumClassBit<cls>(y); \
return x; \
}
#define MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) \
_MEGDNN_DECBO_SINGLE_OPR(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR(cls, ^) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \
inline constexpr ::megdnn::EnumClassBit<cls> operator~(cls x) { \
return ~::megdnn::EnumClassBit<cls>(x); \
}
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -251,6 +251,8 @@ protected: ...@@ -251,6 +251,8 @@ protected:
Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; Handle::HandleType m_handle_type = Handle::HandleType::NAIVE;
}; };
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(Algorithm::Attribute)
//! policy for executing the operator //! policy for executing the operator
struct ExecutionPolicy { struct ExecutionPolicy {
//! INVALID_ALGO_TYPE algo_type means using heuristic //! INVALID_ALGO_TYPE algo_type means using heuristic
......
...@@ -53,9 +53,13 @@ class FlatBuffersWriter(IndentWriterBase): ...@@ -53,9 +53,13 @@ class FlatBuffersWriter(IndentWriterBase):
e = self._enums[(p, e)] e = self._enums[(p, e)]
self._write_doc(e.name) self._write_doc(e.name)
self._write("enum %s%s : uint {", p, e.name, indent=1) self._write("enum %s%s : uint {", p, e.name, indent=1)
for member in e.members: for idx, member in enumerate(e.members):
self._write_doc(member) self._write_doc(member)
self._write("%s,", scramble_enum_member_name(str(member))) if e.combined:
self._write("%s=%d,", scramble_enum_member_name(str(member)),
1<<idx)
else:
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("}\n", indent=-1) self._write("}\n", indent=-1)
def _write_doc(self, doc): def _write_doc(self, doc):
......
...@@ -80,13 +80,13 @@ class member_defs: ...@@ -80,13 +80,13 @@ class member_defs:
:attr member_alias: list of (member, alias) pairs :attr member_alias: list of (member, alias) pairs
""" """
__slots__ = ['name', 'name_field', 'members', 'default', __slots__ = ['name', 'name_field', 'members', 'default',
'member_alias'] 'member_alias', 'combined']
all_enums = {} all_enums = {}
"""(param_name, name) => enum""" """(param_name, name) => enum"""
def __init__(self, param_name, name, name_field, members, default, def __init__(self, param_name, name, name_field, members, default,
member_alias): member_alias, combined = False):
name = member_defs.Doc.make(name) name = member_defs.Doc.make(name)
assert name.id[0].isupper() assert name.id[0].isupper()
members = tuple(map(member_defs.Doc.make, members)) members = tuple(map(member_defs.Doc.make, members))
...@@ -97,6 +97,7 @@ class member_defs: ...@@ -97,6 +97,7 @@ class member_defs:
default = name_field.index(default) default = name_field.index(default)
assert isinstance(default, int) assert isinstance(default, int)
self.name = name self.name = name
self.combined = combined
self.name_field = self.get_name_field(name.id, name_field) self.name_field = self.get_name_field(name.id, name_field)
self.members = members self.members = members
self.default = default self.default = default
...@@ -197,6 +198,12 @@ class ParamDef: ...@@ -197,6 +198,12 @@ class ParamDef:
self.name.id, name, name_field, members, default, member_alias)) self.name.id, name, name_field, members, default, member_alias))
return self return self
def add_bit_combination_enum(self, name, *members, default=0,
name_field=None, member_alias=[]):
self.members.append(member_defs.Enum(
self.name.id, name, name_field, members, default, member_alias, True))
return self
def add_enum_alias(self, name, src_class, src_name=None, name_field=None, def add_enum_alias(self, name, src_class, src_name=None, name_field=None,
default=None): default=None):
self.members.append(member_defs.EnumAlias( self.members.append(member_defs.EnumAlias(
...@@ -463,8 +470,12 @@ class SerializedDType(_ParamDefBase): ...@@ -463,8 +470,12 @@ class SerializedDType(_ParamDefBase):
for idx, emem in enumerate(e.members): for idx, emem in enumerate(e.members):
self._write('%s = "%s"', emem, emem) self._write('%s = "%s"', emem, emem)
self._write_doc(emem) self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format( if e.combined:
qualname, emem, idx)) self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, 1<<idx))
else:
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx))
for emem, emem_alis in e.member_alias: for emem, emem_alis in e.member_alias:
self._write('%s = %s', emem_alis, emem) self._write('%s = %s', emem_alis, emem)
...@@ -622,6 +633,8 @@ class CPPWriter(IndentWriterBase): ...@@ -622,6 +633,8 @@ class CPPWriter(IndentWriterBase):
for idx, i in enumerate(e.members): for idx, i in enumerate(e.members):
self._write_doc(i) self._write_doc(i)
v = '{} = {}'.format(i, idx) v = '{} = {}'.format(i, idx)
if e.combined:
v = '{} = 1 << {}'.format(i, idx)
if i is not e.members[-1] or e.member_alias: if i is not e.members[-1] or e.member_alias:
v += ',' v += ','
self._write(v) self._write(v)
...@@ -672,7 +685,6 @@ class CPPEnumValueWriter(CPPWriter): ...@@ -672,7 +685,6 @@ class CPPEnumValueWriter(CPPWriter):
self._write('static const uint32_t %s = %s;', alias, mem) self._write('static const uint32_t %s = %s;', alias, mem)
self._write('};', indent=-1) self._write('};', indent=-1)
def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
s = e.src_enum s = e.src_enum
self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name) self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name)
......
...@@ -91,12 +91,17 @@ class ConverterWriter(IndentWriterBase): ...@@ -91,12 +91,17 @@ class ConverterWriter(IndentWriterBase):
def format(v): def format(v):
return '\"{}\"'.format(str(v)) return '\"{}\"'.format(str(v))
enum_def += ','.join(format(i) for i in e.members) enum_def += ','.join(format(i) for i in e.members)
enum_def += "]"
if e.combined:
enum_def += "], 1"
else:
enum_def += "], 0"
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
enum_def += ", 1" # whether generate ToStringTrait enum_def += ", 1" # whether generate ToStringTrait
enum_def += ">" enum_def += ">"
self._write("def {} : {};".format(td_class, enum_def))
self._write("def {} : {};".format(td_class, enum_def))
if self._skip_current_param: if self._skip_current_param:
return return
......
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
namespace megdnn { namespace megdnn {
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(AlgoAttribute)
#define MEGDNN_DECL_ALGO_TYPE(_type) \ #define MEGDNN_DECL_ALGO_TYPE(_type) \
uint32_t type() const override { \ uint32_t type() const override { \
return static_cast<std::underlying_type<AlgoType>::type>( \ return static_cast<std::underlying_type<AlgoType>::type>( \
......
...@@ -692,61 +692,6 @@ inline void* get_origin_ptr(const TensorND* tensor, void* ptr) { ...@@ -692,61 +692,6 @@ inline void* get_origin_ptr(const TensorND* tensor, void* ptr) {
tensor->layout.span().low_byte); tensor->layout.span().low_byte);
} }
template <typename T>
class EnumClassBit {
std::underlying_type_t<T> m_val;
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {}
public:
constexpr EnumClassBit(T v)
: m_val(static_cast<std::underlying_type_t<T>>(v)) {}
constexpr operator T() const { return static_cast<T>(m_val); }
constexpr explicit operator bool() const { return m_val; }
#define DEF_OPR(op) \
constexpr EnumClassBit operator op(const EnumClassBit& rhs) const { \
return m_val op rhs.m_val; \
}
DEF_OPR(&)
DEF_OPR(|)
DEF_OPR (^)
constexpr EnumClassBit operator~() const { return ~m_val; }
#undef DEF_OPR
};
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) \
op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
}
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \
inline constexpr cls& operator op##=(cls& x, cls y) { \
x = x op ::megdnn::EnumClassBit<cls>(y); \
return x; \
}
#define MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) \
_MEGDNN_DECBO_SINGLE_OPR(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR(cls, ^) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \
inline constexpr ::megdnn::EnumClassBit<cls> operator~(cls x) { \
return ~::megdnn::EnumClassBit<cls>(x); \
}
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -218,4 +218,3 @@ public: ...@@ -218,4 +218,3 @@ public:
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -8,9 +8,12 @@ ...@@ -8,9 +8,12 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os import os
from ..core.ops import builtin
from ..logger import get_logger from ..logger import get_logger
from ..utils.deprecation import deprecated from ..utils.deprecation import deprecated
Strategy = builtin.ops.Convolution.Strategy
_execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC") _execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")
if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
...@@ -19,7 +22,7 @@ if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: ...@@ -19,7 +22,7 @@ if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
) )
def get_execution_strategy() -> str: def get_execution_strategy() -> Strategy:
""" """
Returns the execution strategy of :class:`~.Conv2d` and :func:'~.matmul' Returns the execution strategy of :class:`~.Conv2d` and :func:'~.matmul'
...@@ -28,12 +31,22 @@ def get_execution_strategy() -> str: ...@@ -28,12 +31,22 @@ def get_execution_strategy() -> str:
return _execution_strategy return _execution_strategy
def set_execution_strategy(option: str): def set_execution_strategy(option):
""" """
Sets the execution strategy of :class:`~.Conv2d` and :func:'~.matmul' Sets the execution strategy of :class:`~.Conv2d` and :func:'~.matmul'
:param option: Decides how :class:`~.Conv2d` and :func:'~.matmul' algorithms are chosen. :param option: Decides how :class:`~.Conv2d`and :func:'~.matmul' algorithms are chosen.
Available values: Available value Strategy
* HEURISTIC uses heuristic to choose the fastest algorithm.
* PROFILE runs possible algorithms on real device to find the best one.
* REPRODUCIBLE uses the algorithms that is reproducible.
* OPTMIZED uses the algorithms that is optimized.
The default strategy is HEURISTIC, this options can be combined to
form a combination option, e.g. PROFILE | REPRODUCIBLE
can combined a option that uses the fastest of profiling result that is also reproducible.
Available values string:
* 'HEURISTIC' uses heuristic to choose the fastest algorithm. * 'HEURISTIC' uses heuristic to choose the fastest algorithm.
* 'PROFILE' runs possible algorithms on real device to find the best one. * 'PROFILE' runs possible algorithms on real device to find the best one.
...@@ -45,18 +58,29 @@ def set_execution_strategy(option: str): ...@@ -45,18 +58,29 @@ def set_execution_strategy(option: str):
It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'.
""" """
valid_option = ( valid_string_option = {
"HEURISTIC", "REPRODUCIBLE": Strategy.REPRODUCIBLE,
"PROFILE", "HEURISTIC": Strategy.HEURISTIC,
"PROFILE_HEURISTIC", "PROFILE": Strategy.PROFILE,
"PROFILE_REPRODUCIBLE", }
"HEURISTIC_REPRODUCIBLE",
)
if not option in valid_option:
raise ValueError("Valid option can only be one of {}".format(valid_option))
global _execution_strategy # pylint: disable=global-statement global _execution_strategy # pylint: disable=global-statement
_execution_strategy = option if isinstance(option, Strategy):
_execution_strategy = option
return
assert isinstance(option, str)
strategy_tmp = Strategy(0)
for opt in option.split("_"):
if not opt in valid_string_option:
raise ValueError(
"Valid option can only be one of {}, or combine them with '_'.".format(
valid_string_option.keys()
)
)
strategy_tmp = strategy_tmp | valid_string_option[opt]
_execution_strategy = strategy_tmp
@deprecated(version="1.3", reason="use get_execution_strategy() instead") @deprecated(version="1.3", reason="use get_execution_strategy() instead")
......
...@@ -19,6 +19,7 @@ import megengine.autodiff as ad ...@@ -19,6 +19,7 @@ import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import jit from megengine import jit
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.functional.debug_param import set_execution_strategy from megengine.functional.debug_param import set_execution_strategy
from megengine.jit import SublinearMemoryConfig from megengine.jit import SublinearMemoryConfig
...@@ -33,6 +34,8 @@ from megengine.module import ( ...@@ -33,6 +34,8 @@ from megengine.module import (
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.tensor import Tensor from megengine.tensor import Tensor
Strategy = builtin.ops.Convolution.Strategy
def get_gpu_name(): def get_gpu_name():
try: try:
...@@ -242,7 +245,7 @@ def test_correctness(): ...@@ -242,7 +245,7 @@ def test_correctness():
else: else:
model_name = "mnist_model_with_test_cpu.mge" model_name = "mnist_model_with_test_cpu.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name) model_path = os.path.join(os.path.dirname(__file__), model_name)
set_execution_strategy("HEURISTIC_REPRODUCIBLE") set_execution_strategy(Strategy.HEURISTIC | Strategy.REPRODUCIBLE)
run_train(model_path, False, False, max_err=1e-5) run_train(model_path, False, False, max_err=1e-5)
run_train(model_path, True, False, max_err=1e-5) run_train(model_path, True, False, max_err=1e-5)
......
...@@ -337,6 +337,20 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ...@@ -337,6 +337,20 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
className, attr->getEnumName(), i className, attr->getEnumName(), i
)); ));
} }
if (attr->getEnumCombinedFlag()) {
//! define operator |
os << formatv(
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));"
"\n })",
className, attr->getEnumName());
//! define operator &
os << formatv(
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));"
"\n })",
className, attr->getEnumName());
}
os << formatv( os << formatv(
"\n .def(py::init([](const std::string& in) {" "\n .def(py::init([](const std::string& in) {"
"\n auto&& str = normalize_enum(in);" "\n auto&& str = normalize_enum(in);"
......
...@@ -77,6 +77,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { ...@@ -77,6 +77,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
bool supportToString() const { bool supportToString() const {
return getBaseRecord()->getValueAsBit("supportToString"); return getBaseRecord()->getValueAsBit("supportToString");
} }
bool getEnumCombinedFlag() const {
return getBaseRecord()->getValueAsBit("enumCombined");
}
}; };
struct MgbHashableAttrMixin : public MgbAttrWrapperBase { struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
......
...@@ -142,8 +142,16 @@ R"__usage__( ...@@ -142,8 +142,16 @@ R"__usage__(
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
R"__usage__( R"__usage__(
--fast-run --fast-run
Enable fast-run mode. Operators with multiple algorithms would be profiled This param will be deperated later, please replace with param --full-profile.
on the real device with actual input shapes. --full-profile
Enable full-profile mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, all algorithms will be profiled
include naive algorithms.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
--fast-profile
Enable fast-profile mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, this mode will only profile the
well optimized algorithms to get the profile result fast.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
)__usage__" )__usage__"
#endif #endif
...@@ -511,7 +519,8 @@ struct Args { ...@@ -511,7 +519,8 @@ struct Args {
bool disable_assert_throw = false; bool disable_assert_throw = false;
bool share_param_mem = false; bool share_param_mem = false;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
bool use_fast_run = false; bool use_full_profile = false;
bool use_fast_profile = false;
#endif #endif
bool reproducible = false; bool reproducible = false;
std::string fast_run_cache_path; std::string fast_run_cache_path;
...@@ -695,18 +704,20 @@ void run_test_st(Args &env) { ...@@ -695,18 +704,20 @@ void run_test_st(Args &env) {
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = S::HEURISTIC; S strategy = S::HEURISTIC;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (env.use_fast_run) { if (env.use_full_profile) {
if (env.reproducible) { if (env.reproducible) {
strategy = S::PROFILE_REPRODUCIBLE; strategy = S::PROFILE | S::REPRODUCIBLE;
} else { } else {
strategy = S::PROFILE; strategy = S::PROFILE;
} }
} else if (env.use_fast_profile) {
strategy = S::PROFILE | S::OPTMIZED;
} else if (env.reproducible) { } else if (env.reproducible) {
strategy = S::HEURISTIC_REPRODUCIBLE; strategy = S::HEURISTIC | S::REPRODUCIBLE;
} }
#else #else
if (env.reproducible) { if (env.reproducible) {
strategy = S::HEURISTIC_REPRODUCIBLE; strategy = S::HEURISTIC | S::REPRODUCIBLE;
} }
#endif #endif
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
...@@ -729,11 +740,12 @@ void run_test_st(Args &env) { ...@@ -729,11 +740,12 @@ void run_test_st(Args &env) {
std::make_shared<InFilePersistentCache>(buf.get(), flen)); std::make_shared<InFilePersistentCache>(buf.get(), flen));
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
} else { } else {
mgb_assert(env.use_fast_run, "fast-run should be enabled"); mgb_assert(env.use_full_profile || env.use_fast_profile,
"fast-run or fast-profile should be enabled");
PersistentCache::set_impl( PersistentCache::set_impl(
std::make_shared<InFilePersistentCache>()); std::make_shared<InFilePersistentCache>());
} }
if (!env.use_fast_run) if (!env.use_full_profile && !env.use_fast_profile)
#endif #endif
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
} }
...@@ -1314,7 +1326,18 @@ Args Args::from_argv(int argc, char **argv) { ...@@ -1314,7 +1326,18 @@ Args Args::from_argv(int argc, char **argv) {
} }
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (!strcmp(argv[i], "--fast-run")) { if (!strcmp(argv[i], "--fast-run")) {
ret.use_fast_run = true; mgb_log_warn(
"--fast-run param will be deperated later, please replace "
"with --full-profile or --fast-profile.");
ret.use_full_profile = true;
continue;
}
if (!strcmp(argv[i], "--full-profile")) {
ret.use_full_profile = true;
continue;
}
if (!strcmp(argv[i], "--fast-profile")) {
ret.use_fast_profile = true;
continue; continue;
} }
#endif #endif
......
...@@ -188,7 +188,7 @@ AlgoChooserProfileCache::get(const Key &key) { ...@@ -188,7 +188,7 @@ AlgoChooserProfileCache::get(const Key &key) {
auto entry_len = read_uint32(); auto entry_len = read_uint32();
mgb_assert(buf + entry_len <= buf_end); mgb_assert(buf + entry_len <= buf_end);
auto nr = sscanf(reinterpret_cast<const char*>(buf), ENTRY_FMT, auto nr = sscanf(reinterpret_cast<const char*>(buf), ENTRY_FMT,
&i.reproducible, &i.time, &i.workspace); &i.attribute, &i.time, &i.workspace);
mgb_assert(nr == 3); mgb_assert(nr == 3);
buf += entry_len; buf += entry_len;
} }
...@@ -210,10 +210,10 @@ void AlgoChooserProfileCache::put(const Key &key, Result &result) { ...@@ -210,10 +210,10 @@ void AlgoChooserProfileCache::put(const Key &key, Result &result) {
auto &&cur = result[i]; auto &&cur = result[i];
if (prev.workspace <= cur.workspace && if (prev.workspace <= cur.workspace &&
prev.reproducible == cur.reproducible) { prev.attribute == cur.attribute) {
result.erase(result.begin() + i); result.erase(result.begin() + i);
} else { } else {
++ i; ++i;
} }
} }
...@@ -235,8 +235,8 @@ void AlgoChooserProfileCache::put(const Key &key, Result &result) { ...@@ -235,8 +235,8 @@ void AlgoChooserProfileCache::put(const Key &key, Result &result) {
write_uint32(0); write_uint32(0);
pos = val.size(); pos = val.size();
val.resize(pos + SPR_SIZE); val.resize(pos + SPR_SIZE);
uint32_t nr = snprintf(&val[pos], SPR_SIZE, uint32_t nr = snprintf(&val[pos], SPR_SIZE, ENTRY_FMT, i.attribute,
ENTRY_FMT, i.reproducible, i.time, i.workspace); i.time, i.workspace);
//! for memory boundary failed, snprintf ret do not contain \0 //! for memory boundary failed, snprintf ret do not contain \0
nr += 1; nr += 1;
mgb_assert(nr < SPR_SIZE); mgb_assert(nr < SPR_SIZE);
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#pragma once #pragma once
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/basic_types.h"
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -242,6 +244,16 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { ...@@ -242,6 +244,16 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) {
return n; return n;
} }
#endif #endif
#define MGB_DEF_ENUM_CLASS_BIT_OPR(cls) \
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls)
} // namespace mgb } // namespace mgb
namespace megdnn {
namespace param {
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy)
}
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#pragma once #pragma once
#include "megbrain/utils/hash.h" #include "megbrain/utils/hash.h"
#include "megbrain/utils/enum_class_bit.h"
#include "megbrain/utils/metahelper.h" #include "megbrain/utils/metahelper.h"
#include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/thin/hash_table.h"
#include "megbrain/utils/thread.h" #include "megbrain/utils/thread.h"
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "megbrain/graph/symbol_var.h" #include "megbrain/graph/symbol_var.h"
#include "megbrain/utils/hashable.h" #include "megbrain/utils/hashable.h"
#include "megbrain/utils/enum_class_bit.h"
#include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/thin/hash_table.h"
#include "megbrain/utils/small_vector.h" #include "megbrain/utils/small_vector.h"
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#pragma once #pragma once
#include "megbrain/graph/bases.h" #include "megbrain/graph/bases.h"
#include "megbrain/utils/enum_class_bit.h"
#include "megbrain/utils/comp_node_sync_manager.h" #include "megbrain/utils/comp_node_sync_manager.h"
#include "megbrain/utils/small_vector.h" #include "megbrain/utils/small_vector.h"
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"
......
...@@ -33,10 +33,11 @@ class MgbHashableAttrMixin { ...@@ -33,10 +33,11 @@ class MgbHashableAttrMixin {
string reprFunction = "std::to_string($0)"; string reprFunction = "std::to_string($0)";
} }
class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> { class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit combined, bit toString> {
string parentNamespace = namespace; string parentNamespace = namespace;
string enumName = name; string enumName = name;
list<string> enumMembers = members; list<string> enumMembers = members;
bit enumCombined = combined;
bit supportToString = toString; bit supportToString = toString;
} }
...@@ -166,8 +167,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>: ...@@ -166,8 +167,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>:
} }
// -- enum types // -- enum types
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>: class MgbEnumAttr<string namespace, string enumName, list<string> members, bit combined, bit toString=0>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> { HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, combined, toString> {
let storageType = "::mlir::IntegerAttr"; let storageType = "::mlir::IntegerAttr";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
...@@ -176,7 +177,7 @@ class MgbEnumAttr<string namespace, string enumName, list<string> members, bit t ...@@ -176,7 +177,7 @@ class MgbEnumAttr<string namespace, string enumName, list<string> members, bit t
} }
class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>:
MgbEnumAttr<namespace, enumName, base.enumMembers>, MgbAliasAttrMixin<base>; MgbEnumAttr<namespace, enumName, base.enumMembers, 0>, MgbAliasAttrMixin<base>;
// -- other types // -- other types
def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> {
......
/**
* \file src/core/include/megbrain/utils/enum_class_bit.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <type_traits>
namespace mgb {
template<typename T>
class EnumClassBit {
std::underlying_type_t<T> m_val;
constexpr EnumClassBit(std::underlying_type_t<T> v):
m_val(v)
{
}
public:
constexpr EnumClassBit(T v):
m_val(static_cast<std::underlying_type_t<T>>(v))
{
}
constexpr operator T() const {
return static_cast<T>(m_val);
}
constexpr explicit operator bool() const {
return m_val;
}
#define DEF_OPR(op) \
constexpr EnumClassBit operator op (\
const EnumClassBit &rhs) const { \
return m_val op rhs.m_val; \
}
DEF_OPR(&)
DEF_OPR(|)
DEF_OPR(^)
constexpr EnumClassBit operator ~() const {
return ~m_val;
}
#undef DEF_OPR
};
}
#define _MGB_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::mgb::EnumClassBit<cls> operator op (cls x, cls y) { \
return ::mgb::EnumClassBit<cls>(x) op ::mgb::EnumClassBit<cls>(y); \
} \
inline constexpr ::mgb::EnumClassBit<cls> operator op ( \
::mgb::EnumClassBit<cls> x, cls y) { \
return x op ::mgb::EnumClassBit<cls>(y); \
}
#define _MGB_DECBO_SINGLE_OPR_ASSIGN(cls, op) \
inline constexpr cls& operator op##= (cls& x, cls y) { \
x = x op ::mgb::EnumClassBit<cls>(y); \
return x; \
}
#define MGB_DEF_ENUM_CLASS_BIT_OPR(cls) \
_MGB_DECBO_SINGLE_OPR(cls, &) \
_MGB_DECBO_SINGLE_OPR(cls, |) \
_MGB_DECBO_SINGLE_OPR(cls, ^) \
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, &) \
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, |) \
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \
inline constexpr ::mgb::EnumClassBit<cls> operator ~ (cls x) { \
return ~::mgb::EnumClassBit<cls>(x); \
} \
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -100,8 +100,7 @@ namespace mgb { ...@@ -100,8 +100,7 @@ namespace mgb {
struct ResultEntry { struct ResultEntry {
std::string algo; //! identifier of the algorithm std::string algo; //! identifier of the algorithm
//! sscanf will up bool as int uint32_t attribute; //! algo attribute, e.g. reproducible
int reproducible; //! whether algorithm is reproducible
double time; //! execution time in seconds double time; //! execution time in seconds
size_t workspace; //! workspace in bytes size_t workspace; //! workspace in bytes
}; };
......
...@@ -54,7 +54,6 @@ using namespace gopt; ...@@ -54,7 +54,6 @@ using namespace gopt;
namespace { namespace {
template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder> template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder>
void param_merge(OptState& opt_state) { void param_merge(OptState& opt_state) {
auto rewriter = opt_state.graph().make_rewriter(); auto rewriter = opt_state.graph().make_rewriter();
...@@ -102,7 +101,7 @@ void param_merge(OptState& opt_state) { ...@@ -102,7 +101,7 @@ void param_merge(OptState& opt_state) {
rewriter.apply_inplace(); rewriter.apply_inplace();
} }
} } // namespace
/* ================ global functions ================ */ /* ================ global functions ================ */
...@@ -190,12 +189,10 @@ void gopt::enable_opr_algo_profiling_inplace( ...@@ -190,12 +189,10 @@ void gopt::enable_opr_algo_profiling_inplace(
void gopt::enable_opr_use_profiling_cache_inplace( void gopt::enable_opr_use_profiling_cache_inplace(
const VarNodeArrayView& dest_vars) { const VarNodeArrayView& dest_vars) {
modify_opr_algo_strategy_inplace( using S = megdnn::param::ExecutionPolicy::Strategy;
dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy:: modify_opr_algo_strategy_inplace(dest_vars, S::PROFILE | S::HEURISTIC);
Strategy::PROFILE_HEURISTIC);
} }
void gopt::set_opr_algo_workspace_limit_inplace( void gopt::set_opr_algo_workspace_limit_inplace(
const VarNodeArrayView& dest_vars, size_t workspace_limit) { const VarNodeArrayView& dest_vars, size_t workspace_limit) {
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
......
...@@ -1693,7 +1693,22 @@ TEST(TestGoptInference, ProfileCache) { ...@@ -1693,7 +1693,22 @@ TEST(TestGoptInference, ProfileCache) {
using S = opr::Convolution::ExecutionPolicy::Strategy; using S = opr::Convolution::ExecutionPolicy::Strategy;
ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy);
gopt::enable_opr_use_profiling_cache_inplace({z + 2.3f}); gopt::enable_opr_use_profiling_cache_inplace({z + 2.3f});
ASSERT_EQ(S::PROFILE_HEURISTIC, conv.execution_policy().strategy); ASSERT_EQ(S::PROFILE | S::HEURISTIC, conv.execution_policy().strategy);
}
TEST(TestGoptInference, FastProfileCache) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen({4, 3, 8, 9}), host_y = gen({2, 3, 3, 3});
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Host2DeviceCopy::make(*graph, host_y),
z = opr::Convolution::make(x, y);
auto&& conv = z.node()->owner_opr()->cast_final_safe<opr::Convolution>();
using S = opr::Convolution::ExecutionPolicy::Strategy;
ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy);
gopt::modify_opr_algo_strategy_inplace({z + 2.3f},
S::PROFILE | S::OPTMIZED);
ASSERT_EQ(S::PROFILE | S::OPTMIZED, conv.execution_policy().strategy);
} }
TEST(TestGoptInference, AlgoWorkspaceLimit) { TEST(TestGoptInference, AlgoWorkspaceLimit) {
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/tqt.h" #include "megbrain/opr/dnn/tqt.h"
#include "megbrain/serialization/sereg.h" #include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"
......
...@@ -284,8 +284,9 @@ namespace mgb { ...@@ -284,8 +284,9 @@ namespace mgb {
namespace opr { namespace opr {
template <typename Opr> template <typename Opr>
void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) { void AlgoChooser<Opr>::profile(ExeContext& ctx,
if (ctx.get_profile_result_from_cache(require_reproducible).valid()) ExecutionStrategy select_strategy) {
if (ctx.get_profile_result_from_cache(select_strategy).valid())
return; return;
AlgoChooserProfileCache::Result prof_rst; AlgoChooserProfileCache::Result prof_rst;
...@@ -305,7 +306,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) { ...@@ -305,7 +306,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {
algo.name.c_str(), str_on_inp_shape.c_str()); algo.name.c_str(), str_on_inp_shape.c_str());
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
policy.algo = algo.desc; policy.algo = algo.desc;
ctx.construct_execution_policy(require_reproducible, policy); ctx.construct_execution_policy(select_strategy, policy);
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) if (ctx.get_workspace_size_bytes(policy) >= workspace_limit)
continue; continue;
...@@ -354,7 +355,8 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) { ...@@ -354,7 +355,8 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible, AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
ExecutionStrategy select_strategy,
bool enable_update) { bool enable_update) {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile")))
if (ctx.owner_graph()->options().no_profiling_on_shape_change) { if (ctx.owner_graph()->options().no_profiling_on_shape_change) {
...@@ -376,11 +378,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible, ...@@ -376,11 +378,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible,
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
_item.param, ctx.mgb_opr(), ctx.comp_node(), _item.param, ctx.mgb_opr(), ctx.comp_node(),
ctx.execution_policy(), ctx.allow_weight_preprocess()); ctx.execution_policy(), ctx.allow_weight_preprocess());
AlgoChooser<_Opr>::profile(sub_ctx, require_reproducible); AlgoChooser<_Opr>::profile(sub_ctx, select_strategy);
}); });
} }
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; typename AlgoChooser<Opr>::ImplExecutionPolicy policy;
ctx.construct_execution_policy(require_reproducible, policy); ctx.construct_execution_policy(select_strategy, policy);
return policy; return policy;
MIDOUT_E MIDOUT_E
} }
...@@ -402,11 +404,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, ...@@ -402,11 +404,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
if (auto algo_choose_hook = mgb_opr->algo_chooser()) { if (auto algo_choose_hook = mgb_opr->algo_chooser()) {
policy = algo_choose_hook(mgb_opr); policy = algo_choose_hook(mgb_opr);
ctx.construct_execution_policy( ctx.construct_execution_policy((ExecutionStrategy::HEURISTIC |
mgb_opr->execution_policy().strategy == ExecutionStrategy::REPRODUCIBLE),
mixin::AlgoChooserHelper::ExecutionPolicy::Strategy:: policy, false);
HEURISTIC_REPRODUCIBLE,
policy, false);
} }
if (!policy.algo.valid()) { if (!policy.algo.valid()) {
policy = get_policy(ctx); policy = get_policy(ctx);
...@@ -419,10 +419,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, ...@@ -419,10 +419,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(palgo, "Unknown algo description"); mgb_assert(palgo, "Unknown algo description");
ret.append("): algo=" + std::string(palgo->name())); ret.append("): algo=" + std::string(palgo->name()));
ret.append(ssprintf(" workspace=%.2fMiB reproducible=%d", ret.append(ssprintf(" workspace=%.2fMiB attirbute=%d",
workspace / (1024 * 1024.0), workspace / (1024 * 1024.0),
palgo->contain_attribute( static_cast<uint32_t>(palgo->attribute())));
megdnn::AlgoAttribute::REPRODUCIBLE)));
mgb_log_debug("%s", ret.c_str()); mgb_log_debug("%s", ret.c_str());
megdnn_opr->execution_policy() = policy; megdnn_opr->execution_policy() = policy;
...@@ -432,41 +431,39 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, ...@@ -432,41 +431,39 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
ExeContext& ctx) { ExeContext& ctx) {
using S = mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE);
switch (ctx.execution_policy().strategy) { auto opr_strategy = ctx.execution_policy().strategy;
case S::HEURISTIC: if ((opr_strategy & ExecutionStrategy::HEURISTIC) &&
return ctx.choose_by_heuristic(); (opr_strategy & ExecutionStrategy::PROFILE)) {
case S::HEURISTIC_REPRODUCIBLE: ImplExecutionPolicy policy =
return ctx.choose_by_heuristic(true); choose_by_profile(ctx, opr_strategy, false);
case S::PROFILE_HEURISTIC: { if (!policy.algo.valid())
ImplExecutionPolicy policy = choose_by_profile(ctx, false, false); policy = ctx.choose_by_heuristic(opr_strategy);
if (!policy.algo.valid()) return policy;
policy = ctx.choose_by_heuristic(); } else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) {
return policy; return ctx.choose_by_heuristic(opr_strategy);
} }
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
case S::PROFILE: else if (opr_strategy & ExecutionStrategy::PROFILE) {
return choose_by_profile(ctx, false); return choose_by_profile(ctx, opr_strategy);
case S::PROFILE_REPRODUCIBLE: }
return choose_by_profile(ctx, true);
#endif #endif
default: else {
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy");
} }
} }
#define INST(Opr) \ #define INST(Opr) \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \ AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \
template void AlgoChooser<megdnn::Opr>::profile( \ template void AlgoChooser<megdnn::Opr>::profile(ExeContext& ctx, \
ExeContext& ctx, bool require_reproducible); \ ExecutionStrategy); \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::choose_by_profile( \ AlgoChooser<megdnn::Opr>::choose_by_profile( \
ExeContext& ctx, bool require_reproducible, bool enable_update); \ ExeContext& ctx, ExecutionStrategy, bool enable_update); \
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \ template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
const MGBOpr* mgb_opr, bool allow_weight_preprocess); \ const MGBOpr* mgb_opr, bool allow_weight_preprocess);
MGB_FOREACH_FASTRUN_OPR(INST) MGB_FOREACH_FASTRUN_OPR(INST)
...@@ -498,7 +495,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext( ...@@ -498,7 +495,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext(
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplAlgo typename AlgoChooser<Opr>::ImplAlgo
AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
bool require_reproducible) const { ExecutionStrategy select_strategy) const {
MIDOUT_B(Opr, MIDOUT_B(Opr,
midout_iv(MGB_HASH_STR( midout_iv(MGB_HASH_STR(
"AlgoChooser::ExeContext::get_profile_result_from_cache"))) "AlgoChooser::ExeContext::get_profile_result_from_cache")))
...@@ -522,7 +519,9 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( ...@@ -522,7 +519,9 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
if (prof.empty()) if (prof.empty())
return {}; return {};
for (auto&& i : prof) { for (auto&& i : prof) {
if ((!require_reproducible || i.reproducible)) { if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) ||
static_cast<AlgoAttribute>(i.attribute) &
AlgoAttribute::REPRODUCIBLE) {
auto iter = algo_map.find(i.algo); auto iter = algo_map.find(i.algo);
mgb_assert(iter != algo_map.end(), mgb_assert(iter != algo_map.end(),
"algorithm %s exists in " "algorithm %s exists in "
...@@ -550,7 +549,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( ...@@ -550,7 +549,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
ExecutionStrategy select_strategy) const {
if (m_execution_policy.workspace_limit != if (m_execution_policy.workspace_limit !=
std::numeric_limits<decltype( std::numeric_limits<decltype(
m_execution_policy.workspace_limit)>::max()) { m_execution_policy.workspace_limit)>::max()) {
...@@ -558,6 +558,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { ...@@ -558,6 +558,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
"workspace_limit should not be setted if choose algo by " "workspace_limit should not be setted if choose algo by "
"heuristic"); "heuristic");
} }
bool reproducible = static_cast<bool>(select_strategy &
ExecutionStrategy::REPRODUCIBLE);
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
...@@ -579,7 +581,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { ...@@ -579,7 +581,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
m_allow_weight_preprocess); m_allow_weight_preprocess);
policy.sub_policy.push_back(sub_ctx.choose_by_heuristic(reproducible)); policy.sub_policy.push_back(
sub_ctx.choose_by_heuristic(select_strategy));
}); });
return policy; return policy;
...@@ -588,9 +591,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { ...@@ -588,9 +591,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
template <typename Opr> template <typename Opr>
std::vector<typename AlgoChooser<Opr>::ImplAlgo> std::vector<typename AlgoChooser<Opr>::ImplAlgo>
AlgoChooser<Opr>::ExeContext::get_all_candidates() const { AlgoChooser<Opr>::ExeContext::get_all_candidates() const {
auto heu = choose_by_heuristic(); auto heu = choose_by_heuristic(ExecutionStrategy::HEURISTIC);
auto&& ret = auto&& ret = APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts);
APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts);
bool found = false; bool found = false;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (ret[i].desc == heu.algo) { if (ret[i].desc == heu.algo) {
...@@ -611,19 +613,21 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const { ...@@ -611,19 +613,21 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const {
template <typename Opr> template <typename Opr>
void AlgoChooser<Opr>::ExeContext::construct_execution_policy( void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
bool require_reproducible, ExecutionStrategy select_strategy,
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, typename AlgoChooser<Opr>::ImplExecutionPolicy& policy,
bool retrive_from_cache) const { bool retrive_from_cache) const {
bool reproducible = static_cast<bool>(select_strategy &
ExecutionStrategy::REPRODUCIBLE);
if (!policy.algo.valid()) { if (!policy.algo.valid()) {
if (retrive_from_cache) { if (retrive_from_cache) {
policy.algo = policy.algo =
get_profile_result_from_cache(require_reproducible).desc; get_profile_result_from_cache(select_strategy).desc;
} else { } else {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, args..., workspace_limit,
require_reproducible), reproducible),
m_layouts) m_layouts)
.desc; .desc;
} }
...@@ -647,7 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( ...@@ -647,7 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
m_allow_weight_preprocess); m_allow_weight_preprocess);
policy.sub_policy.push_back({}); policy.sub_policy.push_back({});
sub_ctx.construct_execution_policy(require_reproducible, sub_ctx.construct_execution_policy(select_strategy,
policy.sub_policy.back(), policy.sub_policy.back(),
retrive_from_cache); retrive_from_cache);
}); });
...@@ -718,8 +722,7 @@ AlgoChooser<Opr>::ExeContext::profile_single_algo( ...@@ -718,8 +722,7 @@ AlgoChooser<Opr>::ExeContext::profile_single_algo(
return None; return None;
return AlgoChooserProfileCache::ResultEntry{ return AlgoChooserProfileCache::ResultEntry{
palgo->name(), palgo->name(),
palgo->contain_attribute( static_cast<uint32_t>(palgo->attribute()),
megdnn::AlgoAttribute::REPRODUCIBLE),
rst.val().time, param.workspace}; rst.val().time, param.workspace};
} }
...@@ -768,10 +771,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { ...@@ -768,10 +771,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
bool allow_weight_preprocess); \ bool allow_weight_preprocess); \
template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::ExeContext::choose_by_heuristic( \ AlgoChooser<megdnn::Opr>::ExeContext::choose_by_heuristic( \
bool reproducible) const; \ ExecutionStrategy select_strategy) const; \
template typename AlgoChooser<megdnn::Opr>::ImplAlgo \ template typename AlgoChooser<megdnn::Opr>::ImplAlgo \
AlgoChooser<megdnn::Opr>::ExeContext::get_profile_result_from_cache( \ AlgoChooser<megdnn::Opr>::ExeContext::get_profile_result_from_cache( \
bool require_reproducible) const; \ ExecutionStrategy select_strategy) const; \
template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \ template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \
AlgoChooser<megdnn::Opr>::ExeContext::get_all_candidates() const; \ AlgoChooser<megdnn::Opr>::ExeContext::get_all_candidates() const; \
template size_t \ template size_t \
...@@ -780,7 +783,7 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { ...@@ -780,7 +783,7 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
policy) const; \ policy) const; \
template void \ template void \
AlgoChooser<megdnn::Opr>::ExeContext::construct_execution_policy( \ AlgoChooser<megdnn::Opr>::ExeContext::construct_execution_policy( \
bool require_reproducible, \ ExecutionStrategy select_strategy, \
typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \ typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \
bool retrive_from_cache) const; \ bool retrive_from_cache) const; \
template Maybe<AlgoChooserProfileCache::ResultEntry> \ template Maybe<AlgoChooserProfileCache::ResultEntry> \
......
...@@ -35,6 +35,13 @@ MGB_FOREACH_FASTRUN_OPR(cb) ...@@ -35,6 +35,13 @@ MGB_FOREACH_FASTRUN_OPR(cb)
#undef cb #undef cb
namespace mgb { namespace mgb {
//! define logical operation of megdnn::param::ExecutionPolicy::Strategy::Enum
//! and megdnn::detail::AlgoAttribute enum
using ExecutionStrategy = megdnn::param::ExecutionPolicy::Strategy;
using AlgoAttribute = megdnn::AlgoAttribute;
namespace opr { namespace opr {
/* =================== AlgoChooser =================== */ /* =================== AlgoChooser =================== */
...@@ -103,7 +110,7 @@ public: ...@@ -103,7 +110,7 @@ public:
const FixedTensorLayouts& layouts() const { return m_layouts; } const FixedTensorLayouts& layouts() const { return m_layouts; }
ImplExecutionPolicy choose_by_heuristic( ImplExecutionPolicy choose_by_heuristic(
bool reproducible = false) const; ExecutionStrategy select_strategy) const;
//! get all candidate algos, and the one choose_by_heuristic() is //! get all candidate algos, and the one choose_by_heuristic() is
//! put first //! put first
...@@ -126,19 +133,20 @@ public: ...@@ -126,19 +133,20 @@ public:
const ImplExecutionPolicy& policy, double& timeout) const; const ImplExecutionPolicy& policy, double& timeout) const;
//! get all profile algorithm from cache, return invalid if not exists //! get all profile algorithm from cache, return invalid if not exists
ImplAlgo get_profile_result_from_cache(bool require_reproducible) const; ImplAlgo get_profile_result_from_cache(
ExecutionStrategy select_strategy) const;
/** /**
* \brief construct execution policy from cache or heuristic. * \brief construct execution policy from cache or heuristic.
* *
* \param require_reproducible select algo which is reproducible * \param select_strategy select algo which matched this strategy
* \param policy execution policy * \param policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get * \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise. * from heuristic otherwise.
*/ */
void construct_execution_policy( void construct_execution_policy(ExecutionStrategy select_strategy,
bool require_reproducible, ImplExecutionPolicy& policy, ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const; bool retrive_from_cache = true) const;
private: private:
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const;
...@@ -153,11 +161,11 @@ private: ...@@ -153,11 +161,11 @@ private:
//! profile and save to cache //! profile and save to cache
static void profile(ExeContext& ctx, bool require_reproducible); static void profile(ExeContext& ctx, ExecutionStrategy select_strategy);
static ImplExecutionPolicy choose_by_profile(ExeContext& ctx, static ImplExecutionPolicy choose_by_profile(
bool require_reproducible, ExeContext& ctx, ExecutionStrategy select_strategy,
bool enable_update = true); bool enable_update = true);
public: public:
/*! /*!
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
#pragma once #pragma once
#include "megbrain/graph/operator_node.h" #include "megbrain/graph/operator_node.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/oprs/base.h" #include "megdnn/oprs/base.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"
...@@ -73,7 +72,6 @@ protected: ...@@ -73,7 +72,6 @@ protected:
}; };
} // namespace mixin } // namespace mixin
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
......
...@@ -429,10 +429,11 @@ TEST(TestOprDNN, MatrixMulExePolicy) { ...@@ -429,10 +429,11 @@ TEST(TestOprDNN, MatrixMulExePolicy) {
auto cn = CompNode::load("cpux"); auto cn = CompNode::load("cpux");
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, for (auto strategy :
S::PROFILE_HEURISTIC}) { SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { for (auto strategy: {S:HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
......
...@@ -355,11 +355,13 @@ TEST(TestOprDNN, ConvBiasExePolicy) { ...@@ -355,11 +355,13 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
auto cn = CompNode::load("cpux"); auto cn = CompNode::load("cpux");
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy: {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, S::PROFILE_HEURISTIC}) { for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
...@@ -397,7 +399,8 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { ...@@ -397,7 +399,8 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) {
auto cn = CompNode::load("cpux"); auto cn = CompNode::load("cpux");
for (auto strategy: {S::PROFILE, S::PROFILE_REPRODUCIBLE}) { for (auto strategy :
SmallVector<S>{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
...@@ -439,10 +442,12 @@ TEST(TestOprDNN, ConvolutionExePolicy) { ...@@ -439,10 +442,12 @@ TEST(TestOprDNN, ConvolutionExePolicy) {
PersistentCacheHook cache_hook{on_get}; PersistentCacheHook cache_hook{on_get};
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, for (auto strategy :
S::PROFILE_HEURISTIC}) { SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
using Checker = AutoOprChecker<2, 1>; using Checker = AutoOprChecker<2, 1>;
...@@ -522,10 +527,11 @@ TEST(TestOprDNN, ConvolutionBackwardDataBfloat16ExePolicy) { ...@@ -522,10 +527,11 @@ TEST(TestOprDNN, ConvolutionBackwardDataBfloat16ExePolicy) {
PersistentCacheHook cache_hook{on_get}; PersistentCacheHook cache_hook{on_get};
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, for (auto strategy :
S::PROFILE_HEURISTIC}) { {S::PROFILE, S::HEURISTIC, S(S::PROFILE | S::REPRODUCIBLE),
S(S::PROFILE | S::HEURISTIC)}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { for (auto strategy: {S:HEURISTIC, S(S::PROFILE | S::HEURISTIC)}) {
#endif #endif
using Checker = AutoOprChecker<2, 1>; using Checker = AutoOprChecker<2, 1>;
...@@ -1183,9 +1189,12 @@ TEST(TestOprDNN, Convolution3DExePolicy) { ...@@ -1183,9 +1189,12 @@ TEST(TestOprDNN, Convolution3DExePolicy) {
using S = Policy::Strategy; using S = Policy::Strategy;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy: {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, S::PROFILE_HEURISTIC}) { for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
using Checker = AutoOprChecker<2, 1>; using Checker = AutoOprChecker<2, 1>;
...@@ -1660,10 +1669,12 @@ TEST(TestOprDNN, LocalShareForwardExecPolicy) { ...@@ -1660,10 +1669,12 @@ TEST(TestOprDNN, LocalShareForwardExecPolicy) {
PersistentCacheHook cache_hook{on_get}; PersistentCacheHook cache_hook{on_get};
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, for (auto strategy :
S::PROFILE_HEURISTIC}) { SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
auto make_graph = [&](const Checker::SymInpArray& inputs) auto make_graph = [&](const Checker::SymInpArray& inputs)
-> Checker::SymOutArray { -> Checker::SymOutArray {
...@@ -1769,10 +1780,12 @@ TEST(TestOprDNN, DeformableConvForward) { ...@@ -1769,10 +1780,12 @@ TEST(TestOprDNN, DeformableConvForward) {
Param param; Param param;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, for (auto strategy :
S::PROFILE_HEURISTIC}) { SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy : {S : HEURISTIC, S::PROFILE_HEURISTIC}) { for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
auto make_graph = [&](const Checker::SymInpArray& inputs) auto make_graph = [&](const Checker::SymInpArray& inputs)
-> Checker::SymOutArray { -> Checker::SymOutArray {
...@@ -1936,10 +1949,12 @@ TEST(TestOprDNN, BatchConvBiasForward) { ...@@ -1936,10 +1949,12 @@ TEST(TestOprDNN, BatchConvBiasForward) {
param.sparse = Param::Sparse::DENSE; param.sparse = Param::Sparse::DENSE;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, for (auto strategy :
S::PROFILE_HEURISTIC}) { SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy : {S : HEURISTIC, S::PROFILE_HEURISTIC}) { for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
auto make_quantized = [&](SymbolVar x, const DType& dtype) { auto make_quantized = [&](SymbolVar x, const DType& dtype) {
...@@ -2080,7 +2095,8 @@ TEST(TestOprDNN, HeuristicReproducible) { ...@@ -2080,7 +2095,8 @@ TEST(TestOprDNN, HeuristicReproducible) {
constexpr size_t PH = 1, PW = 1, SH = 1, SW = 1; constexpr size_t PH = 1, PW = 1, SH = 1, SW = 1;
for (auto strategy : {S::HEURISTIC, S::HEURISTIC_REPRODUCIBLE}) { for (auto strategy :
SmallVector<S>{S::HEURISTIC, S::HEURISTIC | S::REPRODUCIBLE}) {
VarNode* bwd_flt; VarNode* bwd_flt;
auto make_graph = [&](const Checker::SymInpArray& inputs) auto make_graph = [&](const Checker::SymInpArray& inputs)
-> Checker::SymOutArray { -> Checker::SymOutArray {
...@@ -2126,7 +2142,7 @@ TEST(TestOprDNN, HeuristicReproducible) { ...@@ -2126,7 +2142,7 @@ TEST(TestOprDNN, HeuristicReproducible) {
megdnn::Algorithm* palgo = megdnn::Algorithm* palgo =
megdnn_opr->get_algorithm_from_desc(algo); megdnn_opr->get_algorithm_from_desc(algo);
mgb_assert(palgo, "Unknown algo description"); mgb_assert(palgo, "Unknown algo description");
if (strategy == S::HEURISTIC_REPRODUCIBLE) { if (strategy == S(S::HEURISTIC | S::REPRODUCIBLE)) {
EXPECT_TRUE(palgo->contain_attribute( EXPECT_TRUE(palgo->contain_attribute(
megdnn::AlgoAttribute::REPRODUCIBLE)); megdnn::AlgoAttribute::REPRODUCIBLE));
} }
......
...@@ -43,6 +43,7 @@ namespace megdnn { ...@@ -43,6 +43,7 @@ namespace megdnn {
std::ostream &ostr, const DType &dt) { std::ostream &ostr, const DType &dt) {
return ostr << dt.name(); return ostr << dt.name();
} }
} // namespace megdnn } // namespace megdnn
namespace mgb { namespace mgb {
......
...@@ -18,7 +18,7 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -18,7 +18,7 @@ pdef('PersistentOutputStorage').add_fields(
add_const('int32', 'INVALID_AXIS', 'MAX_NDIM'). add_const('int32', 'INVALID_AXIS', 'MAX_NDIM').
add_fields('int32', 'axis', 'INVALID_AXIS')) add_fields('int32', 'axis', 'INVALID_AXIS'))
(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator'). (pdef('ExecutionPolicy', version=0, is_legacy=True).
add_enum('Strategy', add_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'),
Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, ' Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, '
...@@ -33,6 +33,20 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -33,6 +33,20 @@ pdef('PersistentOutputStorage').add_fields(
Doc('workspace_limit', 'workspace limit in bytes'), Doc('workspace_limit', 'workspace limit in bytes'),
str(2**64-1)+'ull')) str(2**64-1)+'ull'))
(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1).
add_bit_combination_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'),
Doc('PROFILE',
'run possible algorithms on real device to find the best'),
Doc('REPRODUCIBLE',
'when profile or heuristic algo selection it require the algos'
'must be reproducible'),
Doc('OPTMIZED',
'profile require algos are optmized to achieve fast-profile')).
add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'),
str(2**64-1)+'ull'))
(pdef('AssertEqual'). (pdef('AssertEqual').
add_fields('float32', add_fields('float32',
Doc('maxerr', 'max allowed error; error is defined as the minimal ' Doc('maxerr', 'max allowed error; error is defined as the minimal '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册