From 2f3d185de67548fc6ccf3ae8b8086218e3aea862 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 May 2020 15:17:26 +0800 Subject: [PATCH] fix(mgb): register invalid grad for AddUpdate GitOrigin-RevId: f9bbf570dcdc2aca714f3d9f3b7831c82bc7c3a9 --- src/opr/impl/basic_arith.cpp | 5 +++++ src/opr/test/basic_arith/others.cpp | 23 ++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 489d9689..249e6504 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -947,6 +947,11 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } +MGB_IMPL_OPR_GRAD(AddUpdate) { + // actually valid, just not implemented + return InvalidGrad::make(opr, wrt_idx); +} + /* =========================== Reduce =========================== */ class Reduce::KernScheduler { diff --git a/src/opr/test/basic_arith/others.cpp b/src/opr/test/basic_arith/others.cpp index 40ec5dc0..aec4238f 100644 --- a/src/opr/test/basic_arith/others.cpp +++ b/src/opr/test/basic_arith/others.cpp @@ -372,7 +372,7 @@ TEST(TestOprBasicArith, AddUpdateVolatile) { for (size_t i = 0; i < SIZE * 2; i ++) { MGB_ASSERT_FLOAT_EQ(expect(i), z[i]); } - mgb_assert(host_sub.shape().total_nr_elems() == 4 && + mgb_assert(host_sub.shape().total_nr_elems() == 4 && host_sub.layout().is_contiguous()); for (size_t i = 0; i < 4; ++ i) { size_t idx = i * (SIZE >> 1); @@ -390,6 +390,27 @@ TEST(TestOprBasicArith, AddUpdateVolatile) { } } +// AddUpdate in gradient path but no gradient flows through it +TEST(TestOprBasicArith, AddUpdateInGradPath) { + auto graph = ComputingGraph::make(); + HostTensorGenerator<> gen; + auto dest = opr::SharedDeviceTensor::make(*graph, *gen({42})); + auto host_x = gen({42}); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + // delta depends on x, but not differentiable wrt x + // a invalid grad is registered for AddUpdate to fix this case + auto delta = opr::VirtualDep::make({opr::SetGrad::make(x, nullptr), x}); + auto updated = opr::AddUpdate::make(dest, delta); + auto y = opr::reduce_ax_sum(updated + x, 0); + auto dx = cg::grad(y, x); + HostTensorND host_dx; + auto func = graph->compile({make_callback_copy(dx, host_dx)}); + func->execute(); + for (size_t i = 0; i < host_dx.shape(0); ++i) { + MGB_ASSERT_FLOAT_EQ(host_dx.ptr()[i], 1.f); + } +} + TEST(TestOprBasicArith, MemFwd) { constexpr size_t SIZE = 12321; HostTensorGenerator<> gen; -- GitLab