Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ca717806
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ca717806
编写于
10月 27, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb): fix wrong use of macro MGB_ENABLE_GRAD
GitOrigin-RevId: e66dabee019ffc5e59f204e173a5158c61413ba9
上级
e8bf5bc0
变更
24
显示空白变更内容
内联
并排
Showing
24 changed file
with
79 addition
and
79 deletion
+79
-79
src/opr-mm/impl/collective_comm.cpp
src/opr-mm/impl/collective_comm.cpp
+1
-1
src/opr-mm/impl/io_remote.cpp
src/opr-mm/impl/io_remote.cpp
+1
-1
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+5
-5
src/opr/impl/blas.cpp
src/opr/impl/blas.cpp
+6
-6
src/opr/impl/cond.cpp
src/opr/impl/cond.cpp
+2
-2
src/opr/impl/dnn/adaptive_pooling.cpp
src/opr/impl/dnn/adaptive_pooling.cpp
+1
-1
src/opr/impl/dnn/batch_norm.cpp
src/opr/impl/dnn/batch_norm.cpp
+1
-1
src/opr/impl/dnn/convolution.cpp
src/opr/impl/dnn/convolution.cpp
+9
-9
src/opr/impl/dnn/images2neibs.cpp
src/opr/impl/dnn/images2neibs.cpp
+1
-1
src/opr/impl/dnn/local.cpp
src/opr/impl/dnn/local.cpp
+2
-2
src/opr/impl/dnn/lrn.cpp
src/opr/impl/dnn/lrn.cpp
+1
-1
src/opr/impl/dnn/pooling.cpp
src/opr/impl/dnn/pooling.cpp
+1
-1
src/opr/impl/dnn/roi_align.cpp
src/opr/impl/dnn/roi_align.cpp
+1
-1
src/opr/impl/dnn/roi_pooling.cpp
src/opr/impl/dnn/roi_pooling.cpp
+2
-2
src/opr/impl/imgproc.cpp
src/opr/impl/imgproc.cpp
+3
-3
src/opr/impl/indexing.cpp
src/opr/impl/indexing.cpp
+12
-12
src/opr/impl/io.cpp
src/opr/impl/io.cpp
+1
-1
src/opr/impl/loop/forward.cpp
src/opr/impl/loop/forward.cpp
+1
-1
src/opr/impl/misc.cpp
src/opr/impl/misc.cpp
+6
-6
src/opr/impl/muxing.cpp
src/opr/impl/muxing.cpp
+1
-1
src/opr/impl/rand.cpp
src/opr/impl/rand.cpp
+1
-1
src/opr/impl/tensor_gen.cpp
src/opr/impl/tensor_gen.cpp
+3
-3
src/opr/impl/tensor_manip.cpp
src/opr/impl/tensor_manip.cpp
+11
-11
src/opr/impl/utility.cpp
src/opr/impl/utility.cpp
+6
-6
未找到文件。
src/opr-mm/impl/collective_comm.cpp
浏览文件 @
ca717806
...
...
@@ -760,7 +760,7 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const {
return
ModeTrait
::
from_mode
(
m_param
.
mode
).
grad
(
out_grad
,
this
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CollectiveComm
)
{
mgb_assert
(
out_grad
.
size
()
==
1
,
"CollectiveComm should only have one grad"
);
return
opr
.
grad
(
out_grad
[
0
]);
...
...
src/opr-mm/impl/io_remote.cpp
浏览文件 @
ca717806
...
...
@@ -119,7 +119,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
return
prop
;
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
RemoteSend
)
{
mgb_assert
(
opr
.
is_grad
());
return
RemoteRecv
::
make
(
opr
.
key
()
+
":grad"
,
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
ca717806
...
...
@@ -552,7 +552,7 @@ void Elemwise::call_megdnn_opr_exec(
opr
->
exec
(
inp
,
out
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Elemwise
)
{
SymbolVar
i
[
5
];
SymbolVar
i0
(
opr
.
input
(
0
)),
i1
,
i2
,
out
(
opr
.
output
(
0
)),
...
...
@@ -822,7 +822,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const {
return
ret
;
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
TypeCvt
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
auto
itype
=
opr
.
input
(
0
)
->
dtype
(),
otype
=
opr
.
output
(
0
)
->
dtype
();
...
...
@@ -973,7 +973,7 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr
(
deps
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
AddUpdate
)
{
// actually valid, just not implemented
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -1712,7 +1712,7 @@ void Reduce::create_megdnn_opr() {
create_operator
<
megdnn
::
Reduce
>
());
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Reduce
)
{
for
(
size_t
i
=
1
;
i
<
opr
.
output
().
size
();
++
i
)
mgb_assert
(
!
out_grad
[
i
]);
...
...
@@ -1798,7 +1798,7 @@ void PowC::init_output_static_infer_desc() {
{
SourceType
::
DEP
,
{{
input
(
0
),
DepType
::
VALUE
}},
infer_value
});
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
PowC
)
{
auto
exp
=
opr
.
param
().
exp
;
return
(
exp
*
SymbolVar
{
out_grad
[
0
]}
*
...
...
src/opr/impl/blas.cpp
浏览文件 @
ca717806
...
...
@@ -106,7 +106,7 @@ void MatrixMul::scn_do_execute() {
MGB_FINALLY
({
tparam
=
this
->
param
();
});
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MatrixMul
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
().
category
()
==
DTypeCategory
::
FLOAT
,
"only float data type supported for grad"
);
...
...
@@ -226,7 +226,7 @@ void BatchedMatrixMul::scn_do_execute() {
MGB_FINALLY
({
tparam
=
this
->
param
();
});
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchedMatrixMul
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
().
category
()
==
DTypeCategory
::
FLOAT
,
"only float data type supported for grad"
);
...
...
@@ -331,7 +331,7 @@ void Dot::add_input_layout_constraint() {
input
(
1
)
->
add_layout_constraint
(
check
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Dot
)
{
auto
other_input
=
opr
.
input
(
wrt_idx
==
0
?
1
:
0
);
auto
ishp0
=
opr
::
GetVarShape
::
make
(
opr
.
input
(
0
)),
...
...
@@ -357,7 +357,7 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
MatrixInverse
);
MEGDNN_OPR_INIT1
(
MatrixInverse
,
"matrix_inv"
)
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MatrixInverse
)
{
SymbolVar
a
=
opr
.
output
(
0
);
// TODO: use unified MatrixMul interface when we have it
...
...
@@ -395,7 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) :
}
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
namespace
{
/*!
...
...
@@ -489,7 +489,7 @@ OP(*, {}, {})
}
// anonymous namespace
#endif
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
SVD
)
{
/**
* The formula is copied from
...
...
src/opr/impl/cond.cpp
浏览文件 @
ca717806
...
...
@@ -818,7 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input,
return
input
;
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CondExecMark
)
{
if
(
wrt_idx
==
opr
.
input
().
size
()
-
1
||
!
out_grad
.
at
(
wrt_idx
))
{
return
nullptr
;
...
...
@@ -1227,7 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
return
ret
;
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CondExecMerge
)
{
using
Mode
=
CondExecMerge
::
Param
::
Mode
;
if
(
opr
.
param
().
mode
==
Mode
::
SUM_COND_OUT
&&
...
...
src/opr/impl/dnn/adaptive_pooling.cpp
浏览文件 @
ca717806
...
...
@@ -91,7 +91,7 @@ void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr
(
deps
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
AdaptivePoolingForward
)
{
if
(
wrt_idx
==
0
)
{
// wrt src
...
...
src/opr/impl/dnn/batch_norm.cpp
浏览文件 @
ca717806
...
...
@@ -240,7 +240,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {
}
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchNormForward
)
{
mgb_assert
(
opr
.
param
().
fwd_mode
==
BatchNorm
::
Param
::
FwdMode
::
TRAINING
,
"batch norm could only take grad in training mode"
);
...
...
src/opr/impl/dnn/convolution.cpp
浏览文件 @
ca717806
...
...
@@ -1012,7 +1012,7 @@ void ConvolutionForward::init_output_dtype() {
output
(
0
)
->
dtype
(
output_dtype
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ConvolutionForward
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
().
category
()
==
DTypeCategory
::
FLOAT
,
"only float data type supported for grad"
);
...
...
@@ -1175,7 +1175,7 @@ void ConvolutionBackwardData::scn_do_execute() {
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ConvolutionBackwardData
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1229,7 +1229,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes(
megdnn_opr
(),
this
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ConvolutionBackwardFilter
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1285,7 +1285,7 @@ void Convolution3DForward::init_output_dtype() {
}
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Convolution3DForward
)
{
mgb_assert
(
opr
.
param
().
data_type
==
Convolution3DForward
::
Param
::
DataType
::
FLOAT
,
...
...
@@ -1380,7 +1380,7 @@ void Convolution3DBackwardData::scn_do_execute() {
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Convolution3DBackwardData
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1781,7 +1781,7 @@ size_t LocalShareForward::get_workspace_size_bytes(
megdnn_opr
(),
this
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LocalShareForward
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
().
category
()
==
DTypeCategory
::
FLOAT
,
"only float data type supported for grad"
);
...
...
@@ -1862,7 +1862,7 @@ void LocalShareBackwardData::scn_do_execute() {
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LocalShareBackwardData
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1919,7 +1919,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes(
megdnn_opr
(),
this
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LocalShareBackwardFilter
)
{
mgb_assert
(
!
out_grad
[
1
]);
if
(
wrt_idx
==
0
)
{
...
...
@@ -1998,7 +1998,7 @@ size_t DeformableConvForward::get_workspace_size_bytes(
megdnn_opr
(),
this
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
DeformableConvForward
)
{
mgb_assert
(
opr
.
input
(
0
)
->
dtype
()
==
dtype
::
Float32
(),
"only float data type supported for grad"
);
...
...
src/opr/impl/dnn/images2neibs.cpp
浏览文件 @
ca717806
...
...
@@ -20,7 +20,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Images2NeibsForward
);
MEGDNN_OPR_INIT1
(
Images2NeibsForward
,
"images2neibs"
)
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Images2NeibsForward
)
{
mgb_assert
(
wrt_idx
==
0
&&
out_grad
.
size
()
==
2
&&
!
out_grad
[
1
]);
return
Images2NeibsBackward
::
make
(
...
...
src/opr/impl/dnn/local.cpp
浏览文件 @
ca717806
...
...
@@ -21,7 +21,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
LocalForward
);
MEGDNN_OPR_INIT2
(
LocalForward
,
"local"
)
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LocalForward
)
{
return
intl
::
conv_grad
<
LocalBackwardData
,
LocalBackwardFilter
>
(
opr
,
wrt_idx
,
out_grad
);
...
...
@@ -38,7 +38,7 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
GroupLocalForward
);
MEGDNN_OPR_INIT2
(
GroupLocalForward
,
"glocal"
)
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
GroupLocalForward
)
{
return
intl
::
conv_grad
<
GroupLocalBackwardData
,
GroupLocalBackwardFilter
>
(
opr
,
wrt_idx
,
out_grad
);
...
...
src/opr/impl/dnn/lrn.cpp
浏览文件 @
ca717806
...
...
@@ -20,7 +20,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
LRNForward
);
MEGDNN_OPR_INIT1
(
LRNForward
,
"lrn"
)
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
LRNForward
)
{
mgb_assert
(
wrt_idx
==
0
);
SymbolVar
grad
=
LRNBackward
::
make
(
...
...
src/opr/impl/dnn/pooling.cpp
浏览文件 @
ca717806
...
...
@@ -19,7 +19,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
PoolingForward
);
MEGDNN_OPR_INIT1
(
PoolingForward
,
"pooling"
)
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
PoolingForward
)
{
mgb_assert
(
wrt_idx
==
0
);
SymbolVar
grad
=
PoolingBackward
::
make
(
...
...
src/opr/impl/dnn/roi_align.cpp
浏览文件 @
ca717806
...
...
@@ -40,7 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois,
src
.
node
(),
rois
.
node
(),
param
,
config
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ROIAlignForward
)
{
if
(
wrt_idx
==
0
)
{
// wrt src
...
...
src/opr/impl/dnn/roi_pooling.cpp
浏览文件 @
ca717806
...
...
@@ -84,7 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes(
input_shapes
,
output_shapes
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ROIPoolingForward
)
{
if
(
wrt_idx
==
2
)
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -148,7 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make(
return
all
[
0
];
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
DeformablePSROIPooling
)
{
mgb_assert
(
wrt_idx
<=
2
);
// wrt_idx = 0 or 1 or 2
...
...
src/opr/impl/imgproc.cpp
浏览文件 @
ca717806
...
...
@@ -126,7 +126,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr
(
deps
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
WarpPerspectiveForward
)
{
if
(
opr
.
input
().
size
()
==
4
)
{
if
(
wrt_idx
==
0
)
{
...
...
@@ -351,7 +351,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr
(
deps
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ResizeForward
)
{
mgb_assert
(
opr
.
input
().
size
()
==
2
);
if
(
wrt_idx
==
0
)
{
...
...
@@ -443,7 +443,7 @@ void RemapForward::init_output_dtype() {
output
(
0
)
->
dtype
(
input
(
0
)
->
dtype
());
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
RemapForward
)
{
mgb_assert
(
opr
.
input
().
size
()
==
2
);
if
(
wrt_idx
==
0
)
{
...
...
src/opr/impl/indexing.cpp
浏览文件 @
ca717806
...
...
@@ -83,7 +83,7 @@ void IndexingOneHot::init_output_dtype() {
output
(
0
)
->
dtype
(
input
(
0
)
->
dtype
());
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingOneHot
)
{
if
(
wrt_idx
==
0
)
{
return
IndexingSetOneHot
::
make
(
...
...
@@ -135,7 +135,7 @@ void IndexingSetOneHot::scn_do_execute() {
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingSetOneHot
)
{
SymbolVar
index
{
opr
.
input
(
1
)},
sub
{
opr
.
input
(
2
)},
og
{
out_grad
.
at
(
0
)};
if
(
wrt_idx
==
0
)
{
...
...
@@ -169,7 +169,7 @@ void IndexingRemap::init_output_dtype() {
output
(
0
)
->
dtype
(
input
(
0
)
->
dtype
());
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingRemap
)
{
if
(
wrt_idx
==
1
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -466,7 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
IndexingIncrMultiAxisVec
,
"indexing_incr_multi_axis_vec"
,
false
);
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingMultiAxisVec
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -477,7 +477,7 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
}
#endif
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingSetMultiAxisVec
)
{
if
(
wrt_idx
>=
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -490,7 +490,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
}
#endif
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingIncrMultiAxisVec
)
{
if
(
wrt_idx
>=
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -510,7 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET(
BatchedMeshIndexing
,
"batched_mesh_indexing"
,
false
,
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
););
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MeshIndexing
)
{
if
(
wrt_idx
!=
0
)
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -522,7 +522,7 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) {
}
#endif
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchedMeshIndexing
)
{
if
(
wrt_idx
!=
0
)
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -539,7 +539,7 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
IncrMeshIndexing
,
"incr_mesh_indexing"
,
false
);
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IncrMeshIndexing
)
{
if
(
wrt_idx
>
2
)
{
return
opr
::
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -553,7 +553,7 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
BatchedIncrMeshIndexing
,
"batched_incr_mesh_indexing"
,
false
);
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchedIncrMeshIndexing
)
{
if
(
wrt_idx
>
2
)
{
return
opr
::
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -568,7 +568,7 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
/* ======================== SetMeshIndexing =========================== */
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
SetMeshIndexing
,
"set_mesh_indexing"
,
false
);
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
SetMeshIndexing
)
{
if
(
wrt_idx
>=
2
)
{
return
opr
::
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -587,7 +587,7 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
BatchedSetMeshIndexing
,
"batched_set_mesh_indexing"
,
false
);
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchedSetMeshIndexing
)
{
if
(
wrt_idx
>
2
)
{
return
opr
::
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
src/opr/impl/io.cpp
浏览文件 @
ca717806
...
...
@@ -766,7 +766,7 @@ Copy::NodeProp* Copy::do_make_node_prop() const {
return
rst
;
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Copy
)
{
mgb_assert
(
wrt_idx
==
0
);
return
Copy
::
make
(
out_grad
[
0
],
...
...
src/opr/impl/loop/forward.cpp
浏览文件 @
ca717806
...
...
@@ -268,7 +268,7 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) {
return
gopr
->
get_grad_var
(
wrt_idx
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Loop
)
{
return
Loop
::
grad
(
const_cast
<
Loop
&>
(
opr
),
wrt_idx
,
out_grad
);
}
...
...
src/opr/impl/misc.cpp
浏览文件 @
ca717806
...
...
@@ -48,7 +48,7 @@ namespace intl {
/* ================= Argmxx ================= */
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Argmax
)
{
MGB_MARK_USED_VAR
(
out_grad
);
MGB_MARK_USED_VAR
(
opr
);
...
...
@@ -60,7 +60,7 @@ MGB_IMPL_OPR_GRAD(Argmax) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Argmax
);
MEGDNN_OPR_INIT1
(
Argmax
,
"argmax"
)
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Argmin
)
{
MGB_MARK_USED_VAR
(
out_grad
);
MGB_MARK_USED_VAR
(
opr
);
...
...
@@ -87,7 +87,7 @@ std::array<SymbolVar, 2> ArgsortForward::make(
return
{
node
->
output
(
0
),
node
->
output
(
1
)};
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ArgsortForward
)
{
mgb_assert
(
out_grad
.
size
()
==
3
&&
wrt_idx
==
0
&&
!
out_grad
[
2
]);
if
(
!
out_grad
[
0
])
...
...
@@ -112,7 +112,7 @@ Cumsum::Cumsum(VarNode* opr, const Param& param,
add_input
({
opr
},
AddInputSortType
::
CUR_ADDED
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Cumsum
)
{
mgb_assert
(
out_grad
[
0
]
&&
!
out_grad
[
1
]);
auto
param
=
opr
.
param
();
...
...
@@ -263,7 +263,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask,
}
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CondTake
)
{
mgb_assert
(
out_grad
.
size
()
==
3
&&
!
out_grad
[
2
]);
if
(
wrt_idx
==
0
&&
out_grad
[
0
])
{
...
...
@@ -413,7 +413,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr
(
deps
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
TopK
)
{
if
(
opr
.
param
().
mode
==
TopK
::
Param
::
Mode
::
KTH_ONLY
)
{
mgb_assert
(
out_grad
[
0
]
&&
!
out_grad
[
1
]
&&
!
out_grad
[
2
]);
...
...
src/opr/impl/muxing.cpp
浏览文件 @
ca717806
...
...
@@ -316,7 +316,7 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) {
OperatorNodeConfig
().
comp_node_arr
(
sp_cn
)));
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
AllGather
)
{
return
const_cast
<
AllGather
&>
(
opr
).
grad
(
out_grad
);
}
...
...
src/opr/impl/rand.cpp
浏览文件 @
ca717806
...
...
@@ -123,7 +123,7 @@ namespace opr {
namespace
intl
{
template
class
RNGOpr
<::
megdnn
::
GaussianRNG
>;
template
class
RNGOpr
<::
megdnn
::
UniformRNG
>;
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
IMPL
(
GaussianRNG
);
IMPL
(
UniformRNG
);
#endif
...
...
src/opr/impl/tensor_gen.cpp
浏览文件 @
ca717806
...
...
@@ -46,7 +46,7 @@ void Alloc::outshape_by_symvar_do_get_output_shape(
void
Alloc
::
scn_do_execute
()
{
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Alloc
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
MGB_MARK_USED_VAR
(
out_grad
);
...
...
@@ -125,7 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) {
std
::
make_unique
<
intl
::
MegDNNGraphDep
>
(
std
::
move
(
m_megdnn_opr
)));
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Linspace
)
{
if
(
wrt_idx
==
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -199,7 +199,7 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) {
std
::
make_unique
<
intl
::
MegDNNGraphDep
>
(
std
::
move
(
m_megdnn_opr
)));
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Eye
)
{
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
}
...
...
src/opr/impl/tensor_manip.cpp
浏览文件 @
ca717806
...
...
@@ -165,7 +165,7 @@ void GetVarShape::init_output_static_infer_desc() {
mgr
.
register_value_infer
(
output
(
0
),
{
SourceType
::
DEP
,
deps
,
infer_value
});
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
GetVarShape
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
MGB_MARK_USED_VAR
(
out_grad
);
...
...
@@ -372,7 +372,7 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp,
inp
.
node
(),
tshp
.
node
(),
unspec_axis
,
config
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Reshape
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -441,7 +441,7 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp,
inp
.
node
(),
tshp
.
node
(),
config
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Broadcast
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -586,7 +586,7 @@ VarNode* Dimshuffle::grad(
return
Dimshuffle
::
make
(
out_grad
.
at
(
0
),
back
,
m_pattern
.
size
()).
node
();
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Dimshuffle
)
{
return
opr
.
grad
(
wrt_idx
,
out_grad
);
}
...
...
@@ -649,7 +649,7 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout(
return
layout
;
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
AxisAddRemove
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
return
Reshape
::
make
(
out_grad
[
0
],
GetVarShape
::
make
(
opr
.
input
(
0
))).
node
();
...
...
@@ -662,7 +662,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_IMPL_FANCY_INDEXING_OPR_GET
(
Subtensor
,
"subtensor"
,
true
);
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Subtensor
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -806,7 +806,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
sub
.
copy_from_fixlayout
(
val
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
SetSubtensor
)
{
if
(
wrt_idx
>=
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -838,7 +838,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
opr
->
exec
(
sub
.
as_megdnn
(),
val
.
as_megdnn
());
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IncrSubtensor
)
{
if
(
wrt_idx
>=
2
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -1112,7 +1112,7 @@ void Split::do_execute(ExecEnv &env) {
}
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Split
)
{
if
(
wrt_idx
)
return
InvalidGrad
::
make
(
opr
,
wrt_idx
);
...
...
@@ -1265,7 +1265,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis,
axis
,
config
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Concat
)
{
auto
axis
=
opr
.
axis
();
mgb_assert
(
out_grad
.
size
()
==
1
);
...
...
@@ -1549,7 +1549,7 @@ void ParamPackSplit::scn_do_execute() {
mgb_assert
(
inp_size
==
m_offsets
.
back
(),
"input shape should match offsets"
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ParamPackSplit
)
{
mgb_assert
(
out_grad
.
size
()
==
opr
.
output
().
size
());
SmallVector
<
SymbolVar
>
grad
;
...
...
src/opr/impl/utility.cpp
浏览文件 @
ca717806
...
...
@@ -255,7 +255,7 @@ void MarkDynamicVar::scn_do_execute() {
o
->
dev_tensor
().
copy_from_fixlayout
(
i
->
dev_tensor
());
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MarkDynamicVar
)
{
return
MarkDynamicVar
::
make
(
out_grad
.
at
(
0
)).
node
();
}
...
...
@@ -383,7 +383,7 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) {
}
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CallbackInjector
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
return
out_grad
.
at
(
0
);
...
...
@@ -408,7 +408,7 @@ SymbolVar MarkNoBroadcastElemwise::make(
input
.
node
(),
config
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
MarkNoBroadcastElemwise
)
{
return
out_grad
.
at
(
0
);
}
...
...
@@ -435,7 +435,7 @@ SymbolVar Identity::make(
return
input
.
insert_single_output_opr
<
Identity
>
(
input
.
node
(),
config
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
Identity
)
{
return
out_grad
.
at
(
0
);
}
...
...
@@ -538,7 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter,
input
.
node
(),
grad_getter
,
config
);
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
SetGrad
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
MGB_MARK_USED_VAR
(
out_grad
);
...
...
@@ -700,7 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const {
return
ret
;
}
#if
def
MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
VirtualLoss
)
{
mgb_assert
(
out_grad
.
size
()
==
1
);
auto
mid
=
opr
.
input
().
size
()
/
2
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录