提交 288c2e08 编写于 作者: M Megvii Engine Team

fix(mge/autodiff): fix expand_dims and grad rule fallback

GitOrigin-RevId: 4aae771222aa1e1a0d8bebd589f4e32c59044f4c
上级 a5609f3b
......@@ -309,6 +309,8 @@ public:
auto& emplace(Args&&... args) {
return get()->backward.emplace<T>(std::forward<Args>(args)...);
}
void reset() { grad_fn = nullptr; }
};
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
......@@ -398,7 +400,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
maker.finalize();
return ret;
} catch (GradRuleFallback&) {
grad_fn_holder.emplace<std::monostate>();
grad_fn_holder.reset();
}
}
return backward_graph_grad_rule(ctx, grad_fn_holder);
......
......@@ -177,11 +177,27 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
throw GradRuleFallback();
}
template<typename T, typename U>
apply_result_t axisAddRemove_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<T>();
apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<AddAxis>();
mgb_assert(ctx.nargs == 1);
auto&& grad_op = U::make(op.axis);
auto&& grad_op = RemoveAxis::make(op.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0];
return ret;
});
return apply(ctx);
}
apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
mgb_assert(ctx.nargs == 1);
auto&& grad_op = AddAxis::make(op.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end());
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
......@@ -201,8 +217,8 @@ struct Init {
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
reg.emplace(AddAxis::typeinfo(), axisAddRemove_grad_rule<AddAxis, RemoveAxis>);
reg.emplace(RemoveAxis::typeinfo(), axisAddRemove_grad_rule<RemoveAxis, AddAxis>);
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule);
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule);
}
} _;
......
......@@ -335,3 +335,25 @@ def test_Reduce_mean():
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy())
def test_addAxis():
x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.expand_dims(x, [2, 3])
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())
def test_removeAxis():
x_np = np.random.rand(3, 3, 1, 1).astype("float32")
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.squeeze(x, [2, 3])
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册