提交 99f2be70 编写于 作者: F fangzehua

unsortedsegsum grad

上级 3e7ba14e
...@@ -673,6 +673,16 @@ def _GatherDropNegatives(params, ...@@ -673,6 +673,16 @@ def _GatherDropNegatives(params,
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
@bprop_getters.register(P.UnsortedSegmentSum)
def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""
def bprop(x, segment_ids, num_segments, out, dout):
return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
return bprop
@bprop_getters.register(P.UnsortedSegmentMin) @bprop_getters.register(P.UnsortedSegmentMin)
def get_bprop_unsorted_segment_min(self): def get_bprop_unsorted_segment_min(self):
"""Generate bprop for UnsortedSegmentMin""" """Generate bprop for UnsortedSegmentMin"""
......
...@@ -1447,14 +1447,12 @@ test_case_nn_ops = [ ...@@ -1447,14 +1447,12 @@ test_case_nn_ops = [
'block': P.UnsortedSegmentSum(), 'block': P.UnsortedSegmentSum(),
'desc_const': [1280], 'desc_const': [1280],
'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))], 'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))],
'desc_bprop': [[8192, 1024]], 'desc_bprop': [[1280, 1024]]}),
'skip': ['backward']}),
('UnsortedSegmentSum_1', { ('UnsortedSegmentSum_1', {
'block': P.UnsortedSegmentSum(), 'block': P.UnsortedSegmentSum(),
'desc_const': [4], 'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))], 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
'desc_bprop': [[4, 1, 3]], 'desc_bprop': [[4, 1, 3]]}),
'skip': ['backward']}),
('UnsortedSegmentMin', { ('UnsortedSegmentMin', {
'block': P.UnsortedSegmentMin(), 'block': P.UnsortedSegmentMin(),
'desc_const': [4], 'desc_const': [4],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册