提交 663d5973 编写于 作者: C candanzg

tensor assign with slice index

Signed-off-by: Ncandanzg <zhangshucheng@huawei.com>
上级 9edc69af
......@@ -18,7 +18,7 @@ Interfaces for parser module in c++.
from .parser import (Parser, create_obj_instance, generate_scope,
get_bprop_method_of_class, get_class_instance_type,
get_class_member_namespace_symbol,
get_class_member_namespace_symbol, create_slice_obj,
get_dataclass_attributes, get_dataclass_methods,
get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name,
......@@ -29,4 +29,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
'get_dataclass_methods', 'get_scope_name']
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj']
......@@ -29,6 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from ..utils import Slice
# define return value
RET_SUCCESS = 0
......@@ -69,6 +70,10 @@ parse_expr_statement_white_list = (
"append",
)
def create_slice_obj(start, end, step):
"""Create Slice object"""
return Slice(start, end, step)
def parse_cb(func, parse_method=None):
"""Implements the function of parse."""
......
......@@ -19,6 +19,7 @@ import logging
import os
import inspect
from functools import wraps
from dataclasses import dataclass
def cal_sha256(file_path):
......@@ -99,3 +100,13 @@ def cell_attr_register(fn=None, attrs=None):
if fn is not None:
return wrap_cell(fn)
return wrap_cell
@dataclass
class Slice:
"""
Slice class
"""
start: int
end: int
step: int
......@@ -123,6 +123,9 @@ class ValueSlice : public Value {
abstract::AbstractBasePtr ToAbstract() override;
std::string DumpText() const override { return ToString(); }
ValuePtr start() const { return start_; }
ValuePtr stop() const { return stop_; }
ValuePtr step() const { return step_; }
private:
ValuePtr start_;
......
......@@ -79,6 +79,8 @@ const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
// define the common name
const char NAMED_PRIMITIVE_ITER[] = "iter";
const char NAMED_PRIMITIVE_NEXT[] = "next";
......
......@@ -289,6 +289,13 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic["shape"] = shape;
dic["dtype"] = abs_base->BuildType();
dic["value"] = BuildValue(abs_base->BuildValue());
} else if (abs_base->isa<AbstractSlice>()) {
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
std::vector<int> shape;
dic["shape"] = shape;
dic["dtype"] = arg_slice->BuildType();
dic["value"] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size();
......
......@@ -28,6 +28,7 @@
#include "ir/meta_tensor.h"
#include "pipeline/parse/parse.h"
#include "pipeline/parse/parse_base.h"
#include "ir/value.h"
namespace mindspore {
......@@ -97,6 +98,13 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
i++;
}
ret = rets;
} else if (value->isa<ValueSlice>()) {
auto slice = value->cast<ValueSlicePtr>();
auto start = ValuePtrToPyData(slice->start());
auto end = ValuePtrToPyData(slice->stop());
auto step = ValuePtrToPyData(slice->step());
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
step);
} else if (value->isa<Type>()) {
py::tuple v(1);
v[0] = value->cast<TypePtr>();
......
......@@ -15,7 +15,43 @@
"""constexpr util"""
import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ...._extends.utils import Slice
@constexpr
def check_equal(param1, param2, msg="{},{}"):
if param1 != param2:
raise ValueError(msg.format(param1, param2))
return param1
@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Check tuple index type of tensor assignment."""
if index is None:
raise ValueError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u
if isinstance(index, Slice):
return True
# eg. Tensor[Tuple] = u
if isinstance(index, tuple):
if not index:
raise ValueError("Tensor's index cannot be empty.")
# eg. Tensor[Tuple(Slice...)] = u
if not isinstance(index[0], Slice):
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
return True
# eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_:
raise ValueError(
"The index of tensor should be a bool type tensor. \
{} type is not supported yet.".format(element_type))
return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index)))
@constexpr
......@@ -43,3 +79,84 @@ def error_msg(msg="", format_values=""):
"""
raise ValueError(msg.format(*format_values))
def slice_expand(input_slices, shape):
"""
Convert slice to indices.
Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple.
Outputs:
(List, List, List), This is expressed as (begins, ends, strides).
"""
begin = []
end = []
strides = []
index = 0
slices = None
# Slice or Tuple(Slice...)
if isinstance(input_slices, Slice):
slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
slices = input_slices
else:
raise ValueError("Tensor's index type is not supported yet.")
for s in slices:
start = 0 if (s.start is None) else s.start
stop = shape[index] if (s.end is None) else s.end
step = 1 if (s.step is None) else s.step
begin.append(start)
end.append(stop)
strides.append(step)
index += 1
while index < len(shape):
begin.append(0)
end.append(shape[index])
strides.append(1)
index += 1
return begin, end, strides
@constexpr
def slice2indices(input_slices, shape):
"""
Convert slice to indices.
Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple.
Outputs:
Tensor, the shape is (n, 1).
"""
begin, end, strides = slice_expand(input_slices, shape)
np_r = []
for i, element in enumerate(shape):
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
e = end[i] if (end[i] >= 0) else (element + end[i])
np_r.append(np.r_[s:e:strides[i]])
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
np_ix = np.ix_(*np_r)
ravel = np.ravel_multi_index(np_ix, shape)
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
return ravel
@constexpr
def check_indices(indices_size, index):
if indices_size < 1:
raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size
@constexpr
def check_indices_value_size(indices_size, value_size):
if value_size < 1:
raise ValueError("The value assigned to tensor cannot be empty.")
if value_size > 1:
if value_size != indices_size:
raise ValueError(
"The value given to tensor does not match the index size. \
value size:{}, indics size:{}".format(value_size, indices_size))
return value_size
......@@ -138,25 +138,23 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
Outputs:
Tensor, element type and shape is same as data.
"""
result = None
index_dtype = F.dtype(index)
index_shape = F.shape(index)
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
if not is_bool:
return mult_util.error_msg(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
data_shape = F.shape(data)
if index_shape != data_shape:
return mult_util.error_msg(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape))
size = F.size(value_tensor)
if size != 1:
return mult_util.error_msg(
"When assign value is a tensor, its size should be 1, but current size is {}.", (size,))
dtype = F.dtype(data)
u_cast = F.cast(value_tensor, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
return F.select(index, u, data)
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
if check_result:
data_shape = F.shape(data)
data_shape = mult_util.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value_tensor)
size = mult_util.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value_tensor, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result
@setitem.register("Tensor", "Tensor", "Number")
......@@ -179,16 +177,162 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
result = None
index_dtype = F.dtype(index)
index_shape = F.shape(index)
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
if not is_bool:
return mult_util.error_msg(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
shape = F.shape(data)
if index_shape != shape:
return mult_util.error_msg(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape))
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
if check_result:
shape = F.shape(data)
shape = mult_util.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
result = F.select(index, u, data)
return result
@setitem.register("Tensor", "Slice", "Tensor")
def _tensor_setitem_with_slice_v3(data, input_slice, value):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = U
Restraint condition: A is a Tensor
Slice like "1:3"
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice): Slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_tensor(data, input_slice, value)
@setitem.register("Tensor", "Tuple", "Tensor")
def _tensor_setitem_with_slice_v4(data, input_slice, value):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = U
Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (Tuple(Slice)): Slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_tensor(data, input_slice, value)
def _tensor_assgin_tensor(data, input_slice, value):
"""Given a tensor value assign to tensor by slice"""
# 1. condition
result = None
check_result = mult_util.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
data_size = F.size(data)
data_dtype = F.dtype(data)
indices = mult_util.slice2indices(input_slice, data_shape)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, input_slice)
update = F.fill(data_dtype, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape)
# 2. u
value_fill = None
value_size = F.size(value)
value_size = mult_util.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result = F.select(condition, u, data)
return result
@setitem.register("Tensor", "Slice", "Number")
def _tensor_setitem_with_slice_v1(data, input_slice, value):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = u
Restraint condition: A is a Tensor.
Slice like "1:3"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice): slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_number(data, input_slice, value)
@setitem.register("Tensor", "Tuple", "Number")
def _tensor_setitem_with_slice_v2(data, input_slice, value):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = u
Restraint condition: A is a Tensor.
Slice like "1:3, ::, :4:-1"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (Tuple(Slice)): slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_number(data, input_slice, value)
def _tensor_assgin_number(data, input_slice, value):
"""Given a scalar assign to tensor by slice"""
# 1. condition
check_result = mult_util.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
data_size = F.size(data)
data_dtype = F.dtype(data)
indices = mult_util.slice2indices(input_slice, data_shape)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, input_slice)
update = F.fill(data_dtype, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape)
# 2. u
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result = F.select(condition, u, data)
return result
......@@ -68,6 +68,7 @@ tuple_to_array = P.TupleToArray()
scalar_cast = P.ScalarCast()
print_ = P.Print()
expand_dims = P.ExpandDims()
scatter_nd = P.ScatterNd()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
......
......@@ -94,10 +94,101 @@ class NetWorkReduceToScalar(Cell):
return ret
class TensorAssignWithSliceError1(Cell):
def __init__(self):
super(TensorAssignWithSliceError1, self).__init__()
def construct(self, a, b):
a[1:3:-1,::] = b
return a
class TensorAssignWithSliceError2(Cell):
def __init__(self):
super(TensorAssignWithSliceError2, self).__init__()
def construct(self, a, b):
a[1:3:-1] = b
return a
class TensorAssignWithSlice2(Cell):
def __init__(self):
super(TensorAssignWithSlice2, self).__init__()
def construct(self, a, b):
a[1:5] = b
a[3:4] = 5
a[-1:1:-1] = b
a[-1:3:-1] = 5
a[::] = b
a[::] = 9
return a
class TensorAssignWithSlice(Cell):
def __init__(self):
super(TensorAssignWithSlice, self).__init__()
self.c = 2
def construct(self, a, b):
a[1:3,::] = b
a[2:3:,3:] = b
a[::] = b
a[::] = self.c
a[::,::] = b
a[::,::] = self.c
a[2:3:,0:, 4:1:-1] = b
a[2:3:,0:, 4:1:-1] = self.c
z = a
return z
def test_tensor_assign_with_slice():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice()
net2= TensorAssignWithSlice2()
net_e1 = TensorAssignWithSliceError1()
net_e2 = TensorAssignWithSliceError2()
a = np.arange(60).reshape(3,4,5)
b = Tensor([1])
Ta = Tensor(a)
Tb= Tensor([1,3])
Tc= Tensor([])
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
net(Ta, b)
net2(t, b)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error
with pytest.raises(ValueError):
net_e2(t, 2)
# Error for A[Slice] = U, U is a Tensor
# 1. A[Slice] = U, u.size is error
with pytest.raises(ValueError):
net2(t, Tb)
# 2. A[Slice] = U, U is empty
with pytest.raises(ValueError):
net2(t, Tc)
# 3. A[Slice] = U, U.size error
with pytest.raises(ValueError):
net2(t, Tb)
# Error for A[Tuple(Slice...)] = Tensor
# 1. A[Tuple(Slice...)] = U, U is empty
with pytest.raises(ValueError):
net(Ta, Tc)
# 2. A[Tuple(Slice...)] = U, U.size error
with pytest.raises(ValueError):
net(Ta, Tb)
# 3. A[Tuple(Slice...)] = U, Slice error
with pytest.raises(ValueError):
net_e1(Ta, b)
# Error for A[Tuple(Slice...)] = Number
# 1. A[Tuple(Slice...)] = Number, Slice error
with pytest.raises(ValueError):
net_e1(Ta, 2)
class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex, self).__init__()
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64)
def construct(self, a, b, c, u_tensor, _scalar):
a[c] = u_scalar
......@@ -119,6 +210,7 @@ class TensorAssignWithBoolTensorIndex2(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex2, self).__init__()
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64)
def construct(self, a, u_tensor, _scalar):
a[a > 8] = u_tensor
......@@ -139,7 +231,7 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
return a
a = np.random.uniform(1, 10, [2, 3])
a = np.random.uniform(1,10,[3,4,5])
b = a > 5
c = a < 3
Ta = Tensor(a)
......@@ -148,13 +240,13 @@ Tc = Tensor(c)
Td = Tensor([True, True])
u_tensor = Tensor([1])
u_tensor_error = Tensor([1, 2])
t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
u_scalar = 5
def test_tensor_assign_bool_index():
net1 = TensorAssignWithBoolTensorIndex()
net2 = TensorAssignWithBoolTensorIndex2()
net1(Ta, Tb, Tc, u_tensor, u_scalar)
net1(Ta, Tb, Tc, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor, u_scalar)
......@@ -180,8 +272,15 @@ def test_tensor_assign_bool_index():
with pytest.raises(AttributeError):
net4(Ta, u_scalar)
test_cases = [
('TensorAssignWithSlice', {
'block': TensorAssignWithSlice(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithSlice2', {
'block': TensorAssignWithSlice2(),
'desc_inputs': [t_1d, u_tensor],
}),
('TensorAssignWithBoolTensorIndex', {
'block': TensorAssignWithBoolTensorIndex(),
'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册