diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 742ee57166fadd0e2d450465a5d7a75846e71b08..a3e8c029628e9b9d064a1d8bd419f2eb973f0fb9 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -247,10 +247,10 @@ def _tensor_assgin_tensor(data, input_slice, value): indices = mult_util.slice2indices(input_slice, data_shape) indices_size = F.size(indices) indices_size = mult_util.check_indices(indices_size, input_slice) - update = F.fill(data_dtype, (indices_size,), 1) + update = F.fill(mstype.int32, (indices_size,), 1) condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition_1d = F.cast(condition_1d, mstype.bool_) condition = F.reshape(condition_1d, data_shape) + condition = F.cast(condition, mstype.bool_) # 2. u value_fill = None value_size = F.size(value) @@ -325,10 +325,10 @@ def _tensor_assgin_number(data, input_slice, value): indices = mult_util.slice2indices(input_slice, data_shape) indices_size = F.size(indices) indices_size = mult_util.check_indices(indices_size, input_slice) - update = F.fill(data_dtype, (indices_size,), 1) + update = F.fill(mstype.int32, (indices_size,), 1) condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition_1d = F.cast(condition_1d, mstype.bool_) condition = F.reshape(condition_1d, data_shape) + condition = F.cast(condition, mstype.bool_) # 2. u value_fill = F.fill(data_dtype, (indices_size,), value) value_1d = F.scatter_nd(indices, value_fill, (data_size,))