提交 d55c0f66 编写于 作者: Z zhaozhenlong

complement scatter add support dtype

上级 80dd3321
......@@ -31,6 +31,8 @@ scatter_add_op_info = TBERegOp("ScatterAdd") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
......
......@@ -220,10 +220,10 @@ class ScatterMax(nn.Cell):
class ScatterAdd(nn.Cell):
"""ScatterAdd net definition"""
def __init__(self, ref_shape):
def __init__(self, ref_shape, dtype=np.float32):
super(ScatterAdd, self).__init__()
self.scatter_add = P.ScatterAdd()
self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref")
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_add(self.ref, indices, updates)
......@@ -1677,12 +1677,37 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
'skip': ['backward']}),
('ScatterAddScalar', {
'block': ScatterAdd((6,)),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterAdd2d', {
'block': ScatterAdd((3, 4)),
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
'skip': ['backward']}),
('ScatterAddF16', {
'block': ScatterAdd((6,), np.float16),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
'skip': ['backward']}),
('ScatterAddI8', {
'block': ScatterAdd((6,), np.int8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int8))),
'skip': ['backward']}),
('ScatterAddI32', {
'block': ScatterAdd((6,), np.int32),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int32))),
'skip': ['backward']}),
('ScatterAddU8', {
'block': ScatterAdd((6,), np.uint8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.uint8))),
'skip': ['backward']}),
('SmoothL1Loss', {
'block': P.SmoothL1Loss(),
'desc_inputs': [[256, 4], [256, 4]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册