提交 6b9ac894 编写于 作者: M Megvii Engine Team

fix(mgb/topk): fix topk grad

GitOrigin-RevId: a49154ff06aa8f0ede19e8841d359dd2ffab17ab
上级 6856ce9c
......@@ -415,6 +415,8 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) {
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(TopK) {
// TopK has no gradient on the input k
if (wrt_idx) return nullptr;
if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) {
mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]);
auto add_axis = [](SymbolVar x) {
......
......@@ -416,4 +416,19 @@ TEST(TestOprMisc, TopKSortedIdxOnly) {
MGB_ASSERT_TENSOR_EQ(host_gx, *host_y);
}
TEST(TestOprMisc, TopKGrad) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
std::shared_ptr<HostTensorND> host_x = gen({2, 5});
std::shared_ptr<HostTensorND> host_k = gen({1});
host_k->ptr<float>()[0] = 3;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
k = opr::Host2DeviceCopy::make(*graph, host_k),
ki = opr::TypeCvt::make(k, dtype::Int32{}),
val = opr::TopK::make(x, ki,
opr::TopK::Param::Mode::VALUE_IDX_SORTED)[0],
gk = cg::grad(opr::reduce_sum(val, val.make_scalar(1)), ki, true, false);
EXPECT_TRUE(gk == nullptr);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册