提交 ca717806 编写于 作者: M Megvii Engine Team

fix(mgb): fix wrong use of macro MGB_ENABLE_GRAD

GitOrigin-RevId: e66dabee019ffc5e59f204e173a5158c61413ba9
上级 e8bf5bc0
......@@ -760,7 +760,7 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const {
return ModeTrait::from_mode(m_param.mode).grad(out_grad, this);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CollectiveComm) {
mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad");
return opr.grad(out_grad[0]);
......
......@@ -119,7 +119,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
return prop;
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RemoteSend) {
mgb_assert(opr.is_grad());
return RemoteRecv::make(opr.key() + ":grad",
......
......@@ -552,7 +552,7 @@ void Elemwise::call_megdnn_opr_exec(
opr->exec(inp, out);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Elemwise) {
SymbolVar i[5];
SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)),
......@@ -822,7 +822,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const {
return ret;
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(TypeCvt) {
MGB_MARK_USED_VAR(wrt_idx);
auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype();
......@@ -973,7 +973,7 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AddUpdate) {
// actually valid, just not implemented
return InvalidGrad::make(opr, wrt_idx);
......@@ -1712,7 +1712,7 @@ void Reduce::create_megdnn_opr() {
create_operator<megdnn::Reduce>());
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Reduce) {
for (size_t i = 1; i < opr.output().size(); ++ i)
mgb_assert(!out_grad[i]);
......@@ -1798,7 +1798,7 @@ void PowC::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(PowC) {
auto exp = opr.param().exp;
return (exp * SymbolVar{out_grad[0]} *
......
......@@ -106,7 +106,7 @@ void MatrixMul::scn_do_execute() {
MGB_FINALLY({ tparam = this->param(); });
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MatrixMul) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad");
......@@ -226,7 +226,7 @@ void BatchedMatrixMul::scn_do_execute() {
MGB_FINALLY({ tparam = this->param(); });
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad");
......@@ -331,7 +331,7 @@ void Dot::add_input_layout_constraint() {
input(1)->add_layout_constraint(check);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Dot) {
auto other_input = opr.input(wrt_idx == 0 ? 1 : 0);
auto ishp0 = opr::GetVarShape::make(opr.input(0)),
......@@ -357,7 +357,7 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse);
MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv")
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MatrixInverse) {
SymbolVar a = opr.output(0);
// TODO: use unified MatrixMul interface when we have it
......@@ -395,7 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) :
}
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
namespace {
/*!
......@@ -489,7 +489,7 @@ OP(*, {}, {})
} // anonymous namespace
#endif
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SVD) {
/**
* The formula is copied from
......
......@@ -818,7 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input,
return input;
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondExecMark) {
if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) {
return nullptr;
......@@ -1227,7 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
return ret;
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondExecMerge) {
using Mode = CondExecMerge::Param::Mode;
if (opr.param().mode == Mode::SUM_COND_OUT &&
......
......@@ -91,7 +91,7 @@ void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) {
if (wrt_idx == 0) {
// wrt src
......
......@@ -240,7 +240,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {
}
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchNormForward) {
mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING,
"batch norm could only take grad in training mode");
......
......@@ -1012,7 +1012,7 @@ void ConvolutionForward::init_output_dtype() {
output(0)->dtype(output_dtype);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionForward) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad");
......@@ -1175,7 +1175,7 @@ void ConvolutionBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1)));
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) {
mgb_assert(!out_grad[1]);
if (wrt_idx == 0) {
......@@ -1229,7 +1229,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes(
megdnn_opr(), this);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) {
mgb_assert(!out_grad[1]);
if (wrt_idx == 0) {
......@@ -1285,7 +1285,7 @@ void Convolution3DForward::init_output_dtype() {
}
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Convolution3DForward) {
mgb_assert(opr.param().data_type ==
Convolution3DForward::Param::DataType::FLOAT,
......@@ -1380,7 +1380,7 @@ void Convolution3DBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1)));
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) {
mgb_assert(!out_grad[1]);
if (wrt_idx == 0) {
......@@ -1781,7 +1781,7 @@ size_t LocalShareForward::get_workspace_size_bytes(
megdnn_opr(), this);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareForward) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad");
......@@ -1862,7 +1862,7 @@ void LocalShareBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1)));
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareBackwardData) {
mgb_assert(!out_grad[1]);
if (wrt_idx == 0) {
......@@ -1919,7 +1919,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes(
megdnn_opr(), this);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) {
mgb_assert(!out_grad[1]);
if (wrt_idx == 0) {
......@@ -1998,7 +1998,7 @@ size_t DeformableConvForward::get_workspace_size_bytes(
megdnn_opr(), this);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(DeformableConvForward) {
mgb_assert(opr.input(0)->dtype() == dtype::Float32(),
"only float data type supported for grad");
......
......@@ -20,7 +20,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward);
MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs")
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Images2NeibsForward) {
mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]);
return Images2NeibsBackward::make(
......
......@@ -21,7 +21,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward);
MEGDNN_OPR_INIT2(LocalForward, "local")
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalForward) {
return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>(
opr, wrt_idx, out_grad);
......@@ -38,7 +38,7 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward);
MEGDNN_OPR_INIT2(GroupLocalForward, "glocal")
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GroupLocalForward) {
return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>(
opr, wrt_idx, out_grad);
......
......@@ -20,7 +20,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward);
MEGDNN_OPR_INIT1(LRNForward, "lrn")
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LRNForward) {
mgb_assert(wrt_idx == 0);
SymbolVar grad = LRNBackward::make(
......
......@@ -19,7 +19,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward);
MEGDNN_OPR_INIT1(PoolingForward, "pooling")
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(PoolingForward) {
mgb_assert(wrt_idx == 0);
SymbolVar grad = PoolingBackward::make(
......
......@@ -40,7 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois,
src.node(), rois.node(), param, config);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ROIAlignForward) {
if (wrt_idx == 0) {
// wrt src
......
......@@ -84,7 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes(
input_shapes, output_shapes);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ROIPoolingForward) {
if (wrt_idx == 2) {
return InvalidGrad::make(opr, wrt_idx);
......@@ -148,7 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make(
return all[0];
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) {
mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2
......
......@@ -126,7 +126,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) {
if (opr.input().size() == 4) {
if (wrt_idx == 0) {
......@@ -351,7 +351,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ResizeForward) {
mgb_assert(opr.input().size() == 2);
if (wrt_idx == 0) {
......@@ -443,7 +443,7 @@ void RemapForward::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RemapForward) {
mgb_assert(opr.input().size() == 2);
if (wrt_idx == 0) {
......
......@@ -83,7 +83,7 @@ void IndexingOneHot::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingOneHot) {
if (wrt_idx == 0) {
return IndexingSetOneHot::make(
......@@ -135,7 +135,7 @@ void IndexingSetOneHot::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1)));
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingSetOneHot) {
SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)};
if (wrt_idx == 0) {
......@@ -169,7 +169,7 @@ void IndexingRemap::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingRemap) {
if (wrt_idx == 1)
return InvalidGrad::make(opr, wrt_idx);
......@@ -466,7 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false);
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx);
......@@ -477,7 +477,7 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
}
#endif
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx);
......@@ -490,7 +490,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
}
#endif
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx);
......@@ -510,7 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET(
BatchedMeshIndexing, "batched_mesh_indexing", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MeshIndexing) {
if (wrt_idx != 0) {
return InvalidGrad::make(opr, wrt_idx);
......@@ -522,7 +522,7 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) {
}
#endif
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
if (wrt_idx != 0) {
return InvalidGrad::make(opr, wrt_idx);
......@@ -539,7 +539,7 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing",
false);
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx);
......@@ -553,7 +553,7 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing,
"batched_incr_mesh_indexing", false);
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx);
......@@ -568,7 +568,7 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
/* ======================== SetMeshIndexing =========================== */
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false);
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
if (wrt_idx >= 2) {
return opr::InvalidGrad::make(opr, wrt_idx);
......@@ -587,7 +587,7 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing,
"batched_set_mesh_indexing", false);
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) {
if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx);
......
......@@ -766,7 +766,7 @@ Copy::NodeProp* Copy::do_make_node_prop() const {
return rst;
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Copy) {
mgb_assert(wrt_idx == 0);
return Copy::make(out_grad[0],
......
......@@ -268,7 +268,7 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) {
return gopr->get_grad_var(wrt_idx);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Loop) {
return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad);
}
......
......@@ -48,7 +48,7 @@ namespace intl {
/* ================= Argmxx ================= */
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Argmax) {
MGB_MARK_USED_VAR(out_grad);
MGB_MARK_USED_VAR(opr);
......@@ -60,7 +60,7 @@ MGB_IMPL_OPR_GRAD(Argmax) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax);
MEGDNN_OPR_INIT1(Argmax, "argmax")
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Argmin) {
MGB_MARK_USED_VAR(out_grad);
MGB_MARK_USED_VAR(opr);
......@@ -87,7 +87,7 @@ std::array<SymbolVar, 2> ArgsortForward::make(
return {node->output(0), node->output(1)};
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ArgsortForward) {
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
if (!out_grad[0])
......@@ -112,7 +112,7 @@ Cumsum::Cumsum(VarNode* opr, const Param& param,
add_input({opr}, AddInputSortType::CUR_ADDED);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Cumsum) {
mgb_assert(out_grad[0] && !out_grad[1]);
auto param = opr.param();
......@@ -263,7 +263,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask,
}
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondTake) {
mgb_assert(out_grad.size() == 3 && !out_grad[2]);
if (wrt_idx == 0 && out_grad[0]) {
......@@ -413,7 +413,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(TopK) {
if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) {
mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]);
......
......@@ -316,7 +316,7 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) {
OperatorNodeConfig().comp_node_arr(sp_cn)));
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AllGather) {
return const_cast<AllGather&>(opr).grad(out_grad);
}
......
......@@ -123,7 +123,7 @@ namespace opr {
namespace intl {
template class RNGOpr<::megdnn::GaussianRNG>;
template class RNGOpr<::megdnn::UniformRNG>;
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
IMPL(GaussianRNG);
IMPL(UniformRNG);
#endif
......
......@@ -46,7 +46,7 @@ void Alloc::outshape_by_symvar_do_get_output_shape(
void Alloc::scn_do_execute() {
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Alloc) {
MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad);
......@@ -125,7 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) {
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr)));
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Linspace) {
if (wrt_idx == 2)
return InvalidGrad::make(opr, wrt_idx);
......@@ -199,7 +199,7 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) {
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr)));
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Eye) {
return InvalidGrad::make(opr, wrt_idx);
}
......
......@@ -165,7 +165,7 @@ void GetVarShape::init_output_static_infer_desc() {
mgr.register_value_infer(output(0),
{SourceType::DEP, deps, infer_value});
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GetVarShape) {
MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad);
......@@ -372,7 +372,7 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp,
inp.node(), tshp.node(), unspec_axis, config);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Reshape) {
if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx);
......@@ -441,7 +441,7 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp,
inp.node(), tshp.node(), config);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Broadcast) {
if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx);
......@@ -586,7 +586,7 @@ VarNode* Dimshuffle::grad(
return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node();
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Dimshuffle) {
return opr.grad(wrt_idx, out_grad);
}
......@@ -649,7 +649,7 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout(
return layout;
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_MARK_USED_VAR(wrt_idx);
return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node();
......@@ -662,7 +662,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true);
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Subtensor) {
if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx);
......@@ -806,7 +806,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
sub.copy_from_fixlayout(val);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetSubtensor) {
if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx);
......@@ -838,7 +838,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
opr->exec(sub.as_megdnn(), val.as_megdnn());
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IncrSubtensor) {
if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx);
......@@ -1112,7 +1112,7 @@ void Split::do_execute(ExecEnv &env) {
}
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Split) {
if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx);
......@@ -1265,7 +1265,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis,
axis, config);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Concat) {
auto axis = opr.axis();
mgb_assert(out_grad.size() == 1);
......@@ -1549,7 +1549,7 @@ void ParamPackSplit::scn_do_execute() {
mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets");
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ParamPackSplit) {
mgb_assert(out_grad.size() == opr.output().size());
SmallVector<SymbolVar> grad;
......
......@@ -255,7 +255,7 @@ void MarkDynamicVar::scn_do_execute() {
o->dev_tensor().copy_from_fixlayout(i->dev_tensor());
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MarkDynamicVar) {
return MarkDynamicVar::make(out_grad.at(0)).node();
}
......@@ -383,7 +383,7 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) {
}
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CallbackInjector) {
MGB_MARK_USED_VAR(wrt_idx);
return out_grad.at(0);
......@@ -408,7 +408,7 @@ SymbolVar MarkNoBroadcastElemwise::make(
input.node(), config);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) {
return out_grad.at(0);
}
......@@ -435,7 +435,7 @@ SymbolVar Identity::make(
return input.insert_single_output_opr<Identity>(input.node(), config);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Identity) {
return out_grad.at(0);
}
......@@ -538,7 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter,
input.node(), grad_getter, config);
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetGrad) {
MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad);
......@@ -700,7 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const {
return ret;
}
#ifdef MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(VirtualLoss) {
mgb_assert(out_grad.size() == 1);
auto mid = opr.input().size() / 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册