diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index f6fe6e376ae4e3d2cff9ed385ea76ba498887105..1be108d3a77fa35593344228cab5a02fe87477ed 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -496,6 +496,17 @@ def get_bprop_tensor_scatter_update(self): return bprop +@bprop_getters.register(P.ScatterMax) +def get_bprop_scatter_max(self): + """Generate bprop for ScatterMax""" + gather = P.GatherV2() + + def bprop(x, indices, update, out, dout): + return dout, zeros_like(indices), gather(dout, indices, 0) + + return bprop + + @bprop_getters.register(P.Argmax) def get_bprop_argmax(self): """Generate bprop for Argmax"""