diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index e3bc3f13447a1e8b4d8b382a74f9766cb2917544..b14f81d38346252d555102e133898e6921b5f25c 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -760,7 +760,7 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const { return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CollectiveComm) { mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); return opr.grad(out_grad[0]); diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index cddd08fcec5cf0274b65888848dc7548e999c052..53e5b1c3bc1282059c1ca6432cc64e7c8e7568b9 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -119,7 +119,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { return prop; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(RemoteSend) { mgb_assert(opr.is_grad()); return RemoteRecv::make(opr.key() + ":grad", diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 6c126651c0b740a83499cf229521da573bd9f81e..c7916909782f2a4c6ff55177a3d2e0c339c19db9 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -552,7 +552,7 @@ void Elemwise::call_megdnn_opr_exec( opr->exec(inp, out); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Elemwise) { SymbolVar i[5]; SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)), @@ -822,7 +822,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const { return ret; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(TypeCvt) { MGB_MARK_USED_VAR(wrt_idx); auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype(); @@ -973,7 +973,7 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(AddUpdate) { // actually valid, just not implemented return InvalidGrad::make(opr, wrt_idx); @@ -1712,7 +1712,7 @@ void Reduce::create_megdnn_opr() { create_operator()); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Reduce) { for (size_t i = 1; i < opr.output().size(); ++ i) mgb_assert(!out_grad[i]); @@ -1798,7 +1798,7 @@ void PowC::init_output_static_infer_desc() { {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(PowC) { auto exp = opr.param().exp; return (exp * SymbolVar{out_grad[0]} * diff --git a/src/opr/impl/blas.cpp b/src/opr/impl/blas.cpp index ebc142942a4a4fd7f23af16f3b41744468c41fa1..769216696d0229511a28596a8aa16d43ecebfdda 100644 --- a/src/opr/impl/blas.cpp +++ b/src/opr/impl/blas.cpp @@ -106,7 +106,7 @@ void MatrixMul::scn_do_execute() { MGB_FINALLY({ tparam = this->param(); }); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MatrixMul) { mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, "only float data type supported for grad"); @@ -226,7 +226,7 @@ void BatchedMatrixMul::scn_do_execute() { MGB_FINALLY({ tparam = this->param(); }); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, "only float data type supported for grad"); @@ -331,7 +331,7 @@ void Dot::add_input_layout_constraint() { input(1)->add_layout_constraint(check); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Dot) { auto other_input = opr.input(wrt_idx == 0 ? 1 : 0); auto ishp0 = opr::GetVarShape::make(opr.input(0)), @@ -357,7 +357,7 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv") -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MatrixInverse) { SymbolVar a = opr.output(0); // TODO: use unified MatrixMul interface when we have it @@ -395,7 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) : } } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD namespace { /*! @@ -489,7 +489,7 @@ OP(*, {}, {}) } // anonymous namespace #endif -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SVD) { /** * The formula is copied from diff --git a/src/opr/impl/cond.cpp b/src/opr/impl/cond.cpp index 03e3dd03c6dd55d8ce04cbf3f6a06c22b405b8b9..66f29b240aee54520ee7cab29174372045f5158f 100644 --- a/src/opr/impl/cond.cpp +++ b/src/opr/impl/cond.cpp @@ -818,7 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input, return input; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CondExecMark) { if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) { return nullptr; @@ -1227,7 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const { return ret; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CondExecMerge) { using Mode = CondExecMerge::Param::Mode; if (opr.param().mode == Mode::SUM_COND_OUT && diff --git a/src/opr/impl/dnn/adaptive_pooling.cpp b/src/opr/impl/dnn/adaptive_pooling.cpp index 93d829d3d12dd31761bd69874d08c1fd50308751..14ec425e41ffe79e7464d2afa8fd75bcfe079c3e 100644 --- a/src/opr/impl/dnn/adaptive_pooling.cpp +++ b/src/opr/impl/dnn/adaptive_pooling.cpp @@ -91,7 +91,7 @@ void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) { if (wrt_idx == 0) { // wrt src diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index 5d403fbf4ffadc73837e7edd98bb7ac343d923d9..8fb4e3a3d2cb24d59ae82c95f16cd6eb6aecd3c4 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -240,7 +240,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { } } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchNormForward) { mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING, "batch norm could only take grad in training mode"); diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index f4a2abfb7b7bc2adbac4c8473e64dabaa591afd3..e3469707f541a8179c33ba718db0bcbb83a13706 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -1012,7 +1012,7 @@ void ConvolutionForward::init_output_dtype() { output(0)->dtype(output_dtype); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ConvolutionForward) { mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, "only float data type supported for grad"); @@ -1175,7 +1175,7 @@ void ConvolutionBackwardData::scn_do_execute() { intl::get_megdnn_workspace_from_var(output(1))); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1229,7 +1229,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes( megdnn_opr(), this); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1285,7 +1285,7 @@ void Convolution3DForward::init_output_dtype() { } } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Convolution3DForward) { mgb_assert(opr.param().data_type == Convolution3DForward::Param::DataType::FLOAT, @@ -1380,7 +1380,7 @@ void Convolution3DBackwardData::scn_do_execute() { intl::get_megdnn_workspace_from_var(output(1))); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1781,7 +1781,7 @@ size_t LocalShareForward::get_workspace_size_bytes( megdnn_opr(), this); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LocalShareForward) { mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, "only float data type supported for grad"); @@ -1862,7 +1862,7 @@ void LocalShareBackwardData::scn_do_execute() { intl::get_megdnn_workspace_from_var(output(1))); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1919,7 +1919,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes( megdnn_opr(), this); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1998,7 +1998,7 @@ size_t DeformableConvForward::get_workspace_size_bytes( megdnn_opr(), this); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(DeformableConvForward) { mgb_assert(opr.input(0)->dtype() == dtype::Float32(), "only float data type supported for grad"); diff --git a/src/opr/impl/dnn/images2neibs.cpp b/src/opr/impl/dnn/images2neibs.cpp index 904ba503eb38212c2e66b098fca9392f98a040a3..e3bba1520195b3769bb616b1f23c32989f7a3b62 100644 --- a/src/opr/impl/dnn/images2neibs.cpp +++ b/src/opr/impl/dnn/images2neibs.cpp @@ -20,7 +20,7 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward); MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs") -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Images2NeibsForward) { mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); return Images2NeibsBackward::make( diff --git a/src/opr/impl/dnn/local.cpp b/src/opr/impl/dnn/local.cpp index 44211c4fe53948364cd6cdd3e2f36c9d36636c77..6795ed7ebc28910662235535dc4e7a081d57ca93 100644 --- a/src/opr/impl/dnn/local.cpp +++ b/src/opr/impl/dnn/local.cpp @@ -21,7 +21,7 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward); MEGDNN_OPR_INIT2(LocalForward, "local") -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LocalForward) { return intl::conv_grad( opr, wrt_idx, out_grad); @@ -38,7 +38,7 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward); MEGDNN_OPR_INIT2(GroupLocalForward, "glocal") -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(GroupLocalForward) { return intl::conv_grad( opr, wrt_idx, out_grad); diff --git a/src/opr/impl/dnn/lrn.cpp b/src/opr/impl/dnn/lrn.cpp index e192af2ae4c285211e1a834cad492acaf8c3b2a4..26395e64caa688ef13cb904d2e4f89cdf955aa03 100644 --- a/src/opr/impl/dnn/lrn.cpp +++ b/src/opr/impl/dnn/lrn.cpp @@ -20,7 +20,7 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward); MEGDNN_OPR_INIT1(LRNForward, "lrn") -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LRNForward) { mgb_assert(wrt_idx == 0); SymbolVar grad = LRNBackward::make( diff --git a/src/opr/impl/dnn/pooling.cpp b/src/opr/impl/dnn/pooling.cpp index 5045b4988f91eaa3f5e75a636f6af015661fdc77..01f3742ae56b0c514acef487389bfe1dd5181355 100644 --- a/src/opr/impl/dnn/pooling.cpp +++ b/src/opr/impl/dnn/pooling.cpp @@ -19,7 +19,7 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); MEGDNN_OPR_INIT1(PoolingForward, "pooling") -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(PoolingForward) { mgb_assert(wrt_idx == 0); SymbolVar grad = PoolingBackward::make( diff --git a/src/opr/impl/dnn/roi_align.cpp b/src/opr/impl/dnn/roi_align.cpp index e89ec8dcdd65f1c50d71b3bc54072f86ecc6cd23..5ebf401d2ca1d5586919fa9975214583aeb63971 100644 --- a/src/opr/impl/dnn/roi_align.cpp +++ b/src/opr/impl/dnn/roi_align.cpp @@ -40,7 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois, src.node(), rois.node(), param, config); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIAlignForward) { if (wrt_idx == 0) { // wrt src diff --git a/src/opr/impl/dnn/roi_pooling.cpp b/src/opr/impl/dnn/roi_pooling.cpp index 7c2d3df9e25e46717de9608d83780c8817821ad7..90d03092b22e3b4860f119e62f0e330cf21120a4 100644 --- a/src/opr/impl/dnn/roi_pooling.cpp +++ b/src/opr/impl/dnn/roi_pooling.cpp @@ -84,7 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes( input_shapes, output_shapes); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIPoolingForward) { if (wrt_idx == 2) { return InvalidGrad::make(opr, wrt_idx); @@ -148,7 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make( return all[0]; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2 diff --git a/src/opr/impl/imgproc.cpp b/src/opr/impl/imgproc.cpp index ed6c3dd529a28b221d4160a7081515d33e1a9b52..695b527dd3826ee4c68d866f27dc5680e8167126 100644 --- a/src/opr/impl/imgproc.cpp +++ b/src/opr/impl/imgproc.cpp @@ -126,7 +126,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { if (opr.input().size() == 4) { if (wrt_idx == 0) { @@ -351,7 +351,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ResizeForward) { mgb_assert(opr.input().size() == 2); if (wrt_idx == 0) { @@ -443,7 +443,7 @@ void RemapForward::init_output_dtype() { output(0)->dtype(input(0)->dtype()); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(RemapForward) { mgb_assert(opr.input().size() == 2); if (wrt_idx == 0) { diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index f7371c6fecc39b542506d285586e6fccc24d227d..f41129cd22ede642842419fd82403bdf0890fc39 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -83,7 +83,7 @@ void IndexingOneHot::init_output_dtype() { output(0)->dtype(input(0)->dtype()); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingOneHot) { if (wrt_idx == 0) { return IndexingSetOneHot::make( @@ -135,7 +135,7 @@ void IndexingSetOneHot::scn_do_execute() { intl::get_megdnn_workspace_from_var(output(1))); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)}; if (wrt_idx == 0) { @@ -169,7 +169,7 @@ void IndexingRemap::init_output_dtype() { output(0)->dtype(input(0)->dtype()); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingRemap) { if (wrt_idx == 1) return InvalidGrad::make(opr, wrt_idx); @@ -466,7 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); @@ -477,7 +477,7 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { } #endif -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { if (wrt_idx >= 2) return InvalidGrad::make(opr, wrt_idx); @@ -490,7 +490,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { } #endif -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { if (wrt_idx >= 2) return InvalidGrad::make(opr, wrt_idx); @@ -510,7 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET( BatchedMeshIndexing, "batched_mesh_indexing", false, output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MeshIndexing) { if (wrt_idx != 0) { return InvalidGrad::make(opr, wrt_idx); @@ -522,7 +522,7 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) { } #endif -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { if (wrt_idx != 0) { return InvalidGrad::make(opr, wrt_idx); @@ -539,7 +539,7 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", false); -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { if (wrt_idx > 2) { return opr::InvalidGrad::make(opr, wrt_idx); @@ -553,7 +553,7 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing, "batched_incr_mesh_indexing", false); -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { if (wrt_idx > 2) { return opr::InvalidGrad::make(opr, wrt_idx); @@ -568,7 +568,7 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { /* ======================== SetMeshIndexing =========================== */ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false); -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SetMeshIndexing) { if (wrt_idx >= 2) { return opr::InvalidGrad::make(opr, wrt_idx); @@ -587,7 +587,7 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) { MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing, "batched_set_mesh_indexing", false); -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { if (wrt_idx > 2) { return opr::InvalidGrad::make(opr, wrt_idx); diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index bff006e1057dbec873352e837107902898da9e95..a31049394b7064ff9fa2a2ef590f5d0b0d34e109 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -766,7 +766,7 @@ Copy::NodeProp* Copy::do_make_node_prop() const { return rst; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Copy) { mgb_assert(wrt_idx == 0); return Copy::make(out_grad[0], diff --git a/src/opr/impl/loop/forward.cpp b/src/opr/impl/loop/forward.cpp index 0c7eee012fd7451678a671012879603549960828..b8593c739c3509ad0fc748abd4f6d425de6785fb 100644 --- a/src/opr/impl/loop/forward.cpp +++ b/src/opr/impl/loop/forward.cpp @@ -268,7 +268,7 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) { return gopr->get_grad_var(wrt_idx); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Loop) { return Loop::grad(const_cast(opr), wrt_idx, out_grad); } diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 2125f1fab576e3fe87c5bff818440c7a0f2acdb4..c8fc54dd741ed0573324ec5ede50f5fff390d19f 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -48,7 +48,7 @@ namespace intl { /* ================= Argmxx ================= */ -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Argmax) { MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(opr); @@ -60,7 +60,7 @@ MGB_IMPL_OPR_GRAD(Argmax) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax); MEGDNN_OPR_INIT1(Argmax, "argmax") -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Argmin) { MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(opr); @@ -87,7 +87,7 @@ std::array ArgsortForward::make( return {node->output(0), node->output(1)}; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ArgsortForward) { mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); if (!out_grad[0]) @@ -112,7 +112,7 @@ Cumsum::Cumsum(VarNode* opr, const Param& param, add_input({opr}, AddInputSortType::CUR_ADDED); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Cumsum) { mgb_assert(out_grad[0] && !out_grad[1]); auto param = opr.param(); @@ -263,7 +263,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask, } } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CondTake) { mgb_assert(out_grad.size() == 3 && !out_grad[2]); if (wrt_idx == 0 && out_grad[0]) { @@ -413,7 +413,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(TopK) { if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); diff --git a/src/opr/impl/muxing.cpp b/src/opr/impl/muxing.cpp index 571dc7d1d9a239465e1e6064b41d5d2d17e1d324..4d8a4a5f5d05d490459890075e1ae6d1bd90b571 100644 --- a/src/opr/impl/muxing.cpp +++ b/src/opr/impl/muxing.cpp @@ -316,7 +316,7 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) { OperatorNodeConfig().comp_node_arr(sp_cn))); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(AllGather) { return const_cast(opr).grad(out_grad); } diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index b5c16c59f3cc91ad8bd455c7faba08a017f51015..cbbcc64562598fd34f20070de7581952562d650e 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -123,7 +123,7 @@ namespace opr { namespace intl { template class RNGOpr<::megdnn::GaussianRNG>; template class RNGOpr<::megdnn::UniformRNG>; -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD IMPL(GaussianRNG); IMPL(UniformRNG); #endif diff --git a/src/opr/impl/tensor_gen.cpp b/src/opr/impl/tensor_gen.cpp index 24967f83dc847792a2c0ee9acb5c9deabc3f9815..47b102fefea08721837a9d98d72e595dc17cf6e3 100644 --- a/src/opr/impl/tensor_gen.cpp +++ b/src/opr/impl/tensor_gen.cpp @@ -46,7 +46,7 @@ void Alloc::outshape_by_symvar_do_get_output_shape( void Alloc::scn_do_execute() { } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Alloc) { MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(out_grad); @@ -125,7 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) { std::make_unique(std::move(m_megdnn_opr))); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Linspace) { if (wrt_idx == 2) return InvalidGrad::make(opr, wrt_idx); @@ -199,7 +199,7 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) { std::make_unique(std::move(m_megdnn_opr))); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Eye) { return InvalidGrad::make(opr, wrt_idx); } diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index 250a83827658223847f4758ea309643b013ae1ea..a350bf02a3d9a52979d3775ca75b286df36408e0 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -165,7 +165,7 @@ void GetVarShape::init_output_static_infer_desc() { mgr.register_value_infer(output(0), {SourceType::DEP, deps, infer_value}); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(GetVarShape) { MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(out_grad); @@ -372,7 +372,7 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp, inp.node(), tshp.node(), unspec_axis, config); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Reshape) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); @@ -441,7 +441,7 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp, inp.node(), tshp.node(), config); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Broadcast) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); @@ -586,7 +586,7 @@ VarNode* Dimshuffle::grad( return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node(); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Dimshuffle) { return opr.grad(wrt_idx, out_grad); } @@ -649,7 +649,7 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout( return layout; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(AxisAddRemove) { MGB_MARK_USED_VAR(wrt_idx); return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); @@ -662,7 +662,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) { MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Subtensor) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); @@ -806,7 +806,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { sub.copy_from_fixlayout(val); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SetSubtensor) { if (wrt_idx >= 2) return InvalidGrad::make(opr, wrt_idx); @@ -838,7 +838,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { opr->exec(sub.as_megdnn(), val.as_megdnn()); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IncrSubtensor) { if (wrt_idx >= 2) return InvalidGrad::make(opr, wrt_idx); @@ -1112,7 +1112,7 @@ void Split::do_execute(ExecEnv &env) { } } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Split) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); @@ -1265,7 +1265,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis, axis, config); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Concat) { auto axis = opr.axis(); mgb_assert(out_grad.size() == 1); @@ -1549,7 +1549,7 @@ void ParamPackSplit::scn_do_execute() { mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets"); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ParamPackSplit) { mgb_assert(out_grad.size() == opr.output().size()); SmallVector grad; diff --git a/src/opr/impl/utility.cpp b/src/opr/impl/utility.cpp index e9f1f6e0da70c16c784ad6c7e602e45c4819aebb..bfac02e5f35763201cbd5b24e014dc1fe4c7572c 100644 --- a/src/opr/impl/utility.cpp +++ b/src/opr/impl/utility.cpp @@ -255,7 +255,7 @@ void MarkDynamicVar::scn_do_execute() { o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MarkDynamicVar) { return MarkDynamicVar::make(out_grad.at(0)).node(); } @@ -383,7 +383,7 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) { } } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CallbackInjector) { MGB_MARK_USED_VAR(wrt_idx); return out_grad.at(0); @@ -408,7 +408,7 @@ SymbolVar MarkNoBroadcastElemwise::make( input.node(), config); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { return out_grad.at(0); } @@ -435,7 +435,7 @@ SymbolVar Identity::make( return input.insert_single_output_opr(input.node(), config); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Identity) { return out_grad.at(0); } @@ -538,7 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter, input.node(), grad_getter, config); } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SetGrad) { MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(out_grad); @@ -700,7 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const { return ret; } -#ifdef MGB_ENABLE_GRAD +#if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(VirtualLoss) { mgb_assert(out_grad.size() == 1); auto mid = opr.input().size() / 2;