Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e19b9af1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e19b9af1
编写于
3月 09, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): add bit combined enum to python C extension
GitOrigin-RevId: 92307dd2ca077ea5606657f7cb7b321fd0dc8129
上级
a3ea1f15
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
359 addition
and
114 deletion
+359
-114
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+119
-4
imperative/tablegen/autogen.cpp
imperative/tablegen/autogen.cpp
+171
-56
sdk/load-and-run/src/mgblar.cpp
sdk/load-and-run/src/mgblar.cpp
+14
-23
src/core/include/megbrain/common.h
src/core/include/megbrain/common.h
+0
-7
src/core/include/megbrain/graph/operator_node.h
src/core/include/megbrain/graph/operator_node.h
+7
-0
src/opr/impl/search_policy/algo_chooser.cpp
src/opr/impl/search_policy/algo_chooser.cpp
+42
-18
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
+6
-6
未找到文件。
imperative/python/src/ops.cpp
浏览文件 @
e19b9af1
...
@@ -73,7 +73,7 @@ PyTypeObject PyOpType(name);
...
@@ -73,7 +73,7 @@ PyTypeObject PyOpType(name);
} \
} \
} while (0)
} while (0)
template
<
typename
T
,
typename
SFINAE
=
void
>
template
<
typename
T
,
typename
SFINAE
=
void
>
struct
pyobj_convert_generic
{
struct
pyobj_convert_generic
{
static
T
from
(
PyObject
*
obj
)
{
static
T
from
(
PyObject
*
obj
)
{
// TODO: remove this guard which is used for pybind11 implicit conversion
// TODO: remove this guard which is used for pybind11 implicit conversion
...
@@ -87,7 +87,12 @@ struct pyobj_convert_generic {
...
@@ -87,7 +87,12 @@ struct pyobj_convert_generic {
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
EnumTrait
{
static
constexpr
bool
is_bit_combined
=
false
;
};
template
<
typename
T
>
PyObject
*
py_new_generic
(
PyTypeObject
*
type
,
PyObject
*
,
PyObject
*
)
{
PyObject
*
py_new_generic
(
PyTypeObject
*
type
,
PyObject
*
,
PyObject
*
)
{
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
T
*
self
=
reinterpret_cast
<
T
*>
(
obj
);
T
*
self
=
reinterpret_cast
<
T
*>
(
obj
);
...
@@ -203,9 +208,10 @@ struct EnumWrapper {
...
@@ -203,9 +208,10 @@ struct EnumWrapper {
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
pyobj_convert_generic
<
T
,
struct
pyobj_convert_generic
<
T
,
std
::
enable_if_t
<
std
::
is_enum_v
<
std
::
decay_t
<
T
>>>>
{
std
::
enable_if_t
<
std
::
is_enum_v
<
std
::
decay_t
<
T
>>
&&
!
EnumTrait
<
T
>::
is_bit_combined
>>
{
using
Wrapper
=
EnumWrapper
<
T
>
;
using
Wrapper
=
EnumWrapper
<
T
>
;
static
T
from
(
PyObject
*
obj
)
{
static
T
from
(
PyObject
*
obj
)
{
if
(
PyObject_TypeCheck
(
obj
,
&
Wrapper
::
type
))
{
if
(
PyObject_TypeCheck
(
obj
,
&
Wrapper
::
type
))
{
...
@@ -223,6 +229,115 @@ struct pyobj_convert_generic<T,
...
@@ -223,6 +229,115 @@ struct pyobj_convert_generic<T,
}
}
};
};
template
<
typename
T
>
struct
BitCombinedEnumWrapper
{
static_assert
(
std
::
is_enum_v
<
T
>
);
PyObject_HEAD
T
value
;
static
const
char
*
name
;
static
PyTypeObject
type
;
static
std
::
unordered_map
<
T
,
std
::
string
>
type2str
;
static
std
::
unordered_map
<
std
::
string
,
T
>
str2type
;
static
PyNumberMethods
number_methods
;
BitCombinedEnumWrapper
()
=
default
;
BitCombinedEnumWrapper
(
T
v
)
:
value
(
v
)
{}
BitCombinedEnumWrapper
(
std
::
string
&&
str
)
:
BitCombinedEnumWrapper
(
str2type
.
at
(
normalize_enum
(
str
)))
{}
std
::
string
to_string
()
const
{
if
(
static_cast
<
uint32_t
>
(
value
)
==
0
)
{
return
"None"
;
}
else
{
auto
ret
=
std
::
string
();
bool
first
=
true
;
for
(
uint32_t
i
=
0
;
i
<
32
;
i
++
)
{
uint32_t
value_int
=
static_cast
<
uint32_t
>
(
value
);
auto
it
=
type2str
.
find
(
static_cast
<
T
>
((
1
<<
i
)
&
value_int
));
if
(
it
!=
type2str
.
end
())
{
if
(
!
first
)
{
ret
+=
" + "
;
}
else
{
first
=
false
;
}
ret
+=
(
std
::
string
(
name
)
+
"."
+
it
->
second
);
}
}
return
ret
;
}
}
static
PyObject
*
py_new_combined_enum
(
PyTypeObject
*
type
,
PyObject
*
,
PyObject
*
)
{
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
static_cast
<
T
>
(
1
);
return
obj
;
}
static
int
py_init
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
)
{
int
input
=
1
;
if
(
PyArg_ParseTuple
(
args
,
"|i"
,
&
input
)){
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
=
static_cast
<
T
>
(
input
);
}
return
0
;
}
static
PyObject
*
py_repr
(
PyObject
*
self
)
{
return
pyobj_convert_generic
<
std
::
string
>::
to
(
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
to_string
());
}
static
PyObject
*
py_or
(
PyObject
*
self
,
PyObject
*
other
)
{
if
(
!
(
self
->
ob_type
==
other
->
ob_type
)){
return
PyErr_Format
(
PyExc_RuntimeError
,
"Operand in or operator must be the same type."
);
}
PyObject
*
obj
=
type
.
tp_alloc
(
&
type
,
0
);
T
lhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
,
rhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
other
)
->
value
;
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
static_cast
<
T
>
(
static_cast
<
uint32_t
>
(
lhs
)
|
static_cast
<
uint32_t
>
(
rhs
));
return
obj
;
}
static
PyObject
*
py_and
(
PyObject
*
self
,
PyObject
*
other
)
{
if
(
!
(
self
->
ob_type
==
other
->
ob_type
))
{
return
PyErr_Format
(
PyExc_RuntimeError
,
"Operand in and operator must be the same type."
);
}
PyObject
*
obj
=
type
.
tp_alloc
(
&
type
,
0
);
T
lhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
,
rhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
other
)
->
value
;
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
static_cast
<
T
>
(
static_cast
<
uint32_t
>
(
lhs
)
&
static_cast
<
uint32_t
>
(
rhs
));
return
obj
;
}
static
PyObject
*
tp_richcompare
(
PyObject
*
self
,
PyObject
*
other
,
int
op
)
{
T
lhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
,
rhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
other
)
->
value
;
if
(
op
==
Py_EQ
||
op
==
Py_NE
)
{
RETURN_RICHCOMPARE
(
lhs
,
rhs
,
op
);
}
Py_RETURN_NOTIMPLEMENTED
;
}
};
template
<
typename
T
>
struct
pyobj_convert_generic
<
T
,
std
::
enable_if_t
<
std
::
is_enum_v
<
std
::
decay_t
<
T
>>
&&
EnumTrait
<
T
>::
is_bit_combined
>>
{
using
Wrapper
=
BitCombinedEnumWrapper
<
T
>
;
static
T
from
(
PyObject
*
obj
)
{
if
(
PyObject_TypeCheck
(
obj
,
&
Wrapper
::
type
))
{
return
reinterpret_cast
<
Wrapper
*>
(
obj
)
->
value
;
}
// try as string
// TODO: type checkcd
return
Wrapper
(
pyobj_convert_generic
<
std
::
string
>::
from
(
obj
)).
value
;
}
static
PyObject
*
to
(
T
t
)
{
PyTypeObject
*
pytype
=
&
Wrapper
::
type
;
PyObject
*
obj
=
pytype
->
tp_alloc
(
pytype
,
0
);
reinterpret_cast
<
Wrapper
*>
(
obj
)
->
value
=
t
;
return
obj
;
}
};
void
_init_py_op_def
(
py
::
module
m
)
{
void
_init_py_op_def
(
py
::
module
m
)
{
using
py_op
=
PyOp
(
OpDef
);
using
py_op
=
PyOp
(
OpDef
);
auto
&
py_type
=
PyOpType
(
OpDef
);
auto
&
py_type
=
PyOpType
(
OpDef
);
...
...
imperative/tablegen/autogen.cpp
浏览文件 @
e19b9af1
...
@@ -408,19 +408,14 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
...
@@ -408,19 +408,14 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
os
<<
";
\n\n
"
;
os
<<
";
\n\n
"
;
}
}
static
void
gen_op_def_python_c_extension_single
(
raw_ostream
&
os
,
MgbOp
&
op
,
EnumContext
&
ctx
)
{
static
std
::
string
gen_op_def_python_c_extension_enum
(
auto
className
=
op
.
getCppClassName
();
raw_ostream
&
os
,
EnumContext
&
ctx
,
MgbEnumAttr
*
attr
,
llvm
::
StringRef
className
)
{
std
::
string
body
;
std
::
string
body
;
// generate PyType for enum class member
for
(
auto
&&
i
:
op
.
getMgbAttributes
())
{
if
(
auto
attr
=
llvm
::
dyn_cast
<
MgbEnumAttr
>
(
&
i
.
attr
))
{
unsigned
int
enumID
;
unsigned
int
enumID
;
if
(
auto
alias
=
llvm
::
dyn_cast
<
MgbAliasAttr
>
(
attr
))
{
if
(
auto
alias
=
llvm
::
dyn_cast
<
MgbAliasAttr
>
(
attr
))
{
auto
&&
aliasBase
=
alias
->
getAliasBase
();
auto
&&
aliasBase
=
alias
->
getAliasBase
();
enumID
=
enumID
=
llvm
::
cast
<
MgbEnumAttr
>
(
aliasBase
).
getBaseRecord
()
->
getID
();
llvm
::
cast
<
MgbEnumAttr
>
(
aliasBase
)
.
getBaseRecord
()
->
getID
();
}
else
{
}
else
{
enumID
=
attr
->
getBaseRecord
()
->
getID
();
enumID
=
attr
->
getBaseRecord
()
->
getID
();
}
}
...
@@ -428,20 +423,20 @@ static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, Enu
...
@@ -428,20 +423,20 @@ static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, Enu
auto
&&
iter
=
enumAlias
.
find
(
enumID
);
auto
&&
iter
=
enumAlias
.
find
(
enumID
);
auto
enumName
=
attr
->
getEnumName
();
auto
enumName
=
attr
->
getEnumName
();
body
+=
"{
\n
"
;
body
+=
"{
\n
"
;
body
+=
formatv
(
body
+=
formatv
(
"auto& e_type = EnumWrapper<{0}::{1}>::type;"
,
className
,
"auto& e_type = EnumWrapper<{0}::{1}>::type;"
,
className
,
enumName
enumName
);
);
if
(
iter
==
enumAlias
.
end
())
{
if
(
iter
==
enumAlias
.
end
())
{
os
<<
formatv
(
os
<<
formatv
(
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};
\n
"
,
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};
\n
"
,
className
,
enumName
);
className
,
enumName
);
os
<<
formatv
(
os
<<
formatv
(
"template<> const char* EnumWrapper<{0}::{1}>::name =
\"
{0}.{1}
\"
;
\n
"
,
"template<> const char* EnumWrapper<{0}::{1}>::name = "
"
\"
{0}.{1}
\"
;
\n
"
,
className
,
enumName
);
className
,
enumName
);
std
::
vector
<
std
::
string
>
pairStr
;
std
::
vector
<
std
::
string
>
pairStr
;
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
pairStr
.
push_back
(
formatv
(
pairStr
.
push_back
(
"{{normalize_enum(
\"
{2}
\"
), {0}::{1}::{2}}"
,
formatv
(
"{{normalize_enum(
\"
{2}
\"
), {0}::{1}::{2}}"
,
className
,
enumName
,
i
));
className
,
enumName
,
i
));
}
}
os
<<
formatv
(
R"(
os
<<
formatv
(
R"(
...
@@ -449,11 +444,12 @@ template<> std::unordered_map<std::string, {0}::{1}>
...
@@ -449,11 +444,12 @@ template<> std::unordered_map<std::string, {0}::{1}>
EnumWrapper<{0}::{1}>::str2type = {{
EnumWrapper<{0}::{1}>::str2type = {{
{2}
{2}
};
};
)"
,
className
,
enumName
,
llvm
::
join
(
pairStr
,
", "
));
)"
,
className
,
enumName
,
llvm
::
join
(
pairStr
,
", "
));
pairStr
.
clear
();
pairStr
.
clear
();
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
pairStr
.
push_back
(
formatv
(
pairStr
.
push_back
(
"{{{0}::{1}::{2}, normalize_enum(
\"
{2}
\"
)}"
,
formatv
(
"{{{0}::{1}::{2}, normalize_enum(
\"
{2}
\"
)}"
,
className
,
enumName
,
i
));
className
,
enumName
,
i
));
}
}
os
<<
formatv
(
R"(
os
<<
formatv
(
R"(
...
@@ -461,7 +457,8 @@ template<> std::unordered_map<{0}::{1}, std::string>
...
@@ -461,7 +457,8 @@ template<> std::unordered_map<{0}::{1}, std::string>
EnumWrapper<{0}::{1}>::type2str = {{
EnumWrapper<{0}::{1}>::type2str = {{
{2}
{2}
};
};
)"
,
className
,
enumName
,
llvm
::
join
(
pairStr
,
", "
));
)"
,
className
,
enumName
,
llvm
::
join
(
pairStr
,
", "
));
body
+=
formatv
(
R"(
body
+=
formatv
(
R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
...
@@ -472,13 +469,113 @@ EnumWrapper<{0}::{1}>::type2str = {{
...
@@ -472,13 +469,113 @@ EnumWrapper<{0}::{1}>::type2str = {{
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
mgb_assert(PyType_Ready(&e_type) >= 0);
mgb_assert(PyType_Ready(&e_type) >= 0);
)"
,
className
,
enumName
);
)"
,
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
className
,
enumName
);
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
body
+=
formatv
(
R"({{
body
+=
formatv
(
R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0);
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})"
,
className
,
enumName
,
i
);
})"
,
className
,
enumName
,
i
);
}
enumAlias
.
emplace
(
enumID
,
std
::
make_pair
(
className
,
enumName
));
}
body
+=
formatv
(
R"(
PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)"
,
enumName
);
body
+=
"}
\n
"
;
return
body
;
}
static
std
::
string
gen_op_def_python_c_extension_bit_combined_enum
(
raw_ostream
&
os
,
EnumContext
&
ctx
,
MgbEnumAttr
*
attr
,
llvm
::
StringRef
className
)
{
std
::
string
body
;
unsigned
int
enumID
;
if
(
auto
alias
=
llvm
::
dyn_cast
<
MgbAliasAttr
>
(
attr
))
{
auto
&&
aliasBase
=
alias
->
getAliasBase
();
enumID
=
llvm
::
cast
<
MgbEnumAttr
>
(
aliasBase
).
getBaseRecord
()
->
getID
();
}
else
{
enumID
=
attr
->
getBaseRecord
()
->
getID
();
}
auto
&&
enumAlias
=
ctx
.
enumAlias
;
auto
&&
iter
=
enumAlias
.
find
(
enumID
);
auto
enumName
=
attr
->
getEnumName
();
body
+=
"{
\n
"
;
body
+=
formatv
(
"auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;"
,
className
,
enumName
);
if
(
iter
==
enumAlias
.
end
())
{
os
<<
formatv
(
"template<> PyTypeObject "
"BitCombinedEnumWrapper<{0}::{1}>::type={{};
\n
"
,
className
,
enumName
);
os
<<
formatv
(
"template<> PyNumberMethods "
"BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};
\n
"
,
className
,
enumName
);
os
<<
formatv
(
"template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name "
"=
\"
{0}.{1}
\"
;
\n
"
,
className
,
enumName
);
os
<<
formatv
(
"template<> struct EnumTrait<{0}::{1}> {{ static constexpr "
"bool is_bit_combined = true;};
\n
"
,
className
,
enumName
);
std
::
vector
<
std
::
string
>
pairStr
;
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
pairStr
.
push_back
(
formatv
(
"{{normalize_enum(
\"
{2}
\"
), {0}::{1}::{2}}"
,
className
,
enumName
,
i
));
}
os
<<
formatv
(
R"(
template<> std::unordered_map<std::string, {0}::{1}>
BitCombinedEnumWrapper<{0}::{1}>::str2type = {{
{2}
};
)"
,
className
,
enumName
,
llvm
::
join
(
pairStr
,
", "
));
pairStr
.
clear
();
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
pairStr
.
push_back
(
formatv
(
"{{{0}::{1}::{2}, normalize_enum(
\"
{2}
\"
)}"
,
className
,
enumName
,
i
));
}
os
<<
formatv
(
R"(
template<> std::unordered_map<{0}::{1}, std::string>
BitCombinedEnumWrapper<{0}::{1}>::type2str = {{
{2}
};
)"
,
className
,
enumName
,
llvm
::
join
(
pairStr
,
", "
));
body
+=
formatv
(
R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "{0}.{1}";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum;
e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init;
e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare;
auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods;
number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or;
number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and;
e_type.tp_as_number = &number_method;
mgb_assert(PyType_Ready(&e_type) >= 0);
)"
,
className
,
enumName
);
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
body
+=
formatv
(
R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})"
,
className
,
enumName
,
i
);
}
}
enumAlias
.
emplace
(
enumID
,
std
::
make_pair
(
className
,
enumName
));
enumAlias
.
emplace
(
enumID
,
std
::
make_pair
(
className
,
enumName
));
}
}
...
@@ -486,8 +583,26 @@ EnumWrapper<{0}::{1}>::type2str = {{
...
@@ -486,8 +583,26 @@ EnumWrapper<{0}::{1}>::type2str = {{
PyType_Modified(&e_type);
PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString(
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)"
,
enumName
);
)"
,
enumName
);
body
+=
"}
\n
"
;
body
+=
"}
\n
"
;
return
body
;
}
static
void
gen_op_def_python_c_extension_single
(
raw_ostream
&
os
,
MgbOp
&
op
,
EnumContext
&
ctx
)
{
auto
className
=
op
.
getCppClassName
();
std
::
string
body
;
// generate PyType for enum class member
for
(
auto
&&
i
:
op
.
getMgbAttributes
())
{
if
(
auto
attr
=
llvm
::
dyn_cast
<
MgbEnumAttr
>
(
&
i
.
attr
))
{
if
(
attr
->
getEnumCombinedFlag
())
{
body
+=
gen_op_def_python_c_extension_bit_combined_enum
(
os
,
ctx
,
attr
,
className
);
}
else
{
body
+=
gen_op_def_python_c_extension_enum
(
os
,
ctx
,
attr
,
className
);
}
}
}
}
}
...
...
sdk/load-and-run/src/mgblar.cpp
浏览文件 @
e19b9af1
...
@@ -141,15 +141,13 @@ R"__usage__(
...
@@ -141,15 +141,13 @@ R"__usage__(
)__usage__"
)__usage__"
#if MGB_ENABLE_FASTRUN
#if MGB_ENABLE_FASTRUN
R"__usage__(
R"__usage__(
--fast-run
--full-run
This param will be deperated later, please replace with param --full-profile.
Enable full-run mode. Operators with multiple algorithms would be profiled
--full-profile
Enable full-profile mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, all algorithms will be profiled
on the real device with actual input shapes, all algorithms will be profiled
include naive algorithms.
include naive algorithms.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
--fast-
profile
--fast-
run
Enable fast-
profile
mode. Operators with multiple algorithms would be profiled
Enable fast-
run
mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, this mode will only profile the
on the real device with actual input shapes, this mode will only profile the
well optimized algorithms to get the profile result fast.
well optimized algorithms to get the profile result fast.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
...
@@ -519,8 +517,8 @@ struct Args {
...
@@ -519,8 +517,8 @@ struct Args {
bool
disable_assert_throw
=
false
;
bool
disable_assert_throw
=
false
;
bool
share_param_mem
=
false
;
bool
share_param_mem
=
false
;
#if MGB_ENABLE_FASTRUN
#if MGB_ENABLE_FASTRUN
bool
use_full_
profile
=
false
;
bool
use_full_
run
=
false
;
bool
use_fast_
profile
=
false
;
bool
use_fast_
run
=
false
;
#endif
#endif
bool
reproducible
=
false
;
bool
reproducible
=
false
;
std
::
string
fast_run_cache_path
;
std
::
string
fast_run_cache_path
;
...
@@ -704,13 +702,13 @@ void run_test_st(Args &env) {
...
@@ -704,13 +702,13 @@ void run_test_st(Args &env) {
using
S
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
using
S
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
S
strategy
=
S
::
HEURISTIC
;
S
strategy
=
S
::
HEURISTIC
;
#if MGB_ENABLE_FASTRUN
#if MGB_ENABLE_FASTRUN
if
(
env
.
use_full_
profile
)
{
if
(
env
.
use_full_
run
)
{
if
(
env
.
reproducible
)
{
if
(
env
.
reproducible
)
{
strategy
=
S
::
PROFILE
|
S
::
REPRODUCIBLE
;
strategy
=
S
::
PROFILE
|
S
::
REPRODUCIBLE
;
}
else
{
}
else
{
strategy
=
S
::
PROFILE
;
strategy
=
S
::
PROFILE
;
}
}
}
else
if
(
env
.
use_fast_
profile
)
{
}
else
if
(
env
.
use_fast_
run
)
{
strategy
=
S
::
PROFILE
|
S
::
OPTMIZED
;
strategy
=
S
::
PROFILE
|
S
::
OPTMIZED
;
}
else
if
(
env
.
reproducible
)
{
}
else
if
(
env
.
reproducible
)
{
strategy
=
S
::
HEURISTIC
|
S
::
REPRODUCIBLE
;
strategy
=
S
::
HEURISTIC
|
S
::
REPRODUCIBLE
;
...
@@ -740,12 +738,12 @@ void run_test_st(Args &env) {
...
@@ -740,12 +738,12 @@ void run_test_st(Args &env) {
std
::
make_shared
<
InFilePersistentCache
>
(
buf
.
get
(),
flen
));
std
::
make_shared
<
InFilePersistentCache
>
(
buf
.
get
(),
flen
));
#if MGB_ENABLE_FASTRUN
#if MGB_ENABLE_FASTRUN
}
else
{
}
else
{
mgb_assert
(
env
.
use_full_
profile
||
env
.
use_fast_profile
,
mgb_assert
(
env
.
use_full_
run
||
env
.
use_fast_run
,
"fast-run or fast-
profile
should be enabled"
);
"fast-run or fast-
run
should be enabled"
);
PersistentCache
::
set_impl
(
PersistentCache
::
set_impl
(
std
::
make_shared
<
InFilePersistentCache
>
());
std
::
make_shared
<
InFilePersistentCache
>
());
}
}
if
(
!
env
.
use_full_
profile
&&
!
env
.
use_fast_profile
)
if
(
!
env
.
use_full_
run
&&
!
env
.
use_fast_run
)
#endif
#endif
mgb
::
gopt
::
enable_opr_use_profiling_cache_inplace
(
vars
);
mgb
::
gopt
::
enable_opr_use_profiling_cache_inplace
(
vars
);
}
}
...
@@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) {
...
@@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) {
}
}
#if MGB_ENABLE_FASTRUN
#if MGB_ENABLE_FASTRUN
if
(
!
strcmp
(
argv
[
i
],
"--fast-run"
))
{
if
(
!
strcmp
(
argv
[
i
],
"--fast-run"
))
{
mgb_log_warn
(
ret
.
use_fast_run
=
true
;
"--fast-run param will be deperated later, please replace "
"with --full-profile or --fast-profile."
);
ret
.
use_full_profile
=
true
;
continue
;
}
if
(
!
strcmp
(
argv
[
i
],
"--full-profile"
))
{
ret
.
use_full_profile
=
true
;
continue
;
continue
;
}
}
if
(
!
strcmp
(
argv
[
i
],
"--f
ast-profile
"
))
{
if
(
!
strcmp
(
argv
[
i
],
"--f
ull-run
"
))
{
ret
.
use_f
ast_profile
=
true
;
ret
.
use_f
ull_run
=
true
;
continue
;
continue
;
}
}
#endif
#endif
...
...
src/core/include/megbrain/common.h
浏览文件 @
e19b9af1
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
#pragma once
#pragma once
#include "megbrain_build_config.h"
#include "megbrain_build_config.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/basic_types.h"
#include "megdnn/basic_types.h"
#include <memory>
#include <memory>
...
@@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) {
...
@@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) {
}
// namespace mgb
}
// namespace mgb
namespace
megdnn
{
namespace
param
{
MGB_DEF_ENUM_CLASS_BIT_OPR
(
ExecutionPolicy
::
Strategy
)
}
}
// namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/core/include/megbrain/graph/operator_node.h
浏览文件 @
e19b9af1
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "megbrain/utils/hashable.h"
#include "megbrain/utils/hashable.h"
#include "megbrain/utils/thin/hash_table.h"
#include "megbrain/utils/thin/hash_table.h"
#include "megbrain/utils/small_vector.h"
#include "megbrain/utils/small_vector.h"
#include "megbrain/opr/param_defs.h"
#include <type_traits>
#include <type_traits>
...
@@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \
...
@@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \
}
// namespace cg
}
// namespace cg
}
// namespace mgb
}
// namespace mgb
namespace
megdnn
{
namespace
param
{
MGB_DEF_ENUM_CLASS_BIT_OPR
(
ExecutionPolicy
::
Strategy
)
}
}
// namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/impl/search_policy/algo_chooser.cpp
浏览文件 @
e19b9af1
...
@@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
...
@@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
return
ret
;
return
ret
;
}
}
//! Test whether the algo attribute of a algo match the require
//! algo_strategy
static
bool
algo_attribute_match_strategy
(
AlgoAttribute
attribute
,
ExecutionStrategy
selected_strategy
)
{
bool
ret
=
true
;
if
(
selected_strategy
&
ExecutionStrategy
::
OPTMIZED
)
{
ret
&=
(
!
static_cast
<
bool
>
(
AlgoAttribute
::
NAIVE
&
attribute
));
}
else
if
(
selected_strategy
&
ExecutionStrategy
::
REPRODUCIBLE
)
{
ret
&=
static_cast
<
bool
>
(
AlgoAttribute
::
REPRODUCIBLE
&
attribute
);
}
return
ret
;
}
}
// namespace
}
// namespace
namespace
mgb
{
namespace
mgb
{
...
@@ -285,8 +298,8 @@ namespace opr {
...
@@ -285,8 +298,8 @@ namespace opr {
template
<
typename
Opr
>
template
<
typename
Opr
>
void
AlgoChooser
<
Opr
>::
profile
(
ExeContext
&
ctx
,
void
AlgoChooser
<
Opr
>::
profile
(
ExeContext
&
ctx
,
ExecutionStrategy
select_strategy
)
{
ExecutionStrategy
select
ed
_strategy
)
{
if
(
ctx
.
get_profile_result_from_cache
(
select_strategy
).
valid
())
if
(
ctx
.
get_profile_result_from_cache
(
select
ed
_strategy
).
valid
())
return
;
return
;
AlgoChooserProfileCache
::
Result
prof_rst
;
AlgoChooserProfileCache
::
Result
prof_rst
;
...
@@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
...
@@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
algo
.
name
.
c_str
(),
str_on_inp_shape
.
c_str
());
algo
.
name
.
c_str
(),
str_on_inp_shape
.
c_str
());
ImplExecutionPolicy
policy
;
ImplExecutionPolicy
policy
;
policy
.
algo
=
algo
.
desc
;
policy
.
algo
=
algo
.
desc
;
ctx
.
construct_execution_policy
(
select_strategy
,
policy
);
ctx
.
construct_execution_policy
(
select
ed
_strategy
,
policy
);
if
(
ctx
.
get_workspace_size_bytes
(
policy
)
>=
workspace_limit
)
if
(
ctx
.
get_workspace_size_bytes
(
policy
)
>=
workspace_limit
)
{
continue
;
continue
;
}
auto
algo_attribute
=
ctx
.
megdnn_opr
()
->
get_algorithm_from_desc
(
policy
.
algo
)
->
attribute
();
if
(
!
algo_attribute_match_strategy
(
algo_attribute
,
selected_strategy
))
{
mgb_log_debug
(
"skip algo %s, which is not match the profile strategy."
,
algo
.
name
.
c_str
());
continue
;
}
timer
.
reset
();
timer
.
reset
();
MGB_TRY
{
cur_rst
=
ctx
.
profile_single_algo
(
policy
,
cur_timeout
);
}
MGB_TRY
{
cur_rst
=
ctx
.
profile_single_algo
(
policy
,
cur_timeout
);
}
...
@@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
...
@@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
template
<
typename
Opr
>
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
AlgoChooser
<
Opr
>::
choose_by_profile
(
ExeContext
&
ctx
,
AlgoChooser
<
Opr
>::
choose_by_profile
(
ExeContext
&
ctx
,
ExecutionStrategy
select_strategy
,
ExecutionStrategy
select
ed
_strategy
,
bool
enable_update
)
{
bool
enable_update
)
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"AlgoChooser::choose_by_profile"
)))
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"AlgoChooser::choose_by_profile"
)))
if
(
ctx
.
owner_graph
()
->
options
().
no_profiling_on_shape_change
)
{
if
(
ctx
.
owner_graph
()
->
options
().
no_profiling_on_shape_change
)
{
...
@@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
...
@@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
to_fixed_layouts
<
_Opr
>
(
_item
.
layouts
),
megdnn_opr
.
get
(),
to_fixed_layouts
<
_Opr
>
(
_item
.
layouts
),
megdnn_opr
.
get
(),
_item
.
param
,
ctx
.
mgb_opr
(),
ctx
.
comp_node
(),
_item
.
param
,
ctx
.
mgb_opr
(),
ctx
.
comp_node
(),
ctx
.
execution_policy
(),
ctx
.
allow_weight_preprocess
());
ctx
.
execution_policy
(),
ctx
.
allow_weight_preprocess
());
AlgoChooser
<
_Opr
>::
profile
(
sub_ctx
,
select_strategy
);
AlgoChooser
<
_Opr
>::
profile
(
sub_ctx
,
select
ed
_strategy
);
});
});
}
}
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
policy
;
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
policy
;
ctx
.
construct_execution_policy
(
select_strategy
,
policy
);
ctx
.
construct_execution_policy
(
select
ed
_strategy
,
policy
);
return
policy
;
return
policy
;
MIDOUT_E
MIDOUT_E
}
}
...
@@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
...
@@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
if
(
!
policy
.
algo
.
valid
())
if
(
!
policy
.
algo
.
valid
())
policy
=
ctx
.
choose_by_heuristic
(
opr_strategy
);
policy
=
ctx
.
choose_by_heuristic
(
opr_strategy
);
return
policy
;
return
policy
;
}
else
if
((
opr_strategy
&
ExecutionStrategy
::
HEURISTIC
))
{
}
else
if
(
!
static_cast
<
int
>
(
opr_strategy
)
||
(
opr_strategy
&
ExecutionStrategy
::
HEURISTIC
))
{
return
ctx
.
choose_by_heuristic
(
opr_strategy
);
return
ctx
.
choose_by_heuristic
(
opr_strategy
);
}
}
#if MGB_ENABLE_FASTRUN
#if MGB_ENABLE_FASTRUN
...
@@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
...
@@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
}
}
#endif
#endif
else
{
else
{
mgb_throw
(
GraphError
,
"bad
convolution
ExecutionPolicy strategy"
);
mgb_throw
(
GraphError
,
"bad ExecutionPolicy strategy"
);
}
}
}
}
...
@@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext(
...
@@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext(
template
<
typename
Opr
>
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplAlgo
typename
AlgoChooser
<
Opr
>::
ImplAlgo
AlgoChooser
<
Opr
>::
ExeContext
::
get_profile_result_from_cache
(
AlgoChooser
<
Opr
>::
ExeContext
::
get_profile_result_from_cache
(
ExecutionStrategy
select_strategy
)
const
{
ExecutionStrategy
select
ed
_strategy
)
const
{
MIDOUT_B
(
Opr
,
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
midout_iv
(
MGB_HASH_STR
(
"AlgoChooser::ExeContext::get_profile_result_from_cache"
)))
"AlgoChooser::ExeContext::get_profile_result_from_cache"
)))
...
@@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
...
@@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
if
(
prof
.
empty
())
if
(
prof
.
empty
())
return
{};
return
{};
for
(
auto
&&
i
:
prof
)
{
for
(
auto
&&
i
:
prof
)
{
if
(
!
(
select_strategy
&
ExecutionStrategy
::
REPRODUCIBLE
)
||
if
(
!
(
select
ed
_strategy
&
ExecutionStrategy
::
REPRODUCIBLE
)
||
static_cast
<
AlgoAttribute
>
(
i
.
attribute
)
&
static_cast
<
AlgoAttribute
>
(
i
.
attribute
)
&
AlgoAttribute
::
REPRODUCIBLE
)
{
AlgoAttribute
::
REPRODUCIBLE
)
{
auto
iter
=
algo_map
.
find
(
i
.
algo
);
auto
iter
=
algo_map
.
find
(
i
.
algo
);
...
@@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
...
@@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
template
<
typename
Opr
>
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
AlgoChooser
<
Opr
>::
ExeContext
::
choose_by_heuristic
(
AlgoChooser
<
Opr
>::
ExeContext
::
choose_by_heuristic
(
ExecutionStrategy
select_strategy
)
const
{
ExecutionStrategy
select
ed
_strategy
)
const
{
if
(
m_execution_policy
.
workspace_limit
!=
if
(
m_execution_policy
.
workspace_limit
!=
std
::
numeric_limits
<
decltype
(
std
::
numeric_limits
<
decltype
(
m_execution_policy
.
workspace_limit
)
>::
max
())
{
m_execution_policy
.
workspace_limit
)
>::
max
())
{
...
@@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
...
@@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
"workspace_limit should not be setted if choose algo by "
"workspace_limit should not be setted if choose algo by "
"heuristic"
);
"heuristic"
);
}
}
bool
reproducible
=
static_cast
<
bool
>
(
select_strategy
&
bool
reproducible
=
static_cast
<
bool
>
(
select
ed
_strategy
&
ExecutionStrategy
::
REPRODUCIBLE
);
ExecutionStrategy
::
REPRODUCIBLE
);
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
owner_graph
(),
m_cn
,
m_execution_policy
.
workspace_limit
);
owner_graph
(),
m_cn
,
m_execution_policy
.
workspace_limit
);
...
@@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
...
@@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
_item
.
param
,
m_base_mgb_opr
,
m_cn
,
m_execution_policy
,
_item
.
param
,
m_base_mgb_opr
,
m_cn
,
m_execution_policy
,
m_allow_weight_preprocess
);
m_allow_weight_preprocess
);
policy
.
sub_policy
.
push_back
(
policy
.
sub_policy
.
push_back
(
sub_ctx
.
choose_by_heuristic
(
select_strategy
));
sub_ctx
.
choose_by_heuristic
(
select
ed
_strategy
));
});
});
return
policy
;
return
policy
;
...
@@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const {
...
@@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const {
template
<
typename
Opr
>
template
<
typename
Opr
>
void
AlgoChooser
<
Opr
>::
ExeContext
::
construct_execution_policy
(
void
AlgoChooser
<
Opr
>::
ExeContext
::
construct_execution_policy
(
ExecutionStrategy
select_strategy
,
ExecutionStrategy
select
ed
_strategy
,
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
&
policy
,
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
&
policy
,
bool
retrive_from_cache
)
const
{
bool
retrive_from_cache
)
const
{
bool
reproducible
=
static_cast
<
bool
>
(
select_strategy
&
bool
reproducible
=
static_cast
<
bool
>
(
select
ed
_strategy
&
ExecutionStrategy
::
REPRODUCIBLE
);
ExecutionStrategy
::
REPRODUCIBLE
);
if
(
!
policy
.
algo
.
valid
())
{
if
(
!
policy
.
algo
.
valid
())
{
if
(
retrive_from_cache
)
{
if
(
retrive_from_cache
)
{
policy
.
algo
=
policy
.
algo
=
get_profile_result_from_cache
(
select_strategy
).
desc
;
get_profile_result_from_cache
(
select
ed
_strategy
).
desc
;
}
else
{
}
else
{
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
owner_graph
(),
m_cn
,
m_execution_policy
.
workspace_limit
);
owner_graph
(),
m_cn
,
m_execution_policy
.
workspace_limit
);
...
@@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
...
@@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
_item
.
param
,
m_base_mgb_opr
,
m_cn
,
m_execution_policy
,
_item
.
param
,
m_base_mgb_opr
,
m_cn
,
m_execution_policy
,
m_allow_weight_preprocess
);
m_allow_weight_preprocess
);
policy
.
sub_policy
.
push_back
({});
policy
.
sub_policy
.
push_back
({});
sub_ctx
.
construct_execution_policy
(
select_strategy
,
sub_ctx
.
construct_execution_policy
(
select
ed
_strategy
,
policy
.
sub_policy
.
back
(),
policy
.
sub_policy
.
back
(),
retrive_from_cache
);
retrive_from_cache
);
});
});
...
...
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
浏览文件 @
e19b9af1
...
@@ -110,7 +110,7 @@ public:
...
@@ -110,7 +110,7 @@ public:
const
FixedTensorLayouts
&
layouts
()
const
{
return
m_layouts
;
}
const
FixedTensorLayouts
&
layouts
()
const
{
return
m_layouts
;
}
ImplExecutionPolicy
choose_by_heuristic
(
ImplExecutionPolicy
choose_by_heuristic
(
ExecutionStrategy
select_strategy
)
const
;
ExecutionStrategy
select
ed
_strategy
)
const
;
//! get all candidate algos, and the one choose_by_heuristic() is
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
//! put first
...
@@ -134,17 +134,17 @@ public:
...
@@ -134,17 +134,17 @@ public:
//! get all profile algorithm from cache, return invalid if not exists
//! get all profile algorithm from cache, return invalid if not exists
ImplAlgo
get_profile_result_from_cache
(
ImplAlgo
get_profile_result_from_cache
(
ExecutionStrategy
select_strategy
)
const
;
ExecutionStrategy
select
ed
_strategy
)
const
;
/**
/**
* \brief construct execution policy from cache or heuristic.
* \brief construct execution policy from cache or heuristic.
*
*
* \param select_strategy select algo which matched this strategy
* \param select
ed
_strategy select algo which matched this strategy
* \param policy execution policy
* \param policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get
* \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise.
* from heuristic otherwise.
*/
*/
void
construct_execution_policy
(
ExecutionStrategy
select_strategy
,
void
construct_execution_policy
(
ExecutionStrategy
select
ed
_strategy
,
ImplExecutionPolicy
&
policy
,
ImplExecutionPolicy
&
policy
,
bool
retrive_from_cache
=
true
)
const
;
bool
retrive_from_cache
=
true
)
const
;
...
@@ -161,10 +161,10 @@ private:
...
@@ -161,10 +161,10 @@ private:
//! profile and save to cache
//! profile and save to cache
static
void
profile
(
ExeContext
&
ctx
,
ExecutionStrategy
select_strategy
);
static
void
profile
(
ExeContext
&
ctx
,
ExecutionStrategy
select
ed
_strategy
);
static
ImplExecutionPolicy
choose_by_profile
(
static
ImplExecutionPolicy
choose_by_profile
(
ExeContext
&
ctx
,
ExecutionStrategy
select_strategy
,
ExeContext
&
ctx
,
ExecutionStrategy
select
ed
_strategy
,
bool
enable_update
=
true
);
bool
enable_update
=
true
);
public:
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录