提交 ca717806 编写于 作者: M Megvii Engine Team

fix(mgb): fix wrong use of macro MGB_ENABLE_GRAD

GitOrigin-RevId: e66dabee019ffc5e59f204e173a5158c61413ba9
上级 e8bf5bc0
...@@ -760,7 +760,7 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const { ...@@ -760,7 +760,7 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const {
return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); return ModeTrait::from_mode(m_param.mode).grad(out_grad, this);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CollectiveComm) { MGB_IMPL_OPR_GRAD(CollectiveComm) {
mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad");
return opr.grad(out_grad[0]); return opr.grad(out_grad[0]);
......
...@@ -119,7 +119,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { ...@@ -119,7 +119,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
return prop; return prop;
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RemoteSend) { MGB_IMPL_OPR_GRAD(RemoteSend) {
mgb_assert(opr.is_grad()); mgb_assert(opr.is_grad());
return RemoteRecv::make(opr.key() + ":grad", return RemoteRecv::make(opr.key() + ":grad",
......
...@@ -552,7 +552,7 @@ void Elemwise::call_megdnn_opr_exec( ...@@ -552,7 +552,7 @@ void Elemwise::call_megdnn_opr_exec(
opr->exec(inp, out); opr->exec(inp, out);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Elemwise) { MGB_IMPL_OPR_GRAD(Elemwise) {
SymbolVar i[5]; SymbolVar i[5];
SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)), SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)),
...@@ -822,7 +822,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const { ...@@ -822,7 +822,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const {
return ret; return ret;
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(TypeCvt) { MGB_IMPL_OPR_GRAD(TypeCvt) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype(); auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype();
...@@ -973,7 +973,7 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) { ...@@ -973,7 +973,7 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AddUpdate) { MGB_IMPL_OPR_GRAD(AddUpdate) {
// actually valid, just not implemented // actually valid, just not implemented
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -1712,7 +1712,7 @@ void Reduce::create_megdnn_opr() { ...@@ -1712,7 +1712,7 @@ void Reduce::create_megdnn_opr() {
create_operator<megdnn::Reduce>()); create_operator<megdnn::Reduce>());
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Reduce) { MGB_IMPL_OPR_GRAD(Reduce) {
for (size_t i = 1; i < opr.output().size(); ++ i) for (size_t i = 1; i < opr.output().size(); ++ i)
mgb_assert(!out_grad[i]); mgb_assert(!out_grad[i]);
...@@ -1798,7 +1798,7 @@ void PowC::init_output_static_infer_desc() { ...@@ -1798,7 +1798,7 @@ void PowC::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(PowC) { MGB_IMPL_OPR_GRAD(PowC) {
auto exp = opr.param().exp; auto exp = opr.param().exp;
return (exp * SymbolVar{out_grad[0]} * return (exp * SymbolVar{out_grad[0]} *
......
...@@ -106,7 +106,7 @@ void MatrixMul::scn_do_execute() { ...@@ -106,7 +106,7 @@ void MatrixMul::scn_do_execute() {
MGB_FINALLY({ tparam = this->param(); }); MGB_FINALLY({ tparam = this->param(); });
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MatrixMul) { MGB_IMPL_OPR_GRAD(MatrixMul) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad"); "only float data type supported for grad");
...@@ -226,7 +226,7 @@ void BatchedMatrixMul::scn_do_execute() { ...@@ -226,7 +226,7 @@ void BatchedMatrixMul::scn_do_execute() {
MGB_FINALLY({ tparam = this->param(); }); MGB_FINALLY({ tparam = this->param(); });
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad"); "only float data type supported for grad");
...@@ -331,7 +331,7 @@ void Dot::add_input_layout_constraint() { ...@@ -331,7 +331,7 @@ void Dot::add_input_layout_constraint() {
input(1)->add_layout_constraint(check); input(1)->add_layout_constraint(check);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Dot) { MGB_IMPL_OPR_GRAD(Dot) {
auto other_input = opr.input(wrt_idx == 0 ? 1 : 0); auto other_input = opr.input(wrt_idx == 0 ? 1 : 0);
auto ishp0 = opr::GetVarShape::make(opr.input(0)), auto ishp0 = opr::GetVarShape::make(opr.input(0)),
...@@ -357,7 +357,7 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) { ...@@ -357,7 +357,7 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse);
MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv") MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv")
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MatrixInverse) { MGB_IMPL_OPR_GRAD(MatrixInverse) {
SymbolVar a = opr.output(0); SymbolVar a = opr.output(0);
// TODO: use unified MatrixMul interface when we have it // TODO: use unified MatrixMul interface when we have it
...@@ -395,7 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) : ...@@ -395,7 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) :
} }
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
namespace { namespace {
/*! /*!
...@@ -489,7 +489,7 @@ OP(*, {}, {}) ...@@ -489,7 +489,7 @@ OP(*, {}, {})
} // anonymous namespace } // anonymous namespace
#endif #endif
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SVD) { MGB_IMPL_OPR_GRAD(SVD) {
/** /**
* The formula is copied from * The formula is copied from
......
...@@ -818,7 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input, ...@@ -818,7 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input,
return input; return input;
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondExecMark) { MGB_IMPL_OPR_GRAD(CondExecMark) {
if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) { if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) {
return nullptr; return nullptr;
...@@ -1227,7 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const { ...@@ -1227,7 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
return ret; return ret;
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondExecMerge) { MGB_IMPL_OPR_GRAD(CondExecMerge) {
using Mode = CondExecMerge::Param::Mode; using Mode = CondExecMerge::Param::Mode;
if (opr.param().mode == Mode::SUM_COND_OUT && if (opr.param().mode == Mode::SUM_COND_OUT &&
......
...@@ -91,7 +91,7 @@ void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) { ...@@ -91,7 +91,7 @@ void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) { MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) {
if (wrt_idx == 0) { if (wrt_idx == 0) {
// wrt src // wrt src
......
...@@ -240,7 +240,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { ...@@ -240,7 +240,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {
} }
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchNormForward) { MGB_IMPL_OPR_GRAD(BatchNormForward) {
mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING, mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING,
"batch norm could only take grad in training mode"); "batch norm could only take grad in training mode");
......
...@@ -1012,7 +1012,7 @@ void ConvolutionForward::init_output_dtype() { ...@@ -1012,7 +1012,7 @@ void ConvolutionForward::init_output_dtype() {
output(0)->dtype(output_dtype); output(0)->dtype(output_dtype);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionForward) { MGB_IMPL_OPR_GRAD(ConvolutionForward) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad"); "only float data type supported for grad");
...@@ -1175,7 +1175,7 @@ void ConvolutionBackwardData::scn_do_execute() { ...@@ -1175,7 +1175,7 @@ void ConvolutionBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
...@@ -1229,7 +1229,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes( ...@@ -1229,7 +1229,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes(
megdnn_opr(), this); megdnn_opr(), this);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
...@@ -1285,7 +1285,7 @@ void Convolution3DForward::init_output_dtype() { ...@@ -1285,7 +1285,7 @@ void Convolution3DForward::init_output_dtype() {
} }
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Convolution3DForward) { MGB_IMPL_OPR_GRAD(Convolution3DForward) {
mgb_assert(opr.param().data_type == mgb_assert(opr.param().data_type ==
Convolution3DForward::Param::DataType::FLOAT, Convolution3DForward::Param::DataType::FLOAT,
...@@ -1380,7 +1380,7 @@ void Convolution3DBackwardData::scn_do_execute() { ...@@ -1380,7 +1380,7 @@ void Convolution3DBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
...@@ -1781,7 +1781,7 @@ size_t LocalShareForward::get_workspace_size_bytes( ...@@ -1781,7 +1781,7 @@ size_t LocalShareForward::get_workspace_size_bytes(
megdnn_opr(), this); megdnn_opr(), this);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareForward) { MGB_IMPL_OPR_GRAD(LocalShareForward) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad"); "only float data type supported for grad");
...@@ -1862,7 +1862,7 @@ void LocalShareBackwardData::scn_do_execute() { ...@@ -1862,7 +1862,7 @@ void LocalShareBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { MGB_IMPL_OPR_GRAD(LocalShareBackwardData) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
...@@ -1919,7 +1919,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes( ...@@ -1919,7 +1919,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes(
megdnn_opr(), this); megdnn_opr(), this);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
...@@ -1998,7 +1998,7 @@ size_t DeformableConvForward::get_workspace_size_bytes( ...@@ -1998,7 +1998,7 @@ size_t DeformableConvForward::get_workspace_size_bytes(
megdnn_opr(), this); megdnn_opr(), this);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(DeformableConvForward) { MGB_IMPL_OPR_GRAD(DeformableConvForward) {
mgb_assert(opr.input(0)->dtype() == dtype::Float32(), mgb_assert(opr.input(0)->dtype() == dtype::Float32(),
"only float data type supported for grad"); "only float data type supported for grad");
......
...@@ -20,7 +20,7 @@ using namespace opr; ...@@ -20,7 +20,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward);
MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs") MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs")
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Images2NeibsForward) { MGB_IMPL_OPR_GRAD(Images2NeibsForward) {
mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]);
return Images2NeibsBackward::make( return Images2NeibsBackward::make(
......
...@@ -21,7 +21,7 @@ using namespace opr; ...@@ -21,7 +21,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward);
MEGDNN_OPR_INIT2(LocalForward, "local") MEGDNN_OPR_INIT2(LocalForward, "local")
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalForward) { MGB_IMPL_OPR_GRAD(LocalForward) {
return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>( return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>(
opr, wrt_idx, out_grad); opr, wrt_idx, out_grad);
...@@ -38,7 +38,7 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false); ...@@ -38,7 +38,7 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward);
MEGDNN_OPR_INIT2(GroupLocalForward, "glocal") MEGDNN_OPR_INIT2(GroupLocalForward, "glocal")
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GroupLocalForward) { MGB_IMPL_OPR_GRAD(GroupLocalForward) {
return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>( return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>(
opr, wrt_idx, out_grad); opr, wrt_idx, out_grad);
......
...@@ -20,7 +20,7 @@ using namespace opr; ...@@ -20,7 +20,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward);
MEGDNN_OPR_INIT1(LRNForward, "lrn") MEGDNN_OPR_INIT1(LRNForward, "lrn")
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LRNForward) { MGB_IMPL_OPR_GRAD(LRNForward) {
mgb_assert(wrt_idx == 0); mgb_assert(wrt_idx == 0);
SymbolVar grad = LRNBackward::make( SymbolVar grad = LRNBackward::make(
......
...@@ -19,7 +19,7 @@ using namespace opr; ...@@ -19,7 +19,7 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward);
MEGDNN_OPR_INIT1(PoolingForward, "pooling") MEGDNN_OPR_INIT1(PoolingForward, "pooling")
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(PoolingForward) { MGB_IMPL_OPR_GRAD(PoolingForward) {
mgb_assert(wrt_idx == 0); mgb_assert(wrt_idx == 0);
SymbolVar grad = PoolingBackward::make( SymbolVar grad = PoolingBackward::make(
......
...@@ -40,7 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois, ...@@ -40,7 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois,
src.node(), rois.node(), param, config); src.node(), rois.node(), param, config);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ROIAlignForward) { MGB_IMPL_OPR_GRAD(ROIAlignForward) {
if (wrt_idx == 0) { if (wrt_idx == 0) {
// wrt src // wrt src
......
...@@ -84,7 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes( ...@@ -84,7 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes(
input_shapes, output_shapes); input_shapes, output_shapes);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ROIPoolingForward) { MGB_IMPL_OPR_GRAD(ROIPoolingForward) {
if (wrt_idx == 2) { if (wrt_idx == 2) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -148,7 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make( ...@@ -148,7 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make(
return all[0]; return all[0];
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) {
mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2 mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2
......
...@@ -126,7 +126,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) { ...@@ -126,7 +126,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) {
if (opr.input().size() == 4) { if (opr.input().size() == 4) {
if (wrt_idx == 0) { if (wrt_idx == 0) {
...@@ -351,7 +351,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray& deps) { ...@@ -351,7 +351,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ResizeForward) { MGB_IMPL_OPR_GRAD(ResizeForward) {
mgb_assert(opr.input().size() == 2); mgb_assert(opr.input().size() == 2);
if (wrt_idx == 0) { if (wrt_idx == 0) {
...@@ -443,7 +443,7 @@ void RemapForward::init_output_dtype() { ...@@ -443,7 +443,7 @@ void RemapForward::init_output_dtype() {
output(0)->dtype(input(0)->dtype()); output(0)->dtype(input(0)->dtype());
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RemapForward) { MGB_IMPL_OPR_GRAD(RemapForward) {
mgb_assert(opr.input().size() == 2); mgb_assert(opr.input().size() == 2);
if (wrt_idx == 0) { if (wrt_idx == 0) {
......
...@@ -83,7 +83,7 @@ void IndexingOneHot::init_output_dtype() { ...@@ -83,7 +83,7 @@ void IndexingOneHot::init_output_dtype() {
output(0)->dtype(input(0)->dtype()); output(0)->dtype(input(0)->dtype());
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingOneHot) { MGB_IMPL_OPR_GRAD(IndexingOneHot) {
if (wrt_idx == 0) { if (wrt_idx == 0) {
return IndexingSetOneHot::make( return IndexingSetOneHot::make(
...@@ -135,7 +135,7 @@ void IndexingSetOneHot::scn_do_execute() { ...@@ -135,7 +135,7 @@ void IndexingSetOneHot::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { MGB_IMPL_OPR_GRAD(IndexingSetOneHot) {
SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)}; SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)};
if (wrt_idx == 0) { if (wrt_idx == 0) {
...@@ -169,7 +169,7 @@ void IndexingRemap::init_output_dtype() { ...@@ -169,7 +169,7 @@ void IndexingRemap::init_output_dtype() {
output(0)->dtype(input(0)->dtype()); output(0)->dtype(input(0)->dtype());
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingRemap) { MGB_IMPL_OPR_GRAD(IndexingRemap) {
if (wrt_idx == 1) if (wrt_idx == 1)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -466,7 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( ...@@ -466,7 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false);
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -477,7 +477,7 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { ...@@ -477,7 +477,7 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
} }
#endif #endif
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
if (wrt_idx >= 2) if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -490,7 +490,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { ...@@ -490,7 +490,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
} }
#endif #endif
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
if (wrt_idx >= 2) if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -510,7 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET( ...@@ -510,7 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET(
BatchedMeshIndexing, "batched_mesh_indexing", false, BatchedMeshIndexing, "batched_mesh_indexing", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MeshIndexing) { MGB_IMPL_OPR_GRAD(MeshIndexing) {
if (wrt_idx != 0) { if (wrt_idx != 0) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -522,7 +522,7 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) { ...@@ -522,7 +522,7 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) {
} }
#endif #endif
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
if (wrt_idx != 0) { if (wrt_idx != 0) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -539,7 +539,7 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { ...@@ -539,7 +539,7 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing",
false); false);
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
if (wrt_idx > 2) { if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
...@@ -553,7 +553,7 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { ...@@ -553,7 +553,7 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing, MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing,
"batched_incr_mesh_indexing", false); "batched_incr_mesh_indexing", false);
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
if (wrt_idx > 2) { if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
...@@ -568,7 +568,7 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { ...@@ -568,7 +568,7 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
/* ======================== SetMeshIndexing =========================== */ /* ======================== SetMeshIndexing =========================== */
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false); MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false);
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetMeshIndexing) { MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
if (wrt_idx >= 2) { if (wrt_idx >= 2) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
...@@ -587,7 +587,7 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) { ...@@ -587,7 +587,7 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing, MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing,
"batched_set_mesh_indexing", false); "batched_set_mesh_indexing", false);
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) {
if (wrt_idx > 2) { if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
......
...@@ -766,7 +766,7 @@ Copy::NodeProp* Copy::do_make_node_prop() const { ...@@ -766,7 +766,7 @@ Copy::NodeProp* Copy::do_make_node_prop() const {
return rst; return rst;
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Copy) { MGB_IMPL_OPR_GRAD(Copy) {
mgb_assert(wrt_idx == 0); mgb_assert(wrt_idx == 0);
return Copy::make(out_grad[0], return Copy::make(out_grad[0],
......
...@@ -268,7 +268,7 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) { ...@@ -268,7 +268,7 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) {
return gopr->get_grad_var(wrt_idx); return gopr->get_grad_var(wrt_idx);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Loop) { MGB_IMPL_OPR_GRAD(Loop) {
return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad); return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad);
} }
......
...@@ -48,7 +48,7 @@ namespace intl { ...@@ -48,7 +48,7 @@ namespace intl {
/* ================= Argmxx ================= */ /* ================= Argmxx ================= */
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Argmax) { MGB_IMPL_OPR_GRAD(Argmax) {
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
MGB_MARK_USED_VAR(opr); MGB_MARK_USED_VAR(opr);
...@@ -60,7 +60,7 @@ MGB_IMPL_OPR_GRAD(Argmax) { ...@@ -60,7 +60,7 @@ MGB_IMPL_OPR_GRAD(Argmax) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax); MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax);
MEGDNN_OPR_INIT1(Argmax, "argmax") MEGDNN_OPR_INIT1(Argmax, "argmax")
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Argmin) { MGB_IMPL_OPR_GRAD(Argmin) {
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
MGB_MARK_USED_VAR(opr); MGB_MARK_USED_VAR(opr);
...@@ -87,7 +87,7 @@ std::array<SymbolVar, 2> ArgsortForward::make( ...@@ -87,7 +87,7 @@ std::array<SymbolVar, 2> ArgsortForward::make(
return {node->output(0), node->output(1)}; return {node->output(0), node->output(1)};
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ArgsortForward) { MGB_IMPL_OPR_GRAD(ArgsortForward) {
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
if (!out_grad[0]) if (!out_grad[0])
...@@ -112,7 +112,7 @@ Cumsum::Cumsum(VarNode* opr, const Param& param, ...@@ -112,7 +112,7 @@ Cumsum::Cumsum(VarNode* opr, const Param& param,
add_input({opr}, AddInputSortType::CUR_ADDED); add_input({opr}, AddInputSortType::CUR_ADDED);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Cumsum) { MGB_IMPL_OPR_GRAD(Cumsum) {
mgb_assert(out_grad[0] && !out_grad[1]); mgb_assert(out_grad[0] && !out_grad[1]);
auto param = opr.param(); auto param = opr.param();
...@@ -263,7 +263,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask, ...@@ -263,7 +263,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask,
} }
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondTake) { MGB_IMPL_OPR_GRAD(CondTake) {
mgb_assert(out_grad.size() == 3 && !out_grad[2]); mgb_assert(out_grad.size() == 3 && !out_grad[2]);
if (wrt_idx == 0 && out_grad[0]) { if (wrt_idx == 0 && out_grad[0]) {
...@@ -413,7 +413,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) { ...@@ -413,7 +413,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(TopK) { MGB_IMPL_OPR_GRAD(TopK) {
if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) {
mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]);
......
...@@ -316,7 +316,7 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) { ...@@ -316,7 +316,7 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) {
OperatorNodeConfig().comp_node_arr(sp_cn))); OperatorNodeConfig().comp_node_arr(sp_cn)));
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AllGather) { MGB_IMPL_OPR_GRAD(AllGather) {
return const_cast<AllGather&>(opr).grad(out_grad); return const_cast<AllGather&>(opr).grad(out_grad);
} }
......
...@@ -123,7 +123,7 @@ namespace opr { ...@@ -123,7 +123,7 @@ namespace opr {
namespace intl { namespace intl {
template class RNGOpr<::megdnn::GaussianRNG>; template class RNGOpr<::megdnn::GaussianRNG>;
template class RNGOpr<::megdnn::UniformRNG>; template class RNGOpr<::megdnn::UniformRNG>;
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
IMPL(GaussianRNG); IMPL(GaussianRNG);
IMPL(UniformRNG); IMPL(UniformRNG);
#endif #endif
......
...@@ -46,7 +46,7 @@ void Alloc::outshape_by_symvar_do_get_output_shape( ...@@ -46,7 +46,7 @@ void Alloc::outshape_by_symvar_do_get_output_shape(
void Alloc::scn_do_execute() { void Alloc::scn_do_execute() {
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Alloc) { MGB_IMPL_OPR_GRAD(Alloc) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
...@@ -125,7 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) { ...@@ -125,7 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) {
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr)));
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Linspace) { MGB_IMPL_OPR_GRAD(Linspace) {
if (wrt_idx == 2) if (wrt_idx == 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -199,7 +199,7 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) { ...@@ -199,7 +199,7 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) {
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr)));
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Eye) { MGB_IMPL_OPR_GRAD(Eye) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
} }
......
...@@ -165,7 +165,7 @@ void GetVarShape::init_output_static_infer_desc() { ...@@ -165,7 +165,7 @@ void GetVarShape::init_output_static_infer_desc() {
mgr.register_value_infer(output(0), mgr.register_value_infer(output(0),
{SourceType::DEP, deps, infer_value}); {SourceType::DEP, deps, infer_value});
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GetVarShape) { MGB_IMPL_OPR_GRAD(GetVarShape) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
...@@ -372,7 +372,7 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp, ...@@ -372,7 +372,7 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp,
inp.node(), tshp.node(), unspec_axis, config); inp.node(), tshp.node(), unspec_axis, config);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Reshape) { MGB_IMPL_OPR_GRAD(Reshape) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -441,7 +441,7 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp, ...@@ -441,7 +441,7 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp,
inp.node(), tshp.node(), config); inp.node(), tshp.node(), config);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Broadcast) { MGB_IMPL_OPR_GRAD(Broadcast) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -586,7 +586,7 @@ VarNode* Dimshuffle::grad( ...@@ -586,7 +586,7 @@ VarNode* Dimshuffle::grad(
return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node(); return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node();
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Dimshuffle) { MGB_IMPL_OPR_GRAD(Dimshuffle) {
return opr.grad(wrt_idx, out_grad); return opr.grad(wrt_idx, out_grad);
} }
...@@ -649,7 +649,7 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout( ...@@ -649,7 +649,7 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout(
return layout; return layout;
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AxisAddRemove) { MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node();
...@@ -662,7 +662,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) { ...@@ -662,7 +662,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true);
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Subtensor) { MGB_IMPL_OPR_GRAD(Subtensor) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -806,7 +806,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { ...@@ -806,7 +806,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
sub.copy_from_fixlayout(val); sub.copy_from_fixlayout(val);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetSubtensor) { MGB_IMPL_OPR_GRAD(SetSubtensor) {
if (wrt_idx >= 2) if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -838,7 +838,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { ...@@ -838,7 +838,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
opr->exec(sub.as_megdnn(), val.as_megdnn()); opr->exec(sub.as_megdnn(), val.as_megdnn());
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IncrSubtensor) { MGB_IMPL_OPR_GRAD(IncrSubtensor) {
if (wrt_idx >= 2) if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -1112,7 +1112,7 @@ void Split::do_execute(ExecEnv &env) { ...@@ -1112,7 +1112,7 @@ void Split::do_execute(ExecEnv &env) {
} }
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Split) { MGB_IMPL_OPR_GRAD(Split) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
...@@ -1265,7 +1265,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis, ...@@ -1265,7 +1265,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis,
axis, config); axis, config);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Concat) { MGB_IMPL_OPR_GRAD(Concat) {
auto axis = opr.axis(); auto axis = opr.axis();
mgb_assert(out_grad.size() == 1); mgb_assert(out_grad.size() == 1);
...@@ -1549,7 +1549,7 @@ void ParamPackSplit::scn_do_execute() { ...@@ -1549,7 +1549,7 @@ void ParamPackSplit::scn_do_execute() {
mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets"); mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets");
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ParamPackSplit) { MGB_IMPL_OPR_GRAD(ParamPackSplit) {
mgb_assert(out_grad.size() == opr.output().size()); mgb_assert(out_grad.size() == opr.output().size());
SmallVector<SymbolVar> grad; SmallVector<SymbolVar> grad;
......
...@@ -255,7 +255,7 @@ void MarkDynamicVar::scn_do_execute() { ...@@ -255,7 +255,7 @@ void MarkDynamicVar::scn_do_execute() {
o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); o->dev_tensor().copy_from_fixlayout(i->dev_tensor());
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MarkDynamicVar) { MGB_IMPL_OPR_GRAD(MarkDynamicVar) {
return MarkDynamicVar::make(out_grad.at(0)).node(); return MarkDynamicVar::make(out_grad.at(0)).node();
} }
...@@ -383,7 +383,7 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) { ...@@ -383,7 +383,7 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) {
} }
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CallbackInjector) { MGB_IMPL_OPR_GRAD(CallbackInjector) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
return out_grad.at(0); return out_grad.at(0);
...@@ -408,7 +408,7 @@ SymbolVar MarkNoBroadcastElemwise::make( ...@@ -408,7 +408,7 @@ SymbolVar MarkNoBroadcastElemwise::make(
input.node(), config); input.node(), config);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) {
return out_grad.at(0); return out_grad.at(0);
} }
...@@ -435,7 +435,7 @@ SymbolVar Identity::make( ...@@ -435,7 +435,7 @@ SymbolVar Identity::make(
return input.insert_single_output_opr<Identity>(input.node(), config); return input.insert_single_output_opr<Identity>(input.node(), config);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Identity) { MGB_IMPL_OPR_GRAD(Identity) {
return out_grad.at(0); return out_grad.at(0);
} }
...@@ -538,7 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter, ...@@ -538,7 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter,
input.node(), grad_getter, config); input.node(), grad_getter, config);
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetGrad) { MGB_IMPL_OPR_GRAD(SetGrad) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
...@@ -700,7 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const { ...@@ -700,7 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const {
return ret; return ret;
} }
#ifdef MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(VirtualLoss) { MGB_IMPL_OPR_GRAD(VirtualLoss) {
mgb_assert(out_grad.size() == 1); mgb_assert(out_grad.size() == 1);
auto mid = opr.input().size() / 2; auto mid = opr.input().size() / 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册