Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
80c47053
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
80c47053
编写于
7月 21, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mgb): use midout in megbrain to reduce binary size
GitOrigin-RevId: ddc8af79af90737cb6f55ae5c1fc95cba722eef9
上级
35c71276
变更
32
隐藏空白更改
内联
并排
Showing
32 changed file
with
361 addition
and
25 deletion
+361
-25
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/gopt/impl/basic_arith/chain.cpp
src/gopt/impl/basic_arith/chain.cpp
+22
-1
src/gopt/impl/basic_arith/inplace.cpp
src/gopt/impl/basic_arith/inplace.cpp
+12
-0
src/gopt/impl/basic_arith/trans.cpp
src/gopt/impl/basic_arith/trans.cpp
+18
-0
src/gopt/impl/inference.cpp
src/gopt/impl/inference.cpp
+31
-0
src/gopt/impl/misc.cpp
src/gopt/impl/misc.cpp
+28
-0
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+25
-1
src/gopt/impl/weights_preprocess.cpp
src/gopt/impl/weights_preprocess.cpp
+12
-0
src/opr-mm/impl/collective_comm.cpp
src/opr-mm/impl/collective_comm.cpp
+2
-0
src/opr-mm/impl/io_remote.cpp
src/opr-mm/impl/io_remote.cpp
+2
-0
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+10
-1
src/opr/impl/blas.cpp
src/opr/impl/blas.cpp
+13
-0
src/opr/impl/cond.cpp
src/opr/impl/cond.cpp
+4
-0
src/opr/impl/dnn/batch_norm.cpp
src/opr/impl/dnn/batch_norm.cpp
+2
-0
src/opr/impl/dnn/convolution.cpp
src/opr/impl/dnn/convolution.cpp
+51
-12
src/opr/impl/dnn/images2neibs.cpp
src/opr/impl/dnn/images2neibs.cpp
+2
-0
src/opr/impl/dnn/local.cpp
src/opr/impl/dnn/local.cpp
+6
-0
src/opr/impl/dnn/lrn.cpp
src/opr/impl/dnn/lrn.cpp
+2
-0
src/opr/impl/dnn/pooling.cpp
src/opr/impl/dnn/pooling.cpp
+2
-0
src/opr/impl/dnn/roi_align.cpp
src/opr/impl/dnn/roi_align.cpp
+2
-0
src/opr/impl/dnn/roi_pooling.cpp
src/opr/impl/dnn/roi_pooling.cpp
+4
-0
src/opr/impl/imgproc.cpp
src/opr/impl/imgproc.cpp
+4
-0
src/opr/impl/indexing.cpp
src/opr/impl/indexing.cpp
+26
-0
src/opr/impl/io.cpp
src/opr/impl/io.cpp
+2
-0
src/opr/impl/loop/forward.cpp
src/opr/impl/loop/forward.cpp
+2
-0
src/opr/impl/misc.cpp
src/opr/impl/misc.cpp
+12
-1
src/opr/impl/muxing.cpp
src/opr/impl/muxing.cpp
+2
-0
src/opr/impl/rand.cpp
src/opr/impl/rand.cpp
+9
-7
src/opr/impl/tensor_gen.cpp
src/opr/impl/tensor_gen.cpp
+6
-1
src/opr/impl/tensor_manip.cpp
src/opr/impl/tensor_manip.cpp
+22
-1
src/opr/impl/utility.cpp
src/opr/impl/utility.cpp
+12
-0
src/plugin/impl/opr_footprint.cpp
src/plugin/impl/opr_footprint.cpp
+13
-0
未找到文件。
src/CMakeLists.txt
浏览文件 @
80c47053
...
...
@@ -43,6 +43,7 @@ add_library(megbrain OBJECT EXCLUDE_FROM_ALL ${SOURCES})
target_link_libraries
(
megbrain PUBLIC mgb_opr_param_defs
)
target_include_directories
(
megbrain
PUBLIC $<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
>
PRIVATE
${
PROJECT_SOURCE_DIR
}
/third_party/midout/src
)
foreach
(
INCPATH IN LISTS MGB_INC
)
target_include_directories
(
megbrain
...
...
src/gopt/impl/basic_arith/chain.cpp
浏览文件 @
80c47053
...
...
@@ -15,6 +15,20 @@
#include <deque>
//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/oprs.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_chain
)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_chain, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using
namespace
mgb
;
using
namespace
gopt
;
using
namespace
opr
;
...
...
@@ -132,6 +146,7 @@ const char* ExpandFusedArithPass::name() const {
}
void
ExpandFusedArithPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"ExpandFusedArithPass::apply"
)
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
](
OperatorNodeBase
*
opr
)
{
using
Mode
=
Elemwise
::
Mode
;
...
...
@@ -172,6 +187,7 @@ void ExpandFusedArithPass::apply(OptState &opt) const {
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ================ NormalizeArithChainPass ================ */
...
...
@@ -529,7 +545,9 @@ const char* NormalizeArithChainPass::name() const {
}
void
NormalizeArithChainPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"NormalizeArithChainPass::apply"
)
Impl
{
opt
};
MIDOUT_E
}
/* ================ ReorderArithChainPass ================ */
...
...
@@ -737,7 +755,9 @@ const char* ReorderArithChainPass::name() const {
}
void
ReorderArithChainPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"ReorderArithChainPass::apply"
)
Impl
{
*
this
,
opt
};
MIDOUT_E
}
/* ================ ArithFusePass ================ */
...
...
@@ -944,8 +964,9 @@ const char* ArithFusePass::name() const {
}
void
ArithFusePass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"ArithFusePass::apply"
)
Impl
{
opt
};
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/impl/basic_arith/inplace.cpp
浏览文件 @
80c47053
...
...
@@ -19,6 +19,16 @@
#include <cmath>
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_inplace
)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_inplace, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using
namespace
mgb
;
using
namespace
opr
;
using
namespace
gopt
;
...
...
@@ -150,8 +160,10 @@ bool gopt::has_inplace_basic_arith_opt(const cg::OperatorNodeBase& opr) {
const
inplace_optimize
::
OptimizerRegistry
&
inplace_optimize
::
optimizer_registry
()
{
MIDOUT_B
(
"inplace_optimize::optimizer_registry"
)
static
OptimizerRegistry
ret
=
make_optimizer_registry
();
return
ret
;
MIDOUT_E
}
inplace_optimize
::
OptimizerRegistry
...
...
src/gopt/impl/basic_arith/trans.cpp
浏览文件 @
80c47053
...
...
@@ -13,6 +13,20 @@
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/serialization/serializer.h"
//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/oprs.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_trans
)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_trans, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using
namespace
mgb
;
using
namespace
gopt
;
...
...
@@ -284,7 +298,9 @@ const char* ArithMulDistributePass::name() const {
}
void
ArithMulDistributePass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"ArithMulDistributePass::apply"
)
Impl
{
*
this
,
opt
};
MIDOUT_E
}
/* ================ FinalArithTransformPass ================ */
...
...
@@ -488,7 +504,9 @@ const char* FinalArithTransformPass::name() const {
}
void
FinalArithTransformPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"FinalArithTransformPass::apply"
)
Impl
{
*
this
,
opt
};
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...
...
src/gopt/impl/inference.cpp
浏览文件 @
80c47053
...
...
@@ -27,6 +27,7 @@
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/utils/hash_ct.h"
#include "megdnn/tensor_format.h"
...
...
@@ -36,6 +37,16 @@
#include "megbrain/gopt/misc.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_inference
)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_inference, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using
namespace
mgb
;
using
namespace
gopt
;
...
...
@@ -430,7 +441,9 @@ ParamRedistributePass::Impl::Impl(OptState &state):
}
void
ParamRedistributePass
::
apply
(
OptState
&
state
)
const
{
MIDOUT_B
(
"ParamRedistributePass::apply"
)
Impl
{
state
};
MIDOUT_E
}
/* ================ ParamFusePass ================ */
...
...
@@ -512,6 +525,7 @@ const char* ParamFusePass::name() const {
}
void
ParamFusePass
::
apply
(
OptState
&
state
)
const
{
MIDOUT_B
(
"ParamFusePass::apply"
)
auto
rewriter
=
state
.
graph
().
make_rewriter
();
auto
cg
=
state
.
graph
().
comp_graph
();
...
...
@@ -613,6 +627,7 @@ void ParamFusePass::apply(OptState &state) const {
state
.
graph
().
iter
(
replace_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ================ One2OneOprReplacePass ================ */
...
...
@@ -621,6 +636,7 @@ const char* ConvertF32ToF16Pass::name() const {
}
void
ConvertF32ToF16Pass
::
apply
(
OptState
&
state
)
const
{
MIDOUT_B
(
"ConvertF32ToF16Pass::apply"
)
state
.
set_var_replace_check_flag
(
m_var_replace_check_flag
);
auto
rewriter
=
state
.
graph
().
make_rewriter
();
VarNodeArray
new_inp_cache
;
...
...
@@ -674,6 +690,7 @@ void ConvertF32ToF16Pass::apply(OptState& state) const {
auto
opr
=
endpoints
[
0
].
node
()
->
owner_opr
();
state
.
call_with_opr
(
opr
,
replace_output
,
OprPropertyFlag
::
NONE
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
std
::
unique_ptr
<
ConvertF32ToF16Pass
>
ConvertF32ToF16Pass
::
make
(
...
...
@@ -940,6 +957,7 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
/* ================ ConvertFormatPass ================ */
void
ConvertFormatPass
::
apply
(
OptState
&
state
)
const
{
MIDOUT_B
(
"ConvertFormatPass::apply"
)
state
.
set_var_replace_check_flag
(
m_var_replace_check_flag
);
auto
rewriter
=
state
.
graph
().
make_rewriter
();
VarNodeArray
new_inp_cache
;
...
...
@@ -994,9 +1012,11 @@ void ConvertFormatPass::apply(OptState& state) const {
};
state
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
std
::
unique_ptr
<
ConvertFormatPass
>
ConvertFormatPass
::
make_nhwcd4_converter
()
{
MIDOUT_B
(
"ConvertFormatPass::make"
)
auto
filter_mode
=
[](
const
megdnn
::
param
::
Convolution
::
Sparse
conv_mode
,
const
VarNode
*
filter
)
->
megdnn
::
param
::
RelayoutFormat
::
Mode
{
...
...
@@ -1551,6 +1571,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func
[
opr
::
GroupLocalForward
::
typeinfo
()]
=
relayout_first_inp_to_chw
;
return
ret
;
MIDOUT_E
}
/* ================ ConvertBatchNormPass ================ */
...
...
@@ -1559,6 +1580,7 @@ const char* ConvertBatchNormToElemwisePass::name() const {
}
void
ConvertBatchNormToElemwisePass
::
apply
(
OptState
&
state
)
const
{
MIDOUT_B
(
"ConvertBatchNormToElemwisePass::apply"
)
auto
rewriter
=
state
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
](
OperatorNodeBase
*
opr
)
{
if
(
auto
bn
=
try_cast_as_op
<
opr
::
BatchNorm
>
(
opr
))
{
...
...
@@ -1586,6 +1608,7 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
state
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ================ FuseConvBiasNonlinPass ================ */
...
...
@@ -1594,6 +1617,7 @@ const char* FuseConvBiasNonlinPass::name() const {
}
void
FuseConvBiasNonlinPass
::
apply
(
OptState
&
state
)
const
{
MIDOUT_B
(
"FuseConvBiasNonlinPass::apply"
)
std
::
unordered_map
<
VarNode
*
,
std
::
vector
<
OperatorNodeBase
*>>
m_deps
;
state
.
graph
().
iter
([
&
m_deps
](
OperatorNodeBase
*
opr
)
{
for
(
auto
&
inp
:
opr
->
input
())
{
...
...
@@ -1843,6 +1867,7 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
state
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ================ FuseConvBiasZPass ================ */
...
...
@@ -1851,6 +1876,7 @@ const char* FuseConvBiasZPass::name() const {
}
void
FuseConvBiasZPass
::
apply
(
OptState
&
state
)
const
{
MIDOUT_B
(
"FuseConvBiasZPass::apply"
)
UniqReaderCheck
uniq_reader_check
{
state
.
graph
()};
auto
rewriter
=
state
.
graph
().
make_rewriter
();
...
...
@@ -1977,6 +2003,7 @@ void FuseConvBiasZPass::apply(OptState& state) const {
state
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ================ FuseDeconvCvtPass ================ */
...
...
@@ -1986,6 +2013,7 @@ const char* FuseDeconvCvtPass::name() const {
void
FuseDeconvCvtPass
::
apply
(
OptState
&
state
)
const
{
MIDOUT_B
(
"FuseDeconvCvtPass::apply"
)
std
::
unordered_map
<
VarNode
*
,
std
::
vector
<
OperatorNodeBase
*>>
m_deps
;
state
.
graph
().
iter
([
&
m_deps
](
OperatorNodeBase
*
opr
)
{
for
(
auto
&
inp
:
opr
->
input
())
{
...
...
@@ -2036,6 +2064,7 @@ void FuseDeconvCvtPass::apply(OptState& state) const {
state
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ================ ParamMergePass ================ */
...
...
@@ -2044,10 +2073,12 @@ const char* ParamMergePass::name() const {
}
void
ParamMergePass
::
apply
(
OptState
&
opt_state
)
const
{
MIDOUT_B
(
"ParamMergePass::apply"
)
param_merge
<
opr
::
SharedDeviceTensor
,
opr
::
MultipleDeviceTensorHolder
>
(
opt_state
);
param_merge
<
opr
::
SharedDeviceTensorWithFormat
,
opr
::
MultipleDeviceTensorWithFormatHolder
>
(
opt_state
);
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/impl/misc.cpp
浏览文件 @
80c47053
...
...
@@ -19,6 +19,16 @@
#include "megbrain/serialization/opr_shallow_copy.h"
#include "../../core/impl/graph/cg_impl.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_misc
)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_misc, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using
namespace
mgb
;
using
namespace
gopt
;
...
...
@@ -29,6 +39,7 @@ const char* RemoveNonComputingOprPass::name() const {
}
void
RemoveNonComputingOprPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"RemoveNonComputingOprPass::apply"
)
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
](
OperatorNodeBase
*
opr
)
{
auto
type
=
opr
->
dyn_typeinfo
();
...
...
@@ -75,6 +86,7 @@ void RemoveNonComputingOprPass::apply(OptState& opt) const {
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ================ ExpandVirtualGradPass ================ */
...
...
@@ -84,6 +96,7 @@ const char* ExpandVirtualGradPass::name() const {
}
void
ExpandVirtualGradPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"ExpandVirtualGradPass::apply"
)
#if MGB_ENABLE_GRAD
opt
.
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
...
...
@@ -111,6 +124,7 @@ void ExpandVirtualGradPass::apply(OptState& opt) const {
#else
MGB_MARK_USED_VAR
(
opt
);
#endif
MIDOUT_E
}
/* ================= DelayBroadcastPass ================ */
...
...
@@ -144,6 +158,7 @@ void DelayBroadcastPass::apply(OptState& opt) const {
// remove them from the chain, and add them back right after the endpoint.
// TypeCvt's order may change, so disable the check.
MIDOUT_B
(
"DelayBroadcastPass::apply"
)
opt
.
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
auto
unique_reader_chk
=
UniqReaderCheck
{
opt
.
graph
()};
...
...
@@ -325,6 +340,7 @@ void DelayBroadcastPass::apply(OptState& opt) const {
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ======================= RecompTypeCvtPass ====================== */
...
...
@@ -334,6 +350,7 @@ const char* RecompTypeCvtPass::name() const {
}
void
RecompTypeCvtPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"RecompTypeCvtPass::apply"
)
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
allowed_typecvt
=
[](
OperatorNodeBase
*
opr
)
->
OperatorNodeBase
*
{
...
...
@@ -399,6 +416,7 @@ void RecompTypeCvtPass::apply(OptState& opt) const {
};
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ======================= CombineAstypeAndReducePass ====================== */
...
...
@@ -408,6 +426,7 @@ const char* CombineAstypeAndReducePass::name() const {
}
void
CombineAstypeAndReducePass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"CombineAstypeAndReducePass::apply"
)
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
using
DataType
=
opr
::
Reduce
::
Param
::
DataType
;
...
...
@@ -453,6 +472,7 @@ void CombineAstypeAndReducePass::apply(OptState& opt) const {
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/* ================ CondExecConstPredicateFolding ================ */
...
...
@@ -462,6 +482,7 @@ const char* CondExecConstPredicateFolding::name() const {
void
CondExecConstPredicateFolding
::
apply
(
OptState
&
opt
)
const
{
#if MGB_ENABLE_COND_EXEC
MIDOUT_B
(
"CondExecConstPredicateFolding::apply"
)
if
(
!
cg
::
ExecutionMask
::
have_alive_instance
())
{
return
;
}
...
...
@@ -605,6 +626,7 @@ void CondExecConstPredicateFolding::apply(OptState& opt) const {
}
rewriter
.
apply_inplace
();
MIDOUT_E
#endif // MGB_ENABLE_COND_EXEC
}
...
...
@@ -632,6 +654,7 @@ bool RemoveRedundantTypeCvtPass::should_remove(DType A, DType B) {
}
void
RemoveRedundantTypeCvtPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"RemoveRedundantTypeCvtPass::apply"
)
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
](
OperatorNodeBase
*
opr
)
{
...
...
@@ -656,6 +679,7 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const {
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
#if MGB_ENABLE_OPR_MM
...
...
@@ -668,6 +692,7 @@ const char* PackAllReduceScanPass::name() const {
}
void
PackAllReduceScanPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"PackAllReduceScanPass::apply"
)
auto
comp_graph
=
opt
.
graph
().
comp_graph
();
if
(
comp_graph
->
options
().
allreduce_pack_max_size
==
0
)
return
;
auto
cb_scan
=
[
this
]
(
OperatorNodeBase
*
opr
)
{
...
...
@@ -682,6 +707,7 @@ void PackAllReduceScanPass::apply(OptState& opt) const {
}
};
opt
.
graph
().
iter
(
cb_scan
);
MIDOUT_E
}
bool
PackAllReduceScanPass
::
check_pattern
(
OperatorNodeBase
*
opr
)
{
...
...
@@ -856,6 +882,7 @@ void PackAllReduceReplacePass::insert_packed_oprs(
}
void
PackAllReduceReplacePass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"PackAllReduceReplacePass::apply"
)
// get graph options
auto
comp_graph
=
opt
.
graph
().
comp_graph
();
size_t
max_size
=
comp_graph
->
options
().
allreduce_pack_max_size
*
1024
*
1024
;
...
...
@@ -917,6 +944,7 @@ void PackAllReduceReplacePass::apply(OptState& opt) const {
};
opt
.
graph
().
iter
(
cb_replace
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
#else
...
...
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
80c47053
...
...
@@ -36,6 +36,16 @@
#endif
#include "megbrain/gopt/misc.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_tensor_reformat
)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_tensor_reformat, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using
namespace
mgb
;
using
namespace
gopt
;
...
...
@@ -755,8 +765,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
}
void
TensorReformatPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"TensorReformatPass::apply"
)
insert_pass
(
opt
);
translate_pass
(
opt
);
MIDOUT_E
}
/* ================ EnableTensorCorePass =============== */
...
...
@@ -773,6 +785,7 @@ VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var,
std
::
unique_ptr
<
EnableTensorCorePass
>
EnableTensorCorePass
::
make_tensorcore_converter
()
{
MIDOUT_B
(
"EnableTensorCorePass::make"
)
// replace rule for conv bias opr
auto
replace_conv_bias_opr
=
[](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
...
...
@@ -1111,6 +1124,7 @@ EnableTensorCorePass::make_tensorcore_converter() {
replace_func
[
opr
::
GetVarShape
::
typeinfo
()]
=
replace_inps_to_nchw4
;
replace_func
[
opr
::
Dimshuffle
::
typeinfo
()]
=
replace_inps_to_nchw4
;
return
ret
;
MIDOUT_E
}
/* ================ EnableCHWN4Pass =============== */
...
...
@@ -1125,6 +1139,7 @@ VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var,
}
std
::
unique_ptr
<
EnableCHWN4Pass
>
EnableCHWN4Pass
::
make_chwn4_converter
()
{
MIDOUT_B
(
"EnableCHWN4Pass::make"
)
auto
ret
=
std
::
make_unique
<
EnableCHWN4Pass
>
();
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
auto
&&
replace_func
=
ret
->
m_opr_replace_func
;
...
...
@@ -1381,6 +1396,7 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
replace_func
[
opr
::
Dimshuffle
::
typeinfo
()]
=
replace_inps_to_nchw4
;
replace_func
[
opr
::
BatchConvBias
::
typeinfo
()]
=
replace_inps_to_nchw4
;
return
ret
;
MIDOUT_E
}
/* ================ EnableNCHW4Pass ================ */
...
...
@@ -1395,6 +1411,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
}
std
::
unique_ptr
<
EnableNCHW4Pass
>
EnableNCHW4Pass
::
make_nchw4_converter
(){
MIDOUT_B
(
"EnableNCHW4Pass::make"
)
auto
ret
=
std
::
make_unique
<
EnableNCHW4Pass
>
();
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
using
RelayoutMode
=
RelayoutPlaceholder
::
LayoutType
;
...
...
@@ -1772,6 +1789,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
replace_func
[
opr
::
IncrSubtensor
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
WarpAffineForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
return
ret
;
MIDOUT_E
}
/* ================ EnableNchwxxPass =============== */
...
...
@@ -2140,6 +2158,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
std
::
unique_ptr
<
EnableNchwxxPass
>
EnableNchwxxPass
::
make_nchwxx_converter
(
size_t
pack_c_size
)
{
MIDOUT_B
(
"EnableNchwxxPass::make"
)
auto
ret
=
std
::
make_unique
<
EnableNchwxxPass
>
(
pack_c_size
);
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
std
::
string
convter_pass_name
=
"conv_format_nchw88"
;
...
...
@@ -2149,6 +2168,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
ret
->
fill_opr_convert_fun
(
pack_c_size
);
ret
->
set_name
(
convter_pass_name
);
return
ret
;
MIDOUT_E
}
/* ================ EnableNchw44DotPass =============== */
...
...
@@ -2164,6 +2184,7 @@ VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var,
std
::
unique_ptr
<
EnableNchw44DotPass
>
EnableNchw44DotPass
::
make_nchw44_dot_converter
()
{
MIDOUT_B
(
"EnableNchw44DotPass::make"
)
auto
ret
=
std
::
make_unique
<
EnableNchw44DotPass
>
();
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
//! First is whether the conv can trans to nchwxx, second is the filter
...
...
@@ -2384,6 +2405,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
replace_func
[
opr
::
Convolution
::
typeinfo
()]
=
replace_conv_opr
;
replace_func
[
opr
::
ConvBias
::
typeinfo
()]
=
replace_conv_bias_opr
;
return
ret
;
MIDOUT_E
}
/* ==================== ShuffleShuffleRemovePass ================= */
...
...
@@ -2961,9 +2983,11 @@ const char* ShuffleShuffleRemovePass::name() const {
}
void
ShuffleShuffleRemovePass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"ShuffleShuffleRemovePass::apply"
)
opt
.
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
CHECK_SHAPE
|
VarReplaceCheckFlag
::
CHECK_DTYPE
);
Impl
{
opt
};
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
src/gopt/impl/weights_preprocess.cpp
浏览文件 @
80c47053
...
...
@@ -14,6 +14,16 @@
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_weight_preprocess
)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_weight_preprocess, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using
namespace
mgb
;
using
namespace
gopt
;
using
namespace
cg
;
...
...
@@ -23,6 +33,7 @@ const char* WinogradTransformReplacePass::name() const {
}
void
WinogradTransformReplacePass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"WinogradTransformReplacePass::apply"
)
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
ConstVarPropogate
cvprop
{
ConstVarType
::
IMMUTABLE_AND_PARAM
};
opt
.
graph
().
iter
([
&
cvprop
](
OperatorNodeBase
*
opr
)
{
...
...
@@ -174,6 +185,7 @@ void WinogradTransformReplacePass::apply(OptState& opt) const {
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
/**
...
...
src/opr-mm/impl/collective_comm.cpp
浏览文件 @
80c47053
...
...
@@ -855,10 +855,12 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const {
return
ModeTrait
::
from_mode
(
m_param
.
mode
).
grad
(
out_grad
,
this
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CollectiveComm
)
{
mgb_assert
(
out_grad
.
size
()
==
1
,
"CollectiveComm should only have one grad"
);
return
opr
.
grad
(
out_grad
[
0
]);
}
#endif
/* ===================== shallow copy ===================== */
...
...
src/opr-mm/impl/io_remote.cpp
浏览文件 @
80c47053
...
...
@@ -109,6 +109,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
return
prop
;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
RemoteSend
)
{
mgb_assert
(
opr
.
is_grad
());
return
RemoteRecv
::
make
(
opr
.
key
()
+
":grad"
,
...
...
@@ -118,6 +119,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) {
opr
.
input
(
0
)
->
shape
(),
opr
.
input
(
0
)
->
dtype
())
.
node
();
}
#endif
/* ===================== RemoteRecv ===================== */
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
80c47053
...
...
@@ -552,6 +552,7 @@ void Elemwise::call_megdnn_opr_exec(
opr
->
exec
(
inp
,
out
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Elemwise
)
{
SymbolVar
i
[
5
];
SymbolVar
i0
(
opr
.
input
(
0
)),
i1
,
i2
,
out
(
opr
.
output
(
0
)),
...
...
@@ -730,6 +731,7 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
result
=
-
result
;
return
result
.
node
();
}
#endif
VarNode
*
Elemwise
::
sum_grad_list
(
VarNode
*
wrt
,
VarNodeArray
&
grads
)
{
mgb_assert
(
!
grads
.
empty
());
...
...
@@ -814,6 +816,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const {
return
ret
;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
TypeCvt
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
auto
itype
=
opr
.
input
(
0
)
->
dtype
(),
otype
=
opr
.
output
(
0
)
->
dtype
();
...
...
@@ -826,6 +829,7 @@ MGB_IMPL_OPR_GRAD(TypeCvt) {
}
return
TypeCvt
::
make
(
out_grad
[
0
],
opr
.
input
(
0
)
->
dtype
()).
node
();
}
#endif
void
TypeCvt
::
mem_plan_fwd_in2out_writable
()
{
if
(
input
(
0
)
->
dtype
().
size
()
==
output
(
0
)
->
dtype
().
size
()
&&
...
...
@@ -963,10 +967,12 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr
(
deps
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
AddUpdate
)
{
// actually valid, just not implemented
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
}
#endif
/* =========================== Reduce =========================== */
...
...
@@ -1698,6 +1704,7 @@ void Reduce::create_megdnn_opr() {
create_operator
<
megdnn
::
Reduce
>
());
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Reduce
)
{
for
(
size_t
i
=
1
;
i
<
opr
.
output
().
size
();
++
i
)
mgb_assert
(
!
out_grad
[
i
]);
...
...
@@ -1733,7 +1740,7 @@ MGB_IMPL_OPR_GRAD(Reduce) {
grad
=
TypeCvt
::
make
(
grad
,
iv
.
dtype
());
return
grad
.
node
();
}
#endif
void
Reduce
::
record_execute_deps
(
ExecDependencyArray
&
deps
)
{
record_megdnn_opr
(
deps
);
...
...
@@ -1783,11 +1790,13 @@ void PowC::init_output_static_infer_desc() {
{
SourceType
::
DEP
,
{{
input
(
0
),
DepType
::
VALUE
}},
infer_value
});
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
PowC
)
{
auto
exp
=
opr
.
param
().
exp
;
return
(
exp
*
SymbolVar
{
out_grad
[
0
]}
*
PowC
::
make
(
opr
.
input
(
0
),
exp
-
1
,
opr
.
config
()))
.
node
();
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/impl/blas.cpp
浏览文件 @
80c47053
...
...
@@ -106,6 +106,7 @@ void MatrixMul::scn_do_execute() {
MGB_FINALLY
({
tparam
=
this
->
param
();
});
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MatrixMul
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
().
category
()
==
DTypeCategory
::
FLOAT
,
"only float data type supported for grad"
);
...
...
@@ -128,6 +129,7 @@ MGB_IMPL_OPR_GRAD(MatrixMul) {
}
return
grad
.
node
();
}
#endif
/* ================= BatchedMatrixMul ================= */
...
...
@@ -224,6 +226,7 @@ void BatchedMatrixMul::scn_do_execute() {
MGB_FINALLY
({
tparam
=
this
->
param
();
});
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchedMatrixMul
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
().
category
()
==
DTypeCategory
::
FLOAT
,
"only float data type supported for grad"
);
...
...
@@ -251,6 +254,7 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
}
return
grad
.
node
();
}
#endif
/* ================= Dot ================= */
...
...
@@ -327,6 +331,7 @@ void Dot::add_input_layout_constraint() {
input
(
1
)
->
add_layout_constraint
(
check
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Dot
)
{
auto
other_input
=
opr
.
input
(
wrt_idx
==
0
?
1
:
0
);
auto
ishp0
=
opr
::
GetVarShape
::
make
(
opr
.
input
(
0
)),
...
...
@@ -336,6 +341,7 @@ MGB_IMPL_OPR_GRAD(Dot) {
Broadcast
::
make
(
mul
(
out_grad
[
0
],
other_input
),
max_ishp
),
wrt_idx
?
ishp1
:
ishp0
).
node
();
}
#endif
SymbolVar
Dot
::
make
(
SymbolVar
opr0
,
SymbolVar
opr1
,
const
OperatorNodeConfig
&
config
)
{
...
...
@@ -350,6 +356,8 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
MatrixInverse
);
MEGDNN_OPR_INIT1
(
MatrixInverse
,
"matrix_inv"
)
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MatrixInverse
)
{
SymbolVar
a
=
opr
.
output
(
0
);
// TODO: use unified MatrixMul interface when we have it
...
...
@@ -364,6 +372,7 @@ MGB_IMPL_OPR_GRAD(MatrixInverse) {
a_bnn
);
return
da
.
reshape
(
a
.
symshape
()).
node
();
}
#endif
/* ================= SVD ================= */
...
...
@@ -386,6 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) :
}
}
#ifdef MGB_ENABLE_GRAD
namespace
{
/*!
...
...
@@ -477,7 +487,9 @@ OP(*, {}, {})
#undef OP
}
// anonymous namespace
#endif
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
SVD
)
{
/**
* The formula is copied from
...
...
@@ -555,6 +567,7 @@ MGB_IMPL_OPR_GRAD(SVD) {
I_n
-
matmul
(
v
,
v
,
param01
)));
return
ret
.
reshape
(
a
.
symshape
()).
node
();
}
#endif
SymbolVarArray
SVD
::
make
(
const
SymbolVar
&
src
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
{
...
...
src/opr/impl/cond.cpp
浏览文件 @
80c47053
...
...
@@ -818,6 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input,
return
input
;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CondExecMark
)
{
if
(
wrt_idx
==
opr
.
input
().
size
()
-
1
||
!
out_grad
.
at
(
wrt_idx
))
{
return
nullptr
;
...
...
@@ -841,6 +842,7 @@ MGB_IMPL_OPR_GRAD(CondExecMark) {
{
1
,
grad_mode
},
OperatorNodeConfig
{})
->
output
(
0
);
}
#endif
/* ============================= CondExecMerge ============================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
CondExecMerge
);
...
...
@@ -1225,6 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
return
ret
;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CondExecMerge
)
{
using
Mode
=
CondExecMerge
::
Param
::
Mode
;
if
(
opr
.
param
().
mode
==
Mode
::
SUM_COND_OUT
&&
...
...
@@ -1259,6 +1262,7 @@ MGB_IMPL_OPR_GRAD(CondExecMerge) {
OperatorNodeConfig
{
og
->
comp_node
()})
->
output
(
0
);
}
#endif
void
CondExecMerge
::
modify_grad_sum_list
(
VarNode
*
wrt
,
VarNodeArray
&
grads
)
{
if
(
!
ExecutionMask
::
have_alive_instance
())
{
...
...
src/opr/impl/dnn/batch_norm.cpp
浏览文件 @
80c47053
...
...
@@ -230,6 +230,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {
}
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchNormForward
)
{
mgb_assert
(
wrt_idx
<
5
);
if
(
wrt_idx
<
3
)
{
...
...
@@ -242,6 +243,7 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) {
return
nullptr
;
}
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BatchNormBackward
);
...
...
src/opr/impl/dnn/convolution.cpp
浏览文件 @
80c47053
...
...
@@ -18,6 +18,19 @@
#include "megdnn/oprs/utils.h"
//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/oprs.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_opr_convolution
)
#define MIDOUT_B(...) \
MIDOUT_BEGIN(megbrain_opr_convolution, __VA_ARGS__) {
#define MIDOUT_E \
} \
MIDOUT_END();
#include "../internal/megdnn_opr_wrapper.inl"
#include <array>
...
...
@@ -230,6 +243,7 @@ class TimedProfiler {
static
constexpr
int
arity_in
=
OprArityTrait
<
Opr
>::
arity_in
;
static
constexpr
int
arity_out
=
OprArityTrait
<
Opr
>::
arity_out
;
static
constexpr
int
arity
=
OprArityTrait
<
Opr
>::
arity
;
using
ConvTensorShapes
=
std
::
array
<
TensorShape
,
arity
>
;
public:
...
...
@@ -295,6 +309,7 @@ double TimedProfiler<Opr>::init_timeout_setting() {
template
<
typename
Opr
>
typename
TimedProfiler
<
Opr
>::
TResult
TimedProfiler
<
Opr
>::
prof_impl
(
const
TParam
&
raw_param
)
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"TimedProfiler::prof_impl"
)))
auto
&&
param
=
raw_param
.
as_single_pod
<
Param
>
();
CompNode
cn
=
CompNode
::
load
(
param
.
comp_node_loc
,
param
.
comp_node_loc
);
auto
megdnn_opr
=
intl
::
create_megdnn_opr
<
Opr
>
(
cn
);
...
...
@@ -401,14 +416,17 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
mgb_assert
(
ev_start
->
finished
());
return
TResult
::
from_pod
(
Result
{
ev_start
->
elapsed_time_until
(
*
ev_end
)});
MIDOUT_E
};
template
<
typename
Opr
>
void
TimedProfiler
<
Opr
>::
prof_init_device
(
const
TParam
&
raw_param
)
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"TimedProfiler::prof_init_device"
)))
auto
&&
param
=
raw_param
.
as_single_pod
<
Param
>
();
CompNode
cn
=
CompNode
::
load
(
param
.
comp_node_loc
,
param
.
comp_node_loc
);
// wait for cuda init, so its time does not get accounted in timeout
cn
.
sync
();
MIDOUT_E
}
/* =================== AlgoChooser =================== */
...
...
@@ -426,6 +444,7 @@ class AlgoChooser {
static
constexpr
int
arity_in
=
OprArityTrait
<
Opr
>::
arity_in
;
static
constexpr
int
arity_out
=
OprArityTrait
<
Opr
>::
arity_out
;
static
constexpr
int
arity
=
OprArityTrait
<
Opr
>::
arity
;
using
ImplAlgo
=
typename
Opr
::
Algorithm
*
;
using
MGBOpr
=
typename
MegDNNOpr2MGBOpr
<
Opr
>::
MGBOpr
;
using
ConvTensorLayouts
=
std
::
array
<
TensorLayout
,
arity
>
;
...
...
@@ -473,8 +492,8 @@ class AlgoChooser {
//! put first
std
::
vector
<
ImplAlgo
>
get_all_candidates
()
const
{
auto
heu
=
choose_by_heuristic
();
auto
&&
ret
=
OprArityTrait
<
Opr
>::
get_all_algorithms
(
m_megdnn_opr
,
m_layouts
);
auto
&&
ret
=
OprArityTrait
<
Opr
>::
get_all_algorithms
(
m_megdnn_opr
,
m_layouts
);
bool
found
=
false
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
ret
[
i
]
==
heu
)
{
...
...
@@ -491,7 +510,7 @@ class AlgoChooser {
//! get candidate algos with workspace limit.
std
::
vector
<
ImplAlgo
>
get_all_candidates_with_workspace_limit
()
const
{
auto
&&
all_algos
=
get_all_candidates
();
auto
&&
all_algos
=
get_all_candidates
();
auto
opr
=
m_mgb_opr
;
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
opr
->
owner_graph
(),
opr
->
comp_node
(),
...
...
@@ -633,16 +652,16 @@ AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result(
algo
->
name
(),
str_on_inp_shape
.
c_str
());
timer
.
reset
();
MGB_TRY
{
cur_rst
=
ctx
.
profile_single_algo
(
algo
,
cur_timeout
);
}
MGB_CATCH
(
std
::
exception
&
exc
,
{
mgb_log_warn
(
"caught exception during %s: %s"
,
msg
.
c_str
(),
exc
.
what
());
continue
;
})
MGB_CATCH
(
std
::
exception
&
exc
,
{
mgb_log_warn
(
"caught exception during %s: %s"
,
msg
.
c_str
(),
exc
.
what
());
continue
;
})
MGB_CATCH
(...,
{
mgb_log_warn
(
"caught exception during %s"
,
msg
.
c_str
());
continue
;
})
if
(
!
cur_rst
.
valid
())
{
})
if
(
!
cur_rst
.
valid
())
{
mgb_log_warn
(
"timeout when %s; timeout setting: %.3fsec"
,
msg
.
c_str
(),
cur_timeout
);
continue
;
...
...
@@ -680,6 +699,7 @@ void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts(
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplAlgo
AlgoChooser
<
Opr
>::
choose_by_profile
(
ExeContext
&
ctx
,
bool
require_reproducible
,
bool
enable_update
)
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"AlgoChooser::choose_by_profile"
)))
auto
opr
=
ctx
.
mgb_opr
();
if
(
opr
->
owner_graph
()
->
options
().
no_profiling_on_shape_change
)
{
auto
algo
=
ctx
.
megdnn_opr
()
->
execution_policy
().
algorithm
;
...
...
@@ -720,6 +740,7 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile(
opr
->
owner_graph
(),
opr
->
comp_node
(),
opr
->
execution_policy
().
workspace_limit
));
mgb_trap
();
MIDOUT_E
}
template
<
>
...
...
@@ -748,7 +769,7 @@ void AlgoChooser<megdnn::ConvBias>::ExeContext::
if
(
m_layouts
[
1
].
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
opr_param
.
format
==
megdnn
::
ConvBias
::
Param
::
Format
::
NCHW44
)
{
if
(
winograd_preprocess_opr
->
param
().
format
==
megdnn
::
param
::
MatrixMul
::
Format
::
MK4
){
megdnn
::
param
::
MatrixMul
::
Format
::
MK4
)
{
winograd_preprocess_opr
->
param
().
compute_mode
=
ConvBias
::
Param
::
ComputeMode
::
FLOAT32
;
param
.
opr_param
.
compute_mode
=
...
...
@@ -941,6 +962,7 @@ void ConvolutionForward::init_output_dtype() {
output
(
0
)
->
dtype
(
output_dtype
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ConvolutionForward
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
().
category
()
==
DTypeCategory
::
FLOAT
,
"only float data type supported for grad"
);
...
...
@@ -960,6 +982,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionForward) {
return
grad
.
node
();
}
}
#endif
size_t
ConvolutionForward
::
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
...
...
@@ -1086,6 +1109,7 @@ void ConvolutionBackwardData::scn_do_execute() {
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ConvolutionBackwardData
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1101,6 +1125,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) {
}
return
nullptr
;
}
#endif
/* ==================== ConvolutionBackwardFilter ==================== */
IMPL_CONV
(
ConvolutionBackwardFilter
,
"conv_bwd_filter"
);
...
...
@@ -1138,6 +1163,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes(
megdnn_opr
(),
this
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ConvolutionBackwardFilter
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1153,6 +1179,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) {
}
return
nullptr
;
}
#endif
/* ==================== Convolution3DForward ==================== */
IMPL_CONV
(
Convolution3DForward
,
"conv3d_fwd"
);
...
...
@@ -1192,6 +1219,7 @@ void Convolution3DForward::init_output_dtype() {
}
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Convolution3DForward
)
{
mgb_assert
(
opr
.
param
().
data_type
==
Convolution3DForward
::
Param
::
DataType
::
FLOAT
,
...
...
@@ -1212,6 +1240,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DForward) {
return
grad
.
node
();
}
}
#endif
size_t
Convolution3DForward
::
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
...
...
@@ -1285,6 +1314,7 @@ void Convolution3DBackwardData::scn_do_execute() {
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Convolution3DBackwardData
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1300,6 +1330,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) {
}
return
nullptr
;
}
#endif
/* ==================== Convolution3DBackwardFilter ==================== */
IMPL_CONV
(
Convolution3DBackwardFilter
,
"conv3d_bwd_filter"
);
...
...
@@ -1658,6 +1689,7 @@ size_t LocalShareForward::get_workspace_size_bytes(
megdnn_opr
(),
this
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LocalShareForward
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
().
category
()
==
DTypeCategory
::
FLOAT
,
"only float data type supported for grad"
);
...
...
@@ -1677,6 +1709,7 @@ MGB_IMPL_OPR_GRAD(LocalShareForward) {
return
grad
.
node
();
}
}
#endif
/* ===================== LocalShareBackwardData ==================== */
...
...
@@ -1737,6 +1770,7 @@ void LocalShareBackwardData::scn_do_execute() {
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LocalShareBackwardData
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1752,6 +1786,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardData) {
}
return
nullptr
;
}
#endif
/* ==================== LocalShareBackwardFilter ==================== */
...
...
@@ -1792,6 +1827,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes(
megdnn_opr
(),
this
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LocalShareBackwardFilter
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1805,6 +1841,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) {
}
return
nullptr
;
}
#endif
/* ===================== DeformableConvForward ==================== */
...
...
@@ -1869,6 +1906,7 @@ size_t DeformableConvForward::get_workspace_size_bytes(
megdnn_opr
(),
this
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
DeformableConvForward
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
()
==
dtype
::
Float32
(),
"only float data type supported for grad"
);
...
...
@@ -1888,6 +1926,7 @@ MGB_IMPL_OPR_GRAD(DeformableConvForward) {
SymbolVarArray
grads
=
{
grad_arr
[
0
],
filter_grad
,
grad_arr
[
1
],
grad_arr
[
2
]};
return
grads
[
wrt_idx
].
node
();
}
#endif
/* ==================== DeformableConvBackwardData ==================== */
...
...
@@ -2265,4 +2304,4 @@ void BatchConvBiasForward::init_output_format() {
#undef IMPL_CONV
#undef MGB_FOREACH_FASTRUN_OPR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
src/opr/impl/dnn/images2neibs.cpp
浏览文件 @
80c47053
...
...
@@ -20,11 +20,13 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Images2NeibsForward
);
MEGDNN_OPR_INIT1
(
Images2NeibsForward
,
"images2neibs"
)
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Images2NeibsForward
)
{
mgb_assert
(
wrt_idx
==
0
&&
out_grad
.
size
()
==
2
&&
!
out_grad
[
1
]);
return
Images2NeibsBackward
::
make
(
out_grad
[
0
],
opr
.
input
(
0
),
opr
.
param
()).
node
();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Images2NeibsBackward
);
MEGDNN_OPR_INIT2
(
Images2NeibsBackward
,
"images2neibs_grad"
,
1
,
false
);
...
...
src/opr/impl/dnn/local.cpp
浏览文件 @
80c47053
...
...
@@ -20,10 +20,13 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
LocalForward
);
MEGDNN_OPR_INIT2
(
LocalForward
,
"local"
)
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LocalForward
)
{
return
intl
::
conv_grad
<
LocalBackwardData
,
LocalBackwardFilter
>
(
opr
,
wrt_idx
,
out_grad
);
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
LocalBackwardData
);
MEGDNN_OPR_INIT3
(
LocalBackwardData
,
"local_bwd_data"
,
2
,
false
);
...
...
@@ -34,10 +37,13 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
GroupLocalForward
);
MEGDNN_OPR_INIT2
(
GroupLocalForward
,
"glocal"
)
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
GroupLocalForward
)
{
return
intl
::
conv_grad
<
GroupLocalBackwardData
,
GroupLocalBackwardFilter
>
(
opr
,
wrt_idx
,
out_grad
);
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
GroupLocalBackwardData
);
MEGDNN_OPR_INIT3
(
GroupLocalBackwardData
,
"glocal_bwd_data"
,
2
,
false
);
...
...
src/opr/impl/dnn/lrn.cpp
浏览文件 @
80c47053
...
...
@@ -20,12 +20,14 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
LRNForward
);
MEGDNN_OPR_INIT1
(
LRNForward
,
"lrn"
)
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LRNForward
)
{
mgb_assert
(
wrt_idx
==
0
);
SymbolVar
grad
=
LRNBackward
::
make
(
opr
.
input
(
0
),
opr
.
output
(
0
),
out_grad
[
0
],
opr
.
param
());
return
grad
.
node
();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
LRNBackward
);
MEGDNN_OPR_INIT3
(
LRNBackward
,
"lrn_bwd"
,
0
,
true
);
...
...
src/opr/impl/dnn/pooling.cpp
浏览文件 @
80c47053
...
...
@@ -19,12 +19,14 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
PoolingForward
);
MEGDNN_OPR_INIT1
(
PoolingForward
,
"pooling"
)
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
PoolingForward
)
{
mgb_assert
(
wrt_idx
==
0
);
SymbolVar
grad
=
PoolingBackward
::
make
(
opr
.
input
(
0
),
opr
.
output
(
0
),
out_grad
[
0
],
opr
.
param
());
return
grad
.
node
();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
PoolingBackward
);
MEGDNN_OPR_INIT3
(
PoolingBackward
,
"pooling_bwd"
,
0
,
true
);
...
...
src/opr/impl/dnn/roi_align.cpp
浏览文件 @
80c47053
...
...
@@ -40,6 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois,
src
.
node
(),
rois
.
node
(),
param
,
config
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ROIAlignForward
)
{
if
(
out_grad
[
1
])
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -55,6 +56,7 @@ MGB_IMPL_OPR_GRAD(ROIAlignForward) {
return
nullptr
;
}
}
#endif
/* ==================== ROIAlignBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ROIAlignBackward
);
...
...
src/opr/impl/dnn/roi_pooling.cpp
浏览文件 @
80c47053
...
...
@@ -84,6 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes(
input_shapes
,
output_shapes
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ROIPoolingForward
)
{
if
(
out_grad
[
1
]
||
wrt_idx
==
2
)
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -98,6 +99,7 @@ MGB_IMPL_OPR_GRAD(ROIPoolingForward) {
return
nullptr
;
}
}
#endif
void
ROIPoolingForward
::
scn_do_execute
()
{
return
intl
::
MegDNNOprMethInvoker
<
megdnn
::
ROIPoolingForward
>::
...
...
@@ -146,6 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make(
return
all
[
0
];
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
DeformablePSROIPooling
)
{
mgb_assert
(
wrt_idx
<=
2
);
// wrt_idx = 0 or 1 or 2
...
...
@@ -168,6 +171,7 @@ MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) {
}
return
nullptr
;
}
#endif
/* ==================== DeformablePSROIPoolingBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
DeformablePSROIPoolingBackward
);
...
...
src/opr/impl/imgproc.cpp
浏览文件 @
80c47053
...
...
@@ -127,6 +127,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr
(
deps
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
WarpPerspectiveForward
)
{
mgb_assert
(
opr
.
input
().
size
()
==
3
,
"backward with mat_idx is currently unsupported"
);
...
...
@@ -145,6 +146,7 @@ MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) {
}
else
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
}
#endif
/* ====================== WarpPerspectiveBackwardData ====================== */
...
...
@@ -234,6 +236,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray &deps) {
record_megdnn_opr
(
deps
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ResizeForward
)
{
mgb_assert
(
opr
.
input
().
size
()
==
2
);
if
(
wrt_idx
==
0
)
{
...
...
@@ -243,6 +246,7 @@ MGB_IMPL_OPR_GRAD(ResizeForward) {
}
else
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
}
#endif
/* ====================== ResizeBackward ====================== */
...
...
src/opr/impl/indexing.cpp
浏览文件 @
80c47053
...
...
@@ -83,6 +83,7 @@ void IndexingOneHot::init_output_dtype() {
output
(
0
)
->
dtype
(
input
(
0
)
->
dtype
());
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingOneHot
)
{
if
(
wrt_idx
==
0
)
{
return
IndexingSetOneHot
::
make
(
...
...
@@ -91,6 +92,7 @@ MGB_IMPL_OPR_GRAD(IndexingOneHot) {
}
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
}
#endif
/* ==================== IndexingSetOneHot ==================== */
...
...
@@ -133,6 +135,7 @@ void IndexingSetOneHot::scn_do_execute() {
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingSetOneHot
)
{
SymbolVar
index
{
opr
.
input
(
1
)},
sub
{
opr
.
input
(
2
)},
og
{
out_grad
.
at
(
0
)};
if
(
wrt_idx
==
0
)
{
...
...
@@ -144,6 +147,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetOneHot) {
}
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
}
#endif
size_t
IndexingSetOneHot
::
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
...
...
@@ -165,6 +169,7 @@ void IndexingRemap::init_output_dtype() {
output
(
0
)
->
dtype
(
input
(
0
)
->
dtype
());
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingRemap
)
{
if
(
wrt_idx
==
1
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -172,6 +177,7 @@ MGB_IMPL_OPR_GRAD(IndexingRemap) {
return
IndexingRemapBackward
::
make
(
out_grad
[
0
],
opr
.
input
(
1
),
opr
.
input
(
0
),
opr
.
param
()).
node
();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
IndexingRemapBackward
);
MEGDNN_OPR_INIT3
(
IndexingRemapBackward
,
"indexing_remap_bwd"
,
2
,
false
);
...
...
@@ -460,6 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
IndexingIncrMultiAxisVec
,
"indexing_incr_multi_axis_vec"
,
false
);
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingMultiAxisVec
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -468,7 +475,9 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
SymbolVar
{
opr
.
input
(
0
)}.
fill_retain_dtype
(
0
),
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
#endif
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingSetMultiAxisVec
)
{
if
(
wrt_idx
>=
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -479,7 +488,9 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
}
return
IndexingMultiAxisVec
::
make
(
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
#endif
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingIncrMultiAxisVec
)
{
if
(
wrt_idx
>=
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -488,6 +499,7 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
}
return
IndexingMultiAxisVec
::
make
(
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
#endif
/* ============================= Mesh Indexing ============================ */
...
...
@@ -498,6 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET(
BatchedMeshIndexing
,
"batched_mesh_indexing"
,
false
,
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
););
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MeshIndexing
)
{
if
(
wrt_idx
!=
0
)
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -507,6 +520,9 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) {
opr
.
index_desc
())
.
node
();
}
#endif
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchedMeshIndexing
)
{
if
(
wrt_idx
!=
0
)
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -516,11 +532,14 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
opr
.
index_desc
())
.
node
();
}
#endif
/* ========================= IncrMeshIndexing ========================= */
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
IncrMeshIndexing
,
"incr_mesh_indexing"
,
false
);
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IncrMeshIndexing
)
{
if
(
wrt_idx
>
2
)
{
return
opr
::
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -530,9 +549,11 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
}
return
MeshIndexing
::
make
(
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
#endif
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
BatchedIncrMeshIndexing
,
"batched_incr_mesh_indexing"
,
false
);
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchedIncrMeshIndexing
)
{
if
(
wrt_idx
>
2
)
{
return
opr
::
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -542,10 +563,12 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
}
return
BatchedMeshIndexing
::
make
(
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
#endif
/* ======================== SetMeshIndexing =========================== */
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
SetMeshIndexing
,
"set_mesh_indexing"
,
false
);
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
SetMeshIndexing
)
{
if
(
wrt_idx
>=
2
)
{
return
opr
::
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -560,9 +583,11 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
return
MeshIndexing
::
make
(
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
}
#endif
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
BatchedSetMeshIndexing
,
"batched_set_mesh_indexing"
,
false
);
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchedSetMeshIndexing
)
{
if
(
wrt_idx
>
2
)
{
return
opr
::
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -578,5 +603,6 @@ MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) {
.
node
();
}
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/impl/io.cpp
浏览文件 @
80c47053
...
...
@@ -764,11 +764,13 @@ Copy::NodeProp* Copy::do_make_node_prop() const {
return
rst
;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Copy
)
{
mgb_assert
(
wrt_idx
==
0
);
return
Copy
::
make
(
out_grad
[
0
],
OperatorNodeConfig
{}.
follow_comp_node
(
opr
.
input
(
0
))).
node
();
}
#endif
void
Copy
::
add_input_layout_constraint
()
{
if
(
input
(
0
)
->
comp_node
()
!=
output
(
0
)
->
comp_node
())
{
...
...
src/opr/impl/loop/forward.cpp
浏览文件 @
80c47053
...
...
@@ -268,9 +268,11 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) {
return
gopr
->
get_grad_var
(
wrt_idx
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Loop
)
{
return
Loop
::
grad
(
const_cast
<
Loop
&>
(
opr
),
wrt_idx
,
out_grad
);
}
#endif
cg
::
OperatorNodeBase
::
NodeProp
*
Loop
::
do_make_node_prop
()
const
{
auto
prop
=
LoopImpl
::
do_make_node_prop
();
...
...
src/opr/impl/misc.cpp
浏览文件 @
80c47053
...
...
@@ -48,23 +48,26 @@ namespace intl {
/* ================= Argmxx ================= */
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Argmax
)
{
MGB_MARK_USED_VAR
(
out_grad
);
MGB_MARK_USED_VAR
(
opr
);
mgb_assert
(
!
wrt_idx
);
return
nullptr
;
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Argmax
);
MEGDNN_OPR_INIT1
(
Argmax
,
"argmax"
)
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Argmin
)
{
MGB_MARK_USED_VAR
(
out_grad
);
MGB_MARK_USED_VAR
(
opr
);
mgb_assert
(
!
wrt_idx
);
return
nullptr
;
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Argmin
);
MEGDNN_OPR_INIT1
(
Argmin
,
"argmin"
)
...
...
@@ -84,12 +87,14 @@ std::array<SymbolVar, 2> ArgsortForward::make(
return
{
node
->
output
(
0
),
node
->
output
(
1
)};
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ArgsortForward
)
{
mgb_assert
(
out_grad
.
size
()
==
3
&&
wrt_idx
==
0
&&
!
out_grad
[
2
]);
if
(
!
out_grad
[
0
])
return
nullptr
;
return
ArgsortBackward
::
make
(
out_grad
[
0
],
opr
.
output
(
1
)).
node
();
}
#endif
/* ================= ArgsortBackward ================= */
...
...
@@ -107,12 +112,14 @@ Cumsum::Cumsum(VarNode* opr, const Param& param,
add_input
({
opr
},
AddInputSortType
::
CUR_ADDED
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Cumsum
)
{
mgb_assert
(
out_grad
[
0
]
&&
!
out_grad
[
1
]);
auto
param
=
opr
.
param
();
param
.
reverse
=
!
param
.
reverse
;
return
Cumsum
::
make
(
out_grad
[
0
],
param
).
node
();
}
#endif
SymbolVar
Cumsum
::
make
(
SymbolVar
opr
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
{
...
...
@@ -170,6 +177,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask,
}
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CondTake
)
{
mgb_assert
(
out_grad
.
size
()
==
3
&&
!
out_grad
[
2
]);
if
(
wrt_idx
==
0
&&
out_grad
[
0
])
{
...
...
@@ -181,6 +189,7 @@ MGB_IMPL_OPR_GRAD(CondTake) {
}
return
nullptr
;
}
#endif
std
::
array
<
SymbolVar
,
2
>
CondTake
::
make
(
SymbolVar
data
,
SymbolVar
mask
,
...
...
@@ -318,6 +327,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr
(
deps
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
TopK
)
{
if
(
opr
.
param
().
mode
==
TopK
::
Param
::
Mode
::
KTH_ONLY
)
{
mgb_assert
(
out_grad
[
0
]
&&
!
out_grad
[
1
]
&&
!
out_grad
[
2
]);
...
...
@@ -334,5 +344,6 @@ MGB_IMPL_OPR_GRAD(TopK) {
return
ArgsortBackward
::
make
(
out_grad
[
0
],
opr
.
output
(
1
),
opr
.
input
(
0
))
.
node
();
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/impl/muxing.cpp
浏览文件 @
80c47053
...
...
@@ -316,9 +316,11 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) {
OperatorNodeConfig
().
comp_node_arr
(
sp_cn
)));
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
AllGather
)
{
return
const_cast
<
AllGather
&>
(
opr
).
grad
(
out_grad
);
}
#endif
void
AllGather
::
on_output_comp_node_stream_changed
()
{
}
...
...
src/opr/impl/rand.cpp
浏览文件 @
80c47053
...
...
@@ -112,19 +112,21 @@ UniqPtrWithCN<megdnn::RNGBase> RNGOpr<MegDNNOpr>::create_megdnn_opr() {
return
opr
;
}
#define IMPL(_cls) \
template class RNGOpr<::megdnn::_cls>; \
MGB_IMPL_OPR_GRAD(_cls) { \
MGB_MARK_USED_VAR(out_grad); \
return InvalidGrad::make(opr, wrt_idx); \
} \
#define IMPL(_cls) \
MGB_IMPL_OPR_GRAD(_cls) { \
MGB_MARK_USED_VAR(out_grad); \
return InvalidGrad::make(opr, wrt_idx); \
}
namespace
mgb
{
namespace
opr
{
namespace
intl
{
template
class
RNGOpr
<::
megdnn
::
GaussianRNG
>;
template
class
RNGOpr
<::
megdnn
::
UniformRNG
>;
#ifdef MGB_ENABLE_GRAD
IMPL
(
GaussianRNG
);
IMPL
(
UniformRNG
);
#endif
}
}
}
...
...
src/opr/impl/tensor_gen.cpp
浏览文件 @
80c47053
...
...
@@ -46,11 +46,13 @@ void Alloc::outshape_by_symvar_do_get_output_shape(
void
Alloc
::
scn_do_execute
()
{
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Alloc
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
MGB_MARK_USED_VAR
(
out_grad
);
return
InvalidGrad
::
make
(
opr
,
0
);
}
#endif
/* ======================= Linspace ======================= */
...
...
@@ -123,6 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) {
std
::
make_unique
<
intl
::
MegDNNGraphDep
>
(
std
::
move
(
m_megdnn_opr
)));
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Linspace
)
{
if
(
wrt_idx
==
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -134,6 +137,7 @@ MGB_IMPL_OPR_GRAD(Linspace) {
return
opr
::
Dot
::
make
(
og
,
opr
::
Linspace
::
make
(
i0
,
i1
,
opr
.
input
(
2
),
opr
.
param
())).
node
();
}
#endif
/* ======================= Eye ======================= */
...
...
@@ -195,9 +199,10 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) {
std
::
make_unique
<
intl
::
MegDNNGraphDep
>
(
std
::
move
(
m_megdnn_opr
)));
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Eye
)
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/impl/tensor_manip.cpp
浏览文件 @
80c47053
...
...
@@ -165,12 +165,13 @@ void GetVarShape::init_output_static_infer_desc() {
mgr
.
register_value_infer
(
output
(
0
),
{
SourceType
::
DEP
,
deps
,
infer_value
});
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
GetVarShape
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
MGB_MARK_USED_VAR
(
out_grad
);
return
nullptr
;
}
#endif
SymbolVar
GetVarShape
::
make
(
const
VarNodeArrayView
&
inp
,
Param
param
,
const
OperatorNodeConfig
&
config
)
{
...
...
@@ -362,11 +363,13 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp,
inp
.
node
(),
tshp
.
node
(),
unspec_axis
,
config
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Reshape
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
return
Reshape
::
make
(
out_grad
[
0
],
GetVarShape
::
make
(
opr
.
input
(
0
))).
node
();
}
#endif
Maybe
<
TensorLayout
>
Reshape
::
reshapebrdcast_get_dest_layout
(
const
TensorLayout
&
src
,
const
TensorShape
&
tshape
)
const
{
...
...
@@ -429,12 +432,14 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp,
inp
.
node
(),
tshp
.
node
(),
config
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Broadcast
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
return
Reduce
::
make
(
out_grad
.
at
(
0
),
Reduce
::
Mode
::
SUM
,
GetVarShape
::
make
(
opr
.
input
(
0
))).
node
();
}
#endif
Maybe
<
TensorLayout
>
Broadcast
::
reshapebrdcast_get_dest_layout
(
const
TensorLayout
&
src
,
const
TensorShape
&
tshape
)
const
{
...
...
@@ -562,9 +567,11 @@ VarNode* Dimshuffle::grad(
return
Dimshuffle
::
make
(
out_grad
.
at
(
0
),
back
,
m_pattern
.
size
()).
node
();
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Dimshuffle
)
{
return
opr
.
grad
(
wrt_idx
,
out_grad
);
}
#endif
// f}}}
...
...
@@ -631,10 +638,12 @@ AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const {
return
ret
;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
AxisAddRemove
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
return
Reshape
::
make
(
out_grad
[
0
],
GetVarShape
::
make
(
opr
.
input
(
0
))).
node
();
}
#endif
// f}}}
...
...
@@ -642,6 +651,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_IMPL_FANCY_INDEXING_OPR_GET
(
Subtensor
,
"subtensor"
,
true
);
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Subtensor
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -650,6 +660,7 @@ MGB_IMPL_OPR_GRAD(Subtensor) {
SymbolVar
{
opr
.
input
(
0
)}.
fill_retain_dtype
(
0
),
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
#endif
void
Subtensor
::
init_output_static_infer_desc
()
{
using
namespace
cg
::
static_infer
;
...
...
@@ -783,6 +794,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
sub
.
copy_from_fixlayout
(
val
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
SetSubtensor
)
{
if
(
wrt_idx
>=
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -793,6 +805,7 @@ MGB_IMPL_OPR_GRAD(SetSubtensor) {
}
return
Subtensor
::
make
(
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
#endif
// f}}}
...
...
@@ -813,6 +826,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
opr
->
exec
(
sub
.
as_megdnn
(),
val
.
as_megdnn
());
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IncrSubtensor
)
{
if
(
wrt_idx
>=
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -821,6 +835,7 @@ MGB_IMPL_OPR_GRAD(IncrSubtensor) {
}
return
Subtensor
::
make
(
out_grad
.
at
(
0
),
opr
.
index_desc
()).
node
();
}
#endif
// f}}}
...
...
@@ -1085,6 +1100,7 @@ void Split::do_execute(ExecEnv &env) {
}
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Split
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -1100,6 +1116,7 @@ MGB_IMPL_OPR_GRAD(Split) {
return
Concat
::
make
(
grad
,
opr
.
options
().
axis
,
OperatorNodeConfig
{}.
follow_comp_node
(
opr
.
input
(
0
))).
node
();
}
#endif
void
Split
::
mem_plan_fwd_in2out_readonly
()
{
m_readonly_fwd_called
=
true
;
...
...
@@ -1236,6 +1253,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis,
axis
,
config
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Concat
)
{
auto
axis
=
opr
.
axis
();
mgb_assert
(
out_grad
.
size
()
==
1
);
...
...
@@ -1250,6 +1268,7 @@ MGB_IMPL_OPR_GRAD(Concat) {
OperatorNodeConfig
().
comp_node_arr
(
comp_node
));
return
cg
::
to_var_node_array
(
ret
);
}
#endif
void
Concat
::
scn_do_execute
()
{
auto
&&
out
=
output
(
0
)
->
dev_tensor
();
...
...
@@ -1507,6 +1526,7 @@ void ParamPackSplit::init_output_static_infer_desc() {
}
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ParamPackSplit
)
{
mgb_assert
(
out_grad
.
size
()
==
opr
.
output
().
size
());
SmallVector
<
SymbolVar
>
grad
;
...
...
@@ -1531,6 +1551,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) {
OperatorNodeConfig
{}.
follow_comp_node
(
opr
.
input
(
0
)))
.
node
();
}
#endif
// f}}}
/* f{{{ ======================= RelayoutFormat ======================= */
...
...
src/opr/impl/utility.cpp
浏览文件 @
80c47053
...
...
@@ -255,9 +255,11 @@ void MarkDynamicVar::scn_do_execute() {
o
->
dev_tensor
().
copy_from_fixlayout
(
i
->
dev_tensor
());
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MarkDynamicVar
)
{
return
MarkDynamicVar
::
make
(
out_grad
.
at
(
0
)).
node
();
}
#endif
MarkDynamicVar
::
MarkDynamicVar
(
VarNode
*
node
,
const
OperatorNodeConfig
&
config
)
:
Super
{
node
->
owner_graph
(),
config
,
"mark_dyn"
,
{
node
}}
...
...
@@ -381,10 +383,12 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) {
}
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CallbackInjector
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
return
out_grad
.
at
(
0
);
}
#endif
/* ===================== MarkNoBroadcastElemwise ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
MarkNoBroadcastElemwise
);
...
...
@@ -404,9 +408,11 @@ SymbolVar MarkNoBroadcastElemwise::make(
input
.
node
(),
config
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MarkNoBroadcastElemwise
)
{
return
out_grad
.
at
(
0
);
}
#endif
/* ===================== Identity ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Identity
);
...
...
@@ -429,9 +435,11 @@ SymbolVar Identity::make(
return
input
.
insert_single_output_opr
<
Identity
>
(
input
.
node
(),
config
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Identity
)
{
return
out_grad
.
at
(
0
);
}
#endif
/* ===================== AssertEqual ===================== */
...
...
@@ -530,6 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter,
input
.
node
(),
grad_getter
,
config
);
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
SetGrad
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
MGB_MARK_USED_VAR
(
out_grad
);
...
...
@@ -538,6 +547,7 @@ MGB_IMPL_OPR_GRAD(SetGrad) {
"var returned by grad_getter belongs to a different comp graph"
);
return
grad
.
node
();
}
#endif
/* ===================== InvalidGrad ===================== */
...
...
@@ -690,6 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const {
return
ret
;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
VirtualLoss
)
{
mgb_assert
(
out_grad
.
size
()
==
1
);
auto
mid
=
opr
.
input
().
size
()
/
2
;
...
...
@@ -698,6 +709,7 @@ MGB_IMPL_OPR_GRAD(VirtualLoss) {
}
return
nullptr
;
}
#endif
#else
VarNode
*
InvalidGrad
::
make
(
const
OperatorNodeBase
&
,
size_t
)
{
...
...
src/plugin/impl/opr_footprint.cpp
浏览文件 @
80c47053
...
...
@@ -24,6 +24,16 @@
#include "megdnn/opr_param_json.h"
#endif
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL
(
megbrain_opr_footprint
)
#define MIDOUT_B(...) \
MIDOUT_BEGIN(megbrain_opr_footprint, __VA_ARGS__) {
#define MIDOUT_E \
} \
MIDOUT_END();
using
namespace
mgb
;
namespace
{
...
...
@@ -581,9 +591,12 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>(
template
<
class
OprType
>
void
OprFootprint
::
add_single_comp_footprint
()
{
MIDOUT_B
(
OprType
,
midout_iv
(
MGB_HASH_STR
(
"OprFootprint::add_single_comp_footprint"
)))
auto
&&
record
=
m_type2comp_footprint
.
emplace
(
OprType
::
typeinfo
(),
opr_footprint_func
<
OprType
>
);
mgb_assert
(
record
.
second
,
"duplicate opr typeinfo"
);
MIDOUT_E
}
#if MGB_ENABLE_JSON
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录