From 6b9ac894d3202ffc0b5ca272622f205a9b917686 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 20 Nov 2020 20:38:50 +0800 Subject: [PATCH] fix(mgb/topk): fix topk grad GitOrigin-RevId: a49154ff06aa8f0ede19e8841d359dd2ffab17ab --- src/opr/impl/misc.cpp | 2 ++ src/opr/test/misc.cpp | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index c8fc54dd..26bc0a33 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -415,6 +415,8 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) { #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(TopK) { + // TopK has no gradient on the input k + if (wrt_idx) return nullptr; if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); auto add_axis = [](SymbolVar x) { diff --git a/src/opr/test/misc.cpp b/src/opr/test/misc.cpp index 2dbb381a..083cdb16 100644 --- a/src/opr/test/misc.cpp +++ b/src/opr/test/misc.cpp @@ -416,4 +416,19 @@ TEST(TestOprMisc, TopKSortedIdxOnly) { MGB_ASSERT_TENSOR_EQ(host_gx, *host_y); } +TEST(TestOprMisc, TopKGrad) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + std::shared_ptr host_x = gen({2, 5}); + std::shared_ptr host_k = gen({1}); + host_k->ptr()[0] = 3; + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + k = opr::Host2DeviceCopy::make(*graph, host_k), + ki = opr::TypeCvt::make(k, dtype::Int32{}), + val = opr::TopK::make(x, ki, + opr::TopK::Param::Mode::VALUE_IDX_SORTED)[0], + gk = cg::grad(opr::reduce_sum(val, val.make_scalar(1)), ki, true, false); + EXPECT_TRUE(gk == nullptr); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab