提交 6fb55381 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2661 add support dtype for scatter_add vm

Merge pull request !2661 from zhaozhenlong/op/scatter-add
...@@ -31,6 +31,8 @@ scatter_add_op_info = TBERegOp("ScatterAdd") \ ...@@ -31,6 +31,8 @@ scatter_add_op_info = TBERegOp("ScatterAdd") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info() .get_op_info()
......
...@@ -220,10 +220,10 @@ class ScatterMax(nn.Cell): ...@@ -220,10 +220,10 @@ class ScatterMax(nn.Cell):
class ScatterAdd(nn.Cell): class ScatterAdd(nn.Cell):
"""ScatterAdd net definition""" """ScatterAdd net definition"""
def __init__(self, ref_shape): def __init__(self, ref_shape, dtype=np.float32):
super(ScatterAdd, self).__init__() super(ScatterAdd, self).__init__()
self.scatter_add = P.ScatterAdd() self.scatter_add = P.ScatterAdd()
self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref") self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates): def construct(self, indices, updates):
out = self.scatter_add(self.ref, indices, updates) out = self.scatter_add(self.ref, indices, updates)
...@@ -1677,12 +1677,37 @@ test_case_other_ops = [ ...@@ -1677,12 +1677,37 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float32))), Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
'skip': ['backward']}), 'skip': ['backward']}),
('ScatterAddScalar', {
'block': ScatterAdd((6,)),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterAdd2d', { ('ScatterAdd2d', {
'block': ScatterAdd((3, 4)), 'block': ScatterAdd((3, 4)),
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)), 'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]], Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))), [[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
'skip': ['backward']}), 'skip': ['backward']}),
('ScatterAddF16', {
'block': ScatterAdd((6,), np.float16),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
'skip': ['backward']}),
('ScatterAddI8', {
'block': ScatterAdd((6,), np.int8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int8))),
'skip': ['backward']}),
('ScatterAddI32', {
'block': ScatterAdd((6,), np.int32),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int32))),
'skip': ['backward']}),
('ScatterAddU8', {
'block': ScatterAdd((6,), np.uint8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.uint8))),
'skip': ['backward']}),
('SmoothL1Loss', { ('SmoothL1Loss', {
'block': P.SmoothL1Loss(), 'block': P.SmoothL1Loss(),
'desc_inputs': [[256, 4], [256, 4]], 'desc_inputs': [[256, 4], [256, 4]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册