提交 000517c6 编写于 作者: M Megvii Engine Team

fix(grad): stop using exception in grad_override

GitOrigin-RevId: 00ae38d48b91e93f6de31d208de7ca469d79a938
上级 0bb058c1
......@@ -393,13 +393,11 @@ apply_result_t apply_grad(ApplyContext& ctx) {
auto&& it = registry.find(ctx.op->dyn_typeinfo());
if (it != registry.end()) {
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
try {
auto ret = it->second(ctx, maker);
if (auto ret = it->second(ctx, maker)) {
maker.finalize();
return ret;
} catch (GradRuleFallback&) {
grad_fn_holder.reset();
return *ret;
}
grad_fn_holder.reset();
}
return backward_graph_grad_rule(ctx, grad_fn_holder);
}();
......
......@@ -16,6 +16,7 @@
#include <megbrain/utils/small_vector.h>
#include <memory>
#include <optional>
namespace mgb::imperative::python {
......@@ -154,7 +155,7 @@ public:
Maker maker(ApplyContext& ctx) {return {*this, ctx};}
};
using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>;
using GradRuleFn = std::function<std::optional<apply_result_t>(ApplyContext&, CustomBackward::Maker&)>;
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
......
......@@ -37,14 +37,14 @@ std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) {
HostTensorND scalar{cn, {{1}, dtype}};
std:memset(scalar.raw_ptr(), 0, dtype.size());
std::memset(scalar.raw_ptr(), 0, dtype.size());
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
auto&& t = std::make_shared<Tensor>(handle);
auto res = broadcast_to(t.get(), shape);
return res;
}
apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Elemwise>();
if (op.mode == Elemwise::Mode::ADD) {
mgb_assert(ctx.nargs == 2);
......@@ -71,10 +71,10 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make
});
return apply(ctx);
}
throw GradRuleFallback();
return {};
}
apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
......@@ -100,7 +100,7 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
return apply(ctx);
}
apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<Subtensor>();
auto&& grad_op = SetSubtensor::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
......@@ -130,7 +130,7 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
return apply(ctx);
}
apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
......@@ -160,11 +160,11 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
return apply(ctx);
}
apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Reduce>();
if (op.mode == Reduce::Mode::SUM) {
if (ctx.nargs != 1) {
throw GradRuleFallback();
return {};
}
std::array<std::shared_ptr<Tensor>, 1> input_shapes;
if (input_requires_grad(ctx, 0)) {
......@@ -182,10 +182,10 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
});
return apply(ctx);
}
throw GradRuleFallback();
return {};
}
apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<AddAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
......@@ -204,7 +204,7 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
return apply(ctx);
}
apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
......@@ -223,7 +223,7 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma
return apply(ctx);
}
apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 1);
maker.output_size(1).output_captured(0, false);
maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册