From 99f2be7064bdd3a2feb156c1dab2c66fdb313dd3 Mon Sep 17 00:00:00 2001 From: fangzehua Date: Mon, 3 Aug 2020 15:53:01 +0800 Subject: [PATCH] unsortedsegsum grad --- mindspore/ops/_grad/grad_array_ops.py | 10 ++++++++++ tests/ut/python/ops/test_ops.py | 6 ++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index d8d3328b9..319ae151d 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -673,6 +673,16 @@ def _GatherDropNegatives(params, return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) +@bprop_getters.register(P.UnsortedSegmentSum) +def get_bprop_unsorted_segment_sum(self): + """Generate bprop for UnsortedSegmentSum""" + + def bprop(x, segment_ids, num_segments, out, dout): + return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments) + + return bprop + + @bprop_getters.register(P.UnsortedSegmentMin) def get_bprop_unsorted_segment_min(self): """Generate bprop for UnsortedSegmentMin""" diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index f22366e13..55f4d6c3c 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1447,14 +1447,12 @@ test_case_nn_ops = [ 'block': P.UnsortedSegmentSum(), 'desc_const': [1280], 'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))], - 'desc_bprop': [[8192, 1024]], - 'skip': ['backward']}), + 'desc_bprop': [[1280, 1024]]}), ('UnsortedSegmentSum_1', { 'block': P.UnsortedSegmentSum(), 'desc_const': [4], 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))], - 'desc_bprop': [[4, 1, 3]], - 'skip': ['backward']}), + 'desc_bprop': [[4, 1, 3]]}), ('UnsortedSegmentMin', { 'block': P.UnsortedSegmentMin(), 'desc_const': [4], -- GitLab