提交 1edaf6de 编写于 作者: Y yanzhenxiang2020

add bprop for ScatterMax

上级 60de9089
......@@ -496,6 +496,17 @@ def get_bprop_tensor_scatter_update(self):
return bprop
@bprop_getters.register(P.ScatterMax)
def get_bprop_scatter_max(self):
"""Generate bprop for ScatterMax"""
gather = P.GatherV2()
def bprop(x, indices, update, out, dout):
return dout, zeros_like(indices), gather(dout, indices, 0)
return bprop
@bprop_getters.register(P.Argmax)
def get_bprop_argmax(self):
"""Generate bprop for Argmax"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册