提交 79058d35 编写于 作者: H huangdongrun

add support for parameter

support for tensor setitem

add support for tensor assgin
上级 c55b81e9
......@@ -92,6 +92,10 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
}
return *this;
}
Tensor &Tensor::AssignValue(const Tensor &tensor) {
*this = tensor;
return *this;
}
bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
......@@ -470,6 +474,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.set_dtype(mindspore.int32)
mindspore.int32
)mydelimiter")
.def("assign_value", &Tensor::AssignValue, R"mydelimiter(
Assign another tensor value to this.
Arg:
value (:class:`mindspore.tensor`): The value tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
>>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
>>> data.assign_value(data2)
>>> data.shape
(2, 2)
)mydelimiter")
.def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle(
......
......@@ -173,6 +173,9 @@ class Tensor : public MetaTensor {
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqual(const Tensor &other) const;
// assgin value to this tensor
Tensor &AssignValue(const Tensor &tensor);
bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) {
auto other_ = static_cast<const Tensor &>(other);
......
......@@ -203,6 +203,8 @@ class Parameter:
return self.default_input / other
def __setitem__(self, index, value):
default_input = self.default_input
default_input[index] = value
return self
def set_parameter_data(self, data):
......
......@@ -150,6 +150,8 @@ class Tensor(Tensor_):
return out
def __setitem__(self, index, value):
out = tensor_operator_registry.get('__setitem__')(self, index, value)
self.assign_value(out)
return self
def __gt__(self, other):
......
......@@ -16,10 +16,8 @@
"""Implementation for setitem."""
from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
from ....common import dtype as mstype
setitem = base.MultitypeFuncGraph('setitem')
......@@ -139,11 +137,7 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value_tensor)
@setitem.register("Tensor", "Tensor", "Number")
......@@ -166,11 +160,7 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value)
@setitem.register("Tensor", "Tuple", "Number")
......@@ -191,24 +181,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value)
@setitem.register("Tensor", "Tuple", "Tensor")
......@@ -229,24 +202,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
@setitem.register("Tensor", "Tuple", "Tuple")
......@@ -268,22 +224,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value)
@setitem.register("Tensor", "Tensor", "Tuple")
......@@ -299,12 +240,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result
return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value)
@setitem.register("Tensor", "Slice", "Tensor")
......@@ -326,7 +262,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_tensor(data, input_slice, value)
return compile_utils.tensor_setitem_by_slice_with_tensor(data, input_slice, value)
@setitem.register("Tensor", "Slice", "Number")
......@@ -348,168 +284,28 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_number(data, input_slice, value)
def _tensor_assgin_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value)
@setitem.register("Tensor", "Number", "Number")
def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3"""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
return compile_utils.tensor_setitem_by_number_with_number(data, index, value)
@setitem.register("Tensor", "Number", "Tensor")
def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor"""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value)
@setitem.register("Tensor", "Ellipsis", "Number")
def _tensor_setitem_with_ellipsis_v1(data, index, value):
"""Syntax: A[...] = number."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)
return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value)
@setitem.register("Tensor", "Ellipsis", "Tensor")
def _tensor_setitem_with_ellipsis_v2(data, index, value):
"""Syntax: A[...] = Tensor."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = compile_utils.generate_updates_from_tuple(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
result = F.scatter_update(data, index, updates)
return result
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = compile_utils.generate_updates_from_scalar(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = const_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = compile_utils.generate_updates_from_tensor(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result
return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册