From 288c2e08b52f215eada60e91d79469bb097e6071 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 4 Jan 2021 16:44:44 +0800 Subject: [PATCH] fix(mge/autodiff): fix expand_dims and grad rule fallback GitOrigin-RevId: 4aae771222aa1e1a0d8bebd589f4e32c59044f4c --- imperative/python/src/grad.cpp | 4 ++- imperative/python/src/grad_override.cpp | 28 +++++++++++++++---- .../python/test/unit/core/test_autodiff.py | 22 +++++++++++++++ 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index ba340ab3..f4ae8cfc 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -309,6 +309,8 @@ public: auto& emplace(Args&&... args) { return get()->backward.emplace(std::forward(args)...); } + + void reset() { grad_fn = nullptr; } }; apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { @@ -398,7 +400,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { maker.finalize(); return ret; } catch (GradRuleFallback&) { - grad_fn_holder.emplace(); + grad_fn_holder.reset(); } } return backward_graph_grad_rule(ctx, grad_fn_holder); diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 9691290d..64217d44 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -177,11 +177,27 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) throw GradRuleFallback(); } -template -apply_result_t axisAddRemove_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { - auto&& op = ctx.op->cast_final_safe(); +apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + auto&& op = ctx.op->cast_final_safe(); mgb_assert(ctx.nargs == 1); - auto&& grad_op = U::make(op.axis); + auto&& grad_op = RemoveAxis::make(op.axis); + std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater()); + maker.output_size(1).output_captured(0, false); + maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { + mgb_assert(ngrads == 1); + Tensor* grad = grads[0]; + apply_result_t ret(1); + ret[0] = python::apply(grad_op_, grad)[0]; + return ret; + }); + return apply(ctx); +} + +apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + auto&& op = ctx.op->cast_final_safe(); + mgb_assert(ctx.nargs == 1); + auto&& grad_op = AddAxis::make(op.axis); + std::sort(grad_op->axis.begin(), grad_op->axis.end()); maker.output_size(1).output_captured(0, false); maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { mgb_assert(ngrads == 1); @@ -201,8 +217,8 @@ struct Init { reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); reg.emplace(Reduce::typeinfo(), reduce_grad_rule); - reg.emplace(AddAxis::typeinfo(), axisAddRemove_grad_rule); - reg.emplace(RemoveAxis::typeinfo(), axisAddRemove_grad_rule); + reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); + reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); } } _; diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 9c1efdf7..2e7fc29b 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -335,3 +335,25 @@ def test_Reduce_mean(): grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy()) + + +def test_addAxis(): + x_np = np.random.rand(3, 3).astype("float32") + x = mge.Tensor(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = F.expand_dims(x, [2, 3]) + + grad(y, F.ones_like(y)) + np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) + + +def test_removeAxis(): + x_np = np.random.rand(3, 3, 1, 1).astype("float32") + x = mge.Tensor(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = F.squeeze(x, [2, 3]) + + grad(y, F.ones_like(y)) + np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) -- GitLab