提交 99f2be70 编写于 作者: F fangzehua

unsortedsegsum grad

上级 3e7ba14e
......@@ -673,6 +673,16 @@ def _GatherDropNegatives(params,
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
@bprop_getters.register(P.UnsortedSegmentSum)
def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""
def bprop(x, segment_ids, num_segments, out, dout):
return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
return bprop
@bprop_getters.register(P.UnsortedSegmentMin)
def get_bprop_unsorted_segment_min(self):
"""Generate bprop for UnsortedSegmentMin"""
......
......@@ -1447,14 +1447,12 @@ test_case_nn_ops = [
'block': P.UnsortedSegmentSum(),
'desc_const': [1280],
'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))],
'desc_bprop': [[8192, 1024]],
'skip': ['backward']}),
'desc_bprop': [[1280, 1024]]}),
('UnsortedSegmentSum_1', {
'block': P.UnsortedSegmentSum(),
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
'desc_bprop': [[4, 1, 3]],
'skip': ['backward']}),
'desc_bprop': [[4, 1, 3]]}),
('UnsortedSegmentMin', {
'block': P.UnsortedSegmentMin(),
'desc_const': [4],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册