提交 437bb8c2 编写于 作者: B buxue

support ellipsis and bool for tensor slice

上级 53b3d187
......@@ -495,6 +495,8 @@ TypePtr StringToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>();
} else if (type_name.compare("Ellipsis") == 0) {
type = std::make_shared<Ellipsis>();
} else if (type_name.compare("TypeType") == 0) {
type = std::make_shared<TypeType>();
} else if (type_name.compare("SymbolicKeyType") == 0) {
......
......@@ -18,6 +18,5 @@
namespace mindspore {
const TypePtr kTypeNone = std::make_shared<TypeNone>();
const TypePtr kTypeAnything = std::make_shared<TypeAnything>();
const TypePtr kAnyType = std::make_shared<TypeAnything>();
} // namespace mindspore
......@@ -71,8 +71,20 @@ class TypeNull : public Type {
};
using TypeNullPtr = std::shared_ptr<TypeNull>;
class Ellipsis : public Type {
public:
Ellipsis() : Type(kMetaTypeEllipsis) {}
~Ellipsis() override {}
MS_DECLARE_PARENT(Ellipsis, Type)
TypeId generic_type_id() const override { return kMetaTypeEllipsis; }
TypePtr DeepCopy() const override { return std::make_shared<Ellipsis>(); }
std::string ToReprString() const override { return "Ellipsis"; }
std::string DumpText() const override { return "Ellipsis"; }
};
using EllipsisPtr = std::shared_ptr<Ellipsis>;
extern const TypePtr kTypeNone;
extern const TypePtr kTypeAnything;
extern const TypePtr kAnyType;
} // namespace mindspore
......
......@@ -49,6 +49,7 @@ enum TypeId : int {
kMetaTypeExternal,
kMetaTypeNone,
kMetaTypeNull,
kMetaTypeEllipsis,
kMetaTypeEnd,
//
// Object types
......
......@@ -31,5 +31,8 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract:
const NamedPtr kNone = std::make_shared<None>();
abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
const NamedPtr kNullObj = std::make_shared<NullObj>();
const NamedPtr kNull = std::make_shared<NullObj>();
abstract::AbstractBasePtr EllipsisObj::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
const NamedPtr kEllipsis = std::make_shared<EllipsisObj>();
} // namespace mindspore
......@@ -61,7 +61,6 @@ class Named : public Value {
std::string name_;
std::size_t hash_id_;
};
using NamedPtr = std::shared_ptr<Named>;
class None : public Named {
......@@ -71,7 +70,6 @@ class None : public Named {
MS_DECLARE_PARENT(None, Named);
abstract::AbstractBasePtr ToAbstract() override;
};
extern const NamedPtr kNone;
class NullObj : public Named {
......@@ -81,7 +79,15 @@ class NullObj : public Named {
MS_DECLARE_PARENT(NullObj, Named);
abstract::AbstractBasePtr ToAbstract() override;
};
extern const NamedPtr kNull;
extern const NamedPtr kNullObj;
class EllipsisObj : public Named {
public:
EllipsisObj() : Named("Ellipsis") {}
~EllipsisObj() override = default;
MS_DECLARE_PARENT(EllipsisObj, Named);
abstract::AbstractBasePtr ToAbstract() override;
};
extern const NamedPtr kEllipsis;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_NAMED_H_
......@@ -135,9 +135,9 @@ T InnerScalarMod(T x, T y) {
if (std::is_integral<T>::value) {
return static_cast<int>(x) % static_cast<int>(y);
}
float x_int = std::floor(x);
float y_int = std::ceil(y);
float max = x_int / y_int;
int x_int = std::floor(x);
int y_int = std::ceil(y);
int max = x_int / y_int;
float ret = x - y * max;
return ret;
}
......
......@@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractEllipsis;
using mindspore::abstract::AbstractEllipsisPtr;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractFunctionPtr;
using mindspore::abstract::AbstractList;
......@@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
std::vector<unsigned int> shrink;
auto slice_tuple_eles = slice_tuple->elements();
size_t ellipsis_num = 0;
for (size_t index = 0; index < slice_tuple_size; index++) {
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
......@@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
continue;
}
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice or int number, but got "
if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
ellipsis_num++;
if (ellipsis_num > 1) {
MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
}
size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
begin->insert(begin->end(), ellipsis_len, 0);
end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
strides->insert(strides->end(), ellipsis_len, 1);
shrink.insert(shrink.end(), ellipsis_len, 0);
continue;
}
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
<< slice_tuple_eles[index]->ToString();
}
......@@ -1160,6 +1176,11 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
abstract::CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr tensor_node = ret_graph->add_parameter();
(void)ret_graph->add_parameter();
auto shape = tensorPtr->shape()->shape();
std::vector<int> begin;
std::vector<int> end;
......@@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractScalar>()) {
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
return ExpandADim(ret_graph, tensor_node);
}
}
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
ret_graph->set_output(tensor_node);
return ret_graph;
} else if (args_spec_list[1]->isa<AbstractNone>()) {
return ExpandADim(ret_graph, tensor_node);
} else {
std::ostringstream args_info;
for (const auto &arg : args_spec_list) {
MS_EXCEPTION_IF_NULL(arg);
args_info << arg->ToString() << "\n";
}
MS_LOG(EXCEPTION) << "TensorSlice requires to input a tensor and a slice or slice tuple, but got "
<< args_info.str();
MS_LOG(EXCEPTION)
<< "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
<< args_info.str();
}
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr tensor_node = ret_graph->add_parameter();
(void)ret_graph->add_parameter();
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
......@@ -1199,6 +1225,12 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
return ret_graph;
}
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
return ret_graph;
}
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
.def(py::init<std::string &>());
......
......@@ -206,6 +206,8 @@ class TensorSlice : public MetaFuncGraph {
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
};
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
......
......@@ -109,6 +109,7 @@ void Parser::BuildMethodMap() {
expr_method_map_["Index"] = &Parser::ParseIndex;
expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
expr_method_map_["Dict"] = &Parser::ParseDict;
expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
}
void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
......@@ -187,7 +188,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block,
namelist_for_default_value.push_back(arg_name);
if (py::isinstance<py::none>(defaults[i])) {
default_values.push_back(NewValueNode(kNullObj));
default_values.push_back(NewValueNode(kNull));
} else {
default_values.push_back(ParseExprNode(block, defaults[i]));
}
......@@ -437,6 +438,11 @@ AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
return NewValueNode(kNone);
}
AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
MS_LOG(DEBUG) << "Process ast Ellipsis";
return NewValueNode(kEllipsis);
}
AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Num";
py::object obj = python_adapter::GetPyObjAttr(node, "n");
......
......@@ -92,6 +92,8 @@ class Parser {
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
// process NoneType
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
// process Ellipsis
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
// process a integer or float number
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
// process a string variable
......
......@@ -892,10 +892,27 @@ bool AbstractNull::operator==(const AbstractBase &other) const {
std::string AbstractNull::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "("
<< "Value: "
<< "Null"
<< ")";
buffer << type_name() << "(Value: Null)";
return buffer.str();
}
bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
bool AbstractEllipsis::operator==(const AbstractBase &other) const {
if (&other == this) {
return true;
}
if (other.isa<AbstractEllipsis>()) {
auto other_none = static_cast<const AbstractEllipsis *>(&other);
return *this == *other_none;
} else {
return false;
}
}
std::string AbstractEllipsis::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "(Value: Ellipsis)";
return buffer.str();
}
......
......@@ -498,7 +498,7 @@ using AbstractNonePtr = std::shared_ptr<AbstractNone>;
// the un assigned state value for variable, which means the variable is not assigned
class AbstractNull : public AbstractBase {
public:
AbstractNull() : AbstractBase(kNullObj) { set_type(std::make_shared<TypeNull>()); }
AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
~AbstractNull() override = default;
MS_DECLARE_PARENT(AbstractNull, AbstractBase)
......@@ -510,6 +510,20 @@ class AbstractNull : public AbstractBase {
};
using AbstractNullPtr = std::shared_ptr<AbstractNull>;
class AbstractEllipsis : public AbstractBase {
public:
AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<Ellipsis>()); }
~AbstractEllipsis() override = default;
MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<Ellipsis>(); }
bool operator==(const AbstractEllipsis &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); }
std::string ToString() const override;
};
using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
class AbstractRefKey : public AbstractBase {
public:
AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); }
......
......@@ -150,7 +150,7 @@ def _tensor_getitem_by_number(data, number_index):
@getitem.register("Tensor", "Slice")
def _tensor_getitem_by_slice(data, slice_index):
"""
Getting item of tensor by slice index.
Getting item of tensor by slice.
Inputs:
data (Tensor): A tensor.
......@@ -165,7 +165,7 @@ def _tensor_getitem_by_slice(data, slice_index):
@getitem.register("Tensor", "Tuple")
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
"""
Getting item of tensor by slice tuple index.
Getting item of tensor by slice tuple.
Inputs:
data (Tensor): A tensor.
......@@ -175,3 +175,18 @@ def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
Tensor, element type is same as the element type of data.
"""
return _tensor_slice(data, slice_tuple_index)
@getitem.register("Tensor", "Ellipsis")
def _tensor_getitem_by_ellipsis(data, ellipsis_index):
"""
Getting item of tensor by Ellipsis.
Inputs:
data (Tensor): A tensor.
ellipsis (Ellipsis): A Ellipsis object.
Outputs:
Tensor, same as data.
"""
return _tensor_slice(data, ellipsis_index)
......@@ -67,6 +67,7 @@ scalar_to_tensor = P.ScalarToTensor()
tuple_to_array = P.TupleToArray()
scalar_cast = P.ScalarCast()
print_ = P.Print()
expand_dims = P.ExpandDims()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
......
......@@ -42,6 +42,20 @@ class NetWorkSlicePositive(Cell):
return ret0, ret1, ret2, ret3
class NetWorkSliceEllipsis(Cell):
def __init__(self):
super(NetWorkSliceEllipsis, self).__init__()
self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
def construct(self, tensor):
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
ret1 = tensor[...] + self.tensor_ret1
ret2 = tensor[True] + self.tensor_ret2
return ret0, ret1, ret2
class NetWorkReduceDimension(Cell):
def __init__(self):
super(NetWorkReduceDimension, self).__init__()
......@@ -83,7 +97,7 @@ class NetWorkReduceToScalar(Cell):
class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex, self).__init__()
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
def construct(self, a, b, c, u_tensor, _scalar):
a[c] = u_scalar
......@@ -104,14 +118,14 @@ class TensorAssignWithBoolTensorIndexError(Cell):
class TensorAssignWithBoolTensorIndex2(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex2, self).__init__()
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
def construct(self, a, u_tensor, _scalar):
a[a>8] = u_tensor
a[a>=6] = u_scalar
a[a<3] = u_scalar
a[a<=5] = u_tensor
a[a==5] = u_scalar
a[a > 8] = u_tensor
a[a >= 6] = u_scalar
a[a < 3] = u_scalar
a[a <= 5] = u_tensor
a[a == 5] = u_scalar
z = a + self.t
return z
......@@ -121,11 +135,11 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
def construct(self, a, u_tensor):
a[a>8][a>5] = u_tensor
a[a > 8][a > 5] = u_tensor
return a
a = np.random.uniform(1,10,[2,3])
a = np.random.uniform(1, 10, [2, 3])
b = a > 5
c = a < 3
Ta = Tensor(a)
......@@ -152,7 +166,7 @@ def test_tensor_assign_bool_index():
net1(Ta, Tb, Ta, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
# net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with pytest.raises(ValueError):
net2(Ta, u_tensor_error, u_scalar)
net3 = TensorAssignWithBoolTensorIndexError()
......@@ -192,7 +206,10 @@ test_cases = [
'block': NetWorkReduceToScalar(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
}),
('NetWorkSliceEllipsis', {
'block': NetWorkSliceEllipsis(),
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
}),
]
......
......@@ -162,14 +162,15 @@ def test_ops():
if self.int > self.float:
if [1, 2, 3] != None:
if self.str_a + self.str_b == "helloworld":
print("hello world")
return ret
if q == 86:
print("hello world")
return ret
return x
net = OpsNet(9, 2)
x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
net(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册