diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index ba43bd8ec63f86ac226e1c4fd2c4aa3ce5eca0a3..f4546885b1fe3988de39cea207ed01434c811d52 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -393,13 +393,11 @@ apply_result_t apply_grad(ApplyContext& ctx) { auto&& it = registry.find(ctx.op->dyn_typeinfo()); if (it != registry.end()) { auto&& maker = grad_fn_holder.emplace().maker(ctx); - try { - auto ret = it->second(ctx, maker); + if (auto ret = it->second(ctx, maker)) { maker.finalize(); - return ret; - } catch (GradRuleFallback&) { - grad_fn_holder.reset(); + return *ret; } + grad_fn_holder.reset(); } return backward_graph_grad_rule(ctx, grad_fn_holder); }(); diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index c5068e207e93241e2614e4c619fccf05e816e80e..4ba63485db3128dc3b3b25565c6d7718f0fe9034 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -16,6 +16,7 @@ #include #include +#include namespace mgb::imperative::python { @@ -154,7 +155,7 @@ public: Maker maker(ApplyContext& ctx) {return {*this, ctx};} }; -using GradRuleFn = std::function; +using GradRuleFn = std::function(ApplyContext&, CustomBackward::Maker&)>; std::unordered_map& grad_rule_registry(); diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 2f495b337ab9a7f9f1c8675ff3f61464f4630305..c192241a1cbf350e710bba9ced4ba015a25c94b2 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -37,14 +37,14 @@ std::shared_ptr broadcast_to(Tensor* x, Tensor* s) { std::shared_ptr make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) { HostTensorND scalar{cn, {{1}, dtype}}; - std:memset(scalar.raw_ptr(), 0, dtype.size()); + std::memset(scalar.raw_ptr(), 0, dtype.size()); interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); auto&& t = std::make_shared(handle); auto res = broadcast_to(t.get(), shape); return res; } -apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { +std::optional elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto& op = ctx.op->cast_final_safe(); if (op.mode == Elemwise::Mode::ADD) { mgb_assert(ctx.nargs == 2); @@ -71,10 +71,10 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make }); return apply(ctx); } - throw GradRuleFallback(); + return {}; } -apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { +std::optional reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { mgb_assert(ctx.nargs == 2); std::array, 2> input_shapes; for (size_t i = 0; i < 2; ++i) { @@ -100,7 +100,7 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker return apply(ctx); } -apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { +std::optional subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto&& op = ctx.op->cast_final_safe(); auto&& grad_op = SetSubtensor::make(op.items); SmallVector> inputs; @@ -130,7 +130,7 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak return apply(ctx); } -apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { +std::optional indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto&& op = ctx.op->cast_final_safe(); auto&& grad_op = IndexingSetMultiAxisVec::make(op.items); SmallVector> inputs; @@ -160,11 +160,11 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward: return apply(ctx); } -apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { +std::optional reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto& op = ctx.op->cast_final_safe(); if (op.mode == Reduce::Mode::SUM) { if (ctx.nargs != 1) { - throw GradRuleFallback(); + return {}; } std::array, 1> input_shapes; if (input_requires_grad(ctx, 0)) { @@ -182,10 +182,10 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) }); return apply(ctx); } - throw GradRuleFallback(); + return {}; } -apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { +std::optional addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto&& op = ctx.op->cast_final_safe(); mgb_assert(ctx.nargs == 1); bool flag = input_requires_grad(ctx, 0); @@ -204,7 +204,7 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker return apply(ctx); } -apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { +std::optional removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto&& op = ctx.op->cast_final_safe(); mgb_assert(ctx.nargs == 1); bool flag = input_requires_grad(ctx, 0); @@ -223,7 +223,7 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma return apply(ctx); } -apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { +std::optional fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { mgb_assert(ctx.nargs == 1); maker.output_size(1).output_captured(0, false); maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) {