Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
24a9f497
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
24a9f497
编写于
4月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!536 support ellipsis and bool for tensor slice
Merge pull request !536 from zhangbuxue/support_elipis_for_tensor_slice
上级
715c0735
437bb8c2
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
170 addition
and
40 deletion
+170
-40
mindspore/ccsrc/ir/dtype.cc
mindspore/ccsrc/ir/dtype.cc
+2
-0
mindspore/ccsrc/ir/dtype/empty.cc
mindspore/ccsrc/ir/dtype/empty.cc
+0
-1
mindspore/ccsrc/ir/dtype/empty.h
mindspore/ccsrc/ir/dtype/empty.h
+13
-1
mindspore/ccsrc/ir/dtype/type.h
mindspore/ccsrc/ir/dtype/type.h
+1
-0
mindspore/ccsrc/ir/named.cc
mindspore/ccsrc/ir/named.cc
+4
-1
mindspore/ccsrc/ir/named.h
mindspore/ccsrc/ir/named.h
+9
-3
mindspore/ccsrc/operator/cc_implementations.cc
mindspore/ccsrc/operator/cc_implementations.cc
+3
-3
mindspore/ccsrc/operator/composite/composite.cc
mindspore/ccsrc/operator/composite/composite.cc
+41
-9
mindspore/ccsrc/operator/composite/composite.h
mindspore/ccsrc/operator/composite/composite.h
+2
-0
mindspore/ccsrc/pipeline/parse/parse.cc
mindspore/ccsrc/pipeline/parse/parse.cc
+7
-1
mindspore/ccsrc/pipeline/parse/parse.h
mindspore/ccsrc/pipeline/parse/parse.h
+2
-0
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
+21
-4
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
+15
-1
mindspore/ops/composite/multitype_ops/getitem_impl.py
mindspore/ops/composite/multitype_ops/getitem_impl.py
+17
-2
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-0
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+28
-11
tests/ut/python/pipeline/parse/test_operator.py
tests/ut/python/pipeline/parse/test_operator.py
+4
-3
未找到文件。
mindspore/ccsrc/ir/dtype.cc
浏览文件 @
24a9f497
...
...
@@ -495,6 +495,8 @@ TypePtr StringToType(const std::string &type_name) {
TypePtr
type
=
nullptr
;
if
(
type_name
.
compare
(
"None"
)
==
0
)
{
type
=
std
::
make_shared
<
TypeNone
>
();
}
else
if
(
type_name
.
compare
(
"Ellipsis"
)
==
0
)
{
type
=
std
::
make_shared
<
Ellipsis
>
();
}
else
if
(
type_name
.
compare
(
"TypeType"
)
==
0
)
{
type
=
std
::
make_shared
<
TypeType
>
();
}
else
if
(
type_name
.
compare
(
"SymbolicKeyType"
)
==
0
)
{
...
...
mindspore/ccsrc/ir/dtype/empty.cc
浏览文件 @
24a9f497
...
...
@@ -18,6 +18,5 @@
namespace
mindspore
{
const
TypePtr
kTypeNone
=
std
::
make_shared
<
TypeNone
>
();
const
TypePtr
kTypeAnything
=
std
::
make_shared
<
TypeAnything
>
();
const
TypePtr
kAnyType
=
std
::
make_shared
<
TypeAnything
>
();
}
// namespace mindspore
mindspore/ccsrc/ir/dtype/empty.h
浏览文件 @
24a9f497
...
...
@@ -71,8 +71,20 @@ class TypeNull : public Type {
};
using
TypeNullPtr
=
std
::
shared_ptr
<
TypeNull
>
;
class
Ellipsis
:
public
Type
{
public:
Ellipsis
()
:
Type
(
kMetaTypeEllipsis
)
{}
~
Ellipsis
()
override
{}
MS_DECLARE_PARENT
(
Ellipsis
,
Type
)
TypeId
generic_type_id
()
const
override
{
return
kMetaTypeEllipsis
;
}
TypePtr
DeepCopy
()
const
override
{
return
std
::
make_shared
<
Ellipsis
>
();
}
std
::
string
ToReprString
()
const
override
{
return
"Ellipsis"
;
}
std
::
string
DumpText
()
const
override
{
return
"Ellipsis"
;
}
};
using
EllipsisPtr
=
std
::
shared_ptr
<
Ellipsis
>
;
extern
const
TypePtr
kTypeNone
;
extern
const
TypePtr
kTypeAnything
;
extern
const
TypePtr
kAnyType
;
}
// namespace mindspore
...
...
mindspore/ccsrc/ir/dtype/type.h
浏览文件 @
24a9f497
...
...
@@ -49,6 +49,7 @@ enum TypeId : int {
kMetaTypeExternal
,
kMetaTypeNone
,
kMetaTypeNull
,
kMetaTypeEllipsis
,
kMetaTypeEnd
,
//
// Object types
...
...
mindspore/ccsrc/ir/named.cc
浏览文件 @
24a9f497
...
...
@@ -31,5 +31,8 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract:
const
NamedPtr
kNone
=
std
::
make_shared
<
None
>
();
abstract
::
AbstractBasePtr
NullObj
::
ToAbstract
()
{
return
std
::
make_shared
<
abstract
::
AbstractNull
>
();
}
const
NamedPtr
kNullObj
=
std
::
make_shared
<
NullObj
>
();
const
NamedPtr
kNull
=
std
::
make_shared
<
NullObj
>
();
abstract
::
AbstractBasePtr
EllipsisObj
::
ToAbstract
()
{
return
std
::
make_shared
<
abstract
::
AbstractEllipsis
>
();
}
const
NamedPtr
kEllipsis
=
std
::
make_shared
<
EllipsisObj
>
();
}
// namespace mindspore
mindspore/ccsrc/ir/named.h
浏览文件 @
24a9f497
...
...
@@ -61,7 +61,6 @@ class Named : public Value {
std
::
string
name_
;
std
::
size_t
hash_id_
;
};
using
NamedPtr
=
std
::
shared_ptr
<
Named
>
;
class
None
:
public
Named
{
...
...
@@ -71,7 +70,6 @@ class None : public Named {
MS_DECLARE_PARENT
(
None
,
Named
);
abstract
::
AbstractBasePtr
ToAbstract
()
override
;
};
extern
const
NamedPtr
kNone
;
class
NullObj
:
public
Named
{
...
...
@@ -81,7 +79,15 @@ class NullObj : public Named {
MS_DECLARE_PARENT
(
NullObj
,
Named
);
abstract
::
AbstractBasePtr
ToAbstract
()
override
;
};
extern
const
NamedPtr
kNull
;
extern
const
NamedPtr
kNullObj
;
class
EllipsisObj
:
public
Named
{
public:
EllipsisObj
()
:
Named
(
"Ellipsis"
)
{}
~
EllipsisObj
()
override
=
default
;
MS_DECLARE_PARENT
(
EllipsisObj
,
Named
);
abstract
::
AbstractBasePtr
ToAbstract
()
override
;
};
extern
const
NamedPtr
kEllipsis
;
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_NAMED_H_
mindspore/ccsrc/operator/cc_implementations.cc
浏览文件 @
24a9f497
...
...
@@ -135,9 +135,9 @@ T InnerScalarMod(T x, T y) {
if
(
std
::
is_integral
<
T
>::
value
)
{
return
static_cast
<
int
>
(
x
)
%
static_cast
<
int
>
(
y
);
}
floa
t
x_int
=
std
::
floor
(
x
);
floa
t
y_int
=
std
::
ceil
(
y
);
floa
t
max
=
x_int
/
y_int
;
in
t
x_int
=
std
::
floor
(
x
);
in
t
y_int
=
std
::
ceil
(
y
);
in
t
max
=
x_int
/
y_int
;
float
ret
=
x
-
y
*
max
;
return
ret
;
}
...
...
mindspore/ccsrc/operator/composite/composite.cc
浏览文件 @
24a9f497
...
...
@@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase;
using
mindspore
::
abstract
::
AbstractClass
;
using
mindspore
::
abstract
::
AbstractDictionary
;
using
mindspore
::
abstract
::
AbstractDictionaryPtr
;
using
mindspore
::
abstract
::
AbstractEllipsis
;
using
mindspore
::
abstract
::
AbstractEllipsisPtr
;
using
mindspore
::
abstract
::
AbstractFunction
;
using
mindspore
::
abstract
::
AbstractFunctionPtr
;
using
mindspore
::
abstract
::
AbstractList
;
...
...
@@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
std
::
vector
<
unsigned
int
>
shrink
;
auto
slice_tuple_eles
=
slice_tuple
->
elements
();
size_t
ellipsis_num
=
0
;
for
(
size_t
index
=
0
;
index
<
slice_tuple_size
;
index
++
)
{
if
(
slice_tuple_eles
[
index
]
->
isa
<
AbstractSlice
>
())
{
AbstractSlicePtr
slice
=
dyn_cast
<
AbstractSlice
>
(
slice_tuple_eles
[
index
]);
...
...
@@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
continue
;
}
MS_LOG
(
EXCEPTION
)
<<
"Slice tuple only could contain slice or int number, but got "
if
(
slice_tuple_eles
[
index
]
->
isa
<
AbstractEllipsis
>
())
{
ellipsis_num
++
;
if
(
ellipsis_num
>
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Tensor slice supports at most one ellipsis"
;
}
size_t
ellipsis_len
=
shape_size
-
(
slice_tuple_size
-
1
);
begin
->
insert
(
begin
->
end
(),
ellipsis_len
,
0
);
end
->
insert
(
end
->
end
(),
shape
.
begin
()
+
index
,
shape
.
begin
()
+
index
+
ellipsis_len
);
strides
->
insert
(
strides
->
end
(),
ellipsis_len
,
1
);
shrink
.
insert
(
shrink
.
end
(),
ellipsis_len
,
0
);
continue
;
}
MS_LOG
(
EXCEPTION
)
<<
"Slice tuple only could contain slice, int number or ellipsis, but got "
<<
slice_tuple_eles
[
index
]
->
ToString
();
}
...
...
@@ -1160,6 +1176,11 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
abstract
::
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractTensorPtr
tensorPtr
=
abstract
::
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
0
);
FuncGraphPtr
ret_graph
=
std
::
make_shared
<
FuncGraph
>
();
ret_graph
->
set_flags
(
FUNC_GRAPH_FLAG_CORE
,
true
);
AnfNodePtr
tensor_node
=
ret_graph
->
add_parameter
();
(
void
)
ret_graph
->
add_parameter
();
auto
shape
=
tensorPtr
->
shape
()
->
shape
();
std
::
vector
<
int
>
begin
;
std
::
vector
<
int
>
end
;
...
...
@@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
shrink_axis_mask
=
GenerateStridedSliceParametersFromSlice
(
slice_ptr
,
shape
,
&
begin
,
&
end
,
&
strides
);
}
else
if
(
args_spec_list
[
1
]
->
isa
<
AbstractScalar
>
())
{
AbstractScalarPtr
scalar_ptr
=
dyn_cast
<
AbstractScalar
>
(
args_spec_list
[
1
]);
if
(
scalar_ptr
->
BuildValue
()
->
isa
<
BoolImm
>
())
{
if
(
scalar_ptr
->
BuildValue
()
->
cast
<
BoolImmPtr
>
()
->
value
())
{
return
ExpandADim
(
ret_graph
,
tensor_node
);
}
}
shrink_axis_mask
=
GenerateStridedSliceParametersFromNumber
(
scalar_ptr
,
shape
,
&
begin
,
&
end
,
&
strides
);
}
else
if
(
args_spec_list
[
1
]
->
isa
<
AbstractEllipsis
>
())
{
ret_graph
->
set_output
(
tensor_node
);
return
ret_graph
;
}
else
if
(
args_spec_list
[
1
]
->
isa
<
AbstractNone
>
())
{
return
ExpandADim
(
ret_graph
,
tensor_node
);
}
else
{
std
::
ostringstream
args_info
;
for
(
const
auto
&
arg
:
args_spec_list
)
{
MS_EXCEPTION_IF_NULL
(
arg
);
args_info
<<
arg
->
ToString
()
<<
"
\n
"
;
}
MS_LOG
(
EXCEPTION
)
<<
"TensorSlice requires to input a tensor and a slice or slice tuple, but got "
<<
args_info
.
str
();
MS_LOG
(
EXCEPTION
)
<<
"TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
<<
args_info
.
str
();
}
FuncGraphPtr
ret_graph
=
std
::
make_shared
<
FuncGraph
>
();
ret_graph
->
set_flags
(
FUNC_GRAPH_FLAG_CORE
,
true
);
AnfNodePtr
tensor_node
=
ret_graph
->
add_parameter
();
(
void
)
ret_graph
->
add_parameter
();
auto
PrimStridedSliceClass
=
prim
::
GetPythonOps
(
"StridedSlice"
,
"mindspore.ops.operations"
);
auto
PrimStridedSlice
=
ret_graph
->
NewCNode
({
NewValueNode
(
PrimStridedSliceClass
),
NewValueNode
(
0
),
NewValueNode
(
0
),
NewValueNode
(
0
),
NewValueNode
(
0
),
NewValueNode
(
shrink_axis_mask
)});
...
...
@@ -1199,6 +1225,12 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
return
ret_graph
;
}
FuncGraphPtr
TensorSlice
::
ExpandADim
(
const
FuncGraphPtr
&
ret_graph
,
const
AnfNodePtr
&
tensor_node
)
const
{
auto
PrimExpandDims
=
GetPythonOps
(
"expand_dims"
,
"mindspore.ops.functional"
);
ret_graph
->
set_output
(
NewCNode
({
NewValueNode
(
PrimExpandDims
),
tensor_node
,
NewValueNode
(
0
)},
ret_graph
));
return
ret_graph
;
}
REGISTER_PYBIND_DEFINE
(
TupleAdd_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
TupleAdd
,
MetaFuncGraph
,
std
::
shared_ptr
<
TupleAdd
>>
(
*
m
,
"TupleAdd_"
)
.
def
(
py
::
init
<
std
::
string
&>
());
...
...
mindspore/ccsrc/operator/composite/composite.h
浏览文件 @
24a9f497
...
...
@@ -206,6 +206,8 @@ class TensorSlice : public MetaFuncGraph {
MS_DECLARE_PARENT
(
TensorSlice
,
MetaFuncGraph
)
FuncGraphPtr
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
override
;
friend
bool
operator
==
(
const
TensorSlice
&
lhs
,
const
TensorSlice
&
rhs
)
{
return
lhs
.
name_
==
rhs
.
name_
;
}
FuncGraphPtr
ExpandADim
(
const
FuncGraphPtr
&
ret_graph
,
const
AnfNodePtr
&
tensor_node
)
const
;
};
using
TensorSlicePtr
=
std
::
shared_ptr
<
TensorSlice
>
;
...
...
mindspore/ccsrc/pipeline/parse/parse.cc
浏览文件 @
24a9f497
...
...
@@ -109,6 +109,7 @@ void Parser::BuildMethodMap() {
expr_method_map_
[
"Index"
]
=
&
Parser
::
ParseIndex
;
expr_method_map_
[
"UnaryOp"
]
=
&
Parser
::
ParseUnaryOp
;
expr_method_map_
[
"Dict"
]
=
&
Parser
::
ParseDict
;
expr_method_map_
[
"Ellipsis"
]
=
&
Parser
::
ParseEllipsis
;
}
void
Parser
::
UpdateTopFuncGraph
(
const
FuncGraphPtr
&
func_graph
)
{
top_func_graph_
=
FuncGraphWeakPtr
(
func_graph
);
}
...
...
@@ -187,7 +188,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block,
namelist_for_default_value
.
push_back
(
arg_name
);
if
(
py
::
isinstance
<
py
::
none
>
(
defaults
[
i
]))
{
default_values
.
push_back
(
NewValueNode
(
kNull
Obj
));
default_values
.
push_back
(
NewValueNode
(
kNull
));
}
else
{
default_values
.
push_back
(
ParseExprNode
(
block
,
defaults
[
i
]));
}
...
...
@@ -437,6 +438,11 @@ AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
return
NewValueNode
(
kNone
);
}
AnfNodePtr
Parser
::
ParseEllipsis
(
const
FunctionBlockPtr
&
,
const
py
::
object
&
)
{
MS_LOG
(
DEBUG
)
<<
"Process ast Ellipsis"
;
return
NewValueNode
(
kEllipsis
);
}
AnfNodePtr
Parser
::
ParseNum
(
const
FunctionBlockPtr
&
,
const
py
::
object
&
node
)
{
MS_LOG
(
DEBUG
)
<<
"Process ast Num"
;
py
::
object
obj
=
python_adapter
::
GetPyObjAttr
(
node
,
"n"
);
...
...
mindspore/ccsrc/pipeline/parse/parse.h
浏览文件 @
24a9f497
...
...
@@ -92,6 +92,8 @@ class Parser {
AnfNodePtr
ParseName
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
// process NoneType
AnfNodePtr
ParseNone
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
// process Ellipsis
AnfNodePtr
ParseEllipsis
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
// process a integer or float number
AnfNodePtr
ParseNum
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
// process a string variable
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
浏览文件 @
24a9f497
...
...
@@ -892,10 +892,27 @@ bool AbstractNull::operator==(const AbstractBase &other) const {
std
::
string
AbstractNull
::
ToString
()
const
{
std
::
ostringstream
buffer
;
buffer
<<
type_name
()
<<
"("
<<
"Value: "
<<
"Null"
<<
")"
;
buffer
<<
type_name
()
<<
"(Value: Null)"
;
return
buffer
.
str
();
}
bool
AbstractEllipsis
::
operator
==
(
const
AbstractEllipsis
&
)
const
{
return
true
;
}
bool
AbstractEllipsis
::
operator
==
(
const
AbstractBase
&
other
)
const
{
if
(
&
other
==
this
)
{
return
true
;
}
if
(
other
.
isa
<
AbstractEllipsis
>
())
{
auto
other_none
=
static_cast
<
const
AbstractEllipsis
*>
(
&
other
);
return
*
this
==
*
other_none
;
}
else
{
return
false
;
}
}
std
::
string
AbstractEllipsis
::
ToString
()
const
{
std
::
ostringstream
buffer
;
buffer
<<
type_name
()
<<
"(Value: Ellipsis)"
;
return
buffer
.
str
();
}
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
浏览文件 @
24a9f497
...
...
@@ -498,7 +498,7 @@ using AbstractNonePtr = std::shared_ptr<AbstractNone>;
// the un assigned state value for variable, which means the variable is not assigned
class
AbstractNull
:
public
AbstractBase
{
public:
AbstractNull
()
:
AbstractBase
(
kNull
Obj
)
{
set_type
(
std
::
make_shared
<
TypeNull
>
());
}
AbstractNull
()
:
AbstractBase
(
kNull
)
{
set_type
(
std
::
make_shared
<
TypeNull
>
());
}
~
AbstractNull
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractNull
,
AbstractBase
)
...
...
@@ -510,6 +510,20 @@ class AbstractNull : public AbstractBase {
};
using
AbstractNullPtr
=
std
::
shared_ptr
<
AbstractNull
>
;
class
AbstractEllipsis
:
public
AbstractBase
{
public:
AbstractEllipsis
()
:
AbstractBase
(
kEllipsis
)
{
set_type
(
std
::
make_shared
<
Ellipsis
>
());
}
~
AbstractEllipsis
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractEllipsis
,
AbstractBase
)
TypePtr
BuildType
()
const
override
{
return
std
::
make_shared
<
Ellipsis
>
();
}
bool
operator
==
(
const
AbstractEllipsis
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractEllipsis
>
();
}
std
::
string
ToString
()
const
override
;
};
using
AbstractEllipsisPtr
=
std
::
shared_ptr
<
AbstractEllipsis
>
;
class
AbstractRefKey
:
public
AbstractBase
{
public:
AbstractRefKey
()
:
AbstractBase
()
{
set_type
(
std
::
make_shared
<
RefKeyType
>
());
}
...
...
mindspore/ops/composite/multitype_ops/getitem_impl.py
浏览文件 @
24a9f497
...
...
@@ -150,7 +150,7 @@ def _tensor_getitem_by_number(data, number_index):
@
getitem
.
register
(
"Tensor"
,
"Slice"
)
def
_tensor_getitem_by_slice
(
data
,
slice_index
):
"""
Getting item of tensor by slice
index
.
Getting item of tensor by slice.
Inputs:
data (Tensor): A tensor.
...
...
@@ -165,7 +165,7 @@ def _tensor_getitem_by_slice(data, slice_index):
@
getitem
.
register
(
"Tensor"
,
"Tuple"
)
def
_tensor_getitem_by_slice_tuple
(
data
,
slice_tuple_index
):
"""
Getting item of tensor by slice tuple
index
.
Getting item of tensor by slice tuple.
Inputs:
data (Tensor): A tensor.
...
...
@@ -175,3 +175,18 @@ def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
Tensor, element type is same as the element type of data.
"""
return
_tensor_slice
(
data
,
slice_tuple_index
)
@
getitem
.
register
(
"Tensor"
,
"Ellipsis"
)
def
_tensor_getitem_by_ellipsis
(
data
,
ellipsis_index
):
"""
Getting item of tensor by Ellipsis.
Inputs:
data (Tensor): A tensor.
ellipsis (Ellipsis): A Ellipsis object.
Outputs:
Tensor, same as data.
"""
return
_tensor_slice
(
data
,
ellipsis_index
)
mindspore/ops/functional.py
浏览文件 @
24a9f497
...
...
@@ -67,6 +67,7 @@ scalar_to_tensor = P.ScalarToTensor()
tuple_to_array
=
P
.
TupleToArray
()
scalar_cast
=
P
.
ScalarCast
()
print_
=
P
.
Print
()
expand_dims
=
P
.
ExpandDims
()
tuple_setitem
=
Primitive
(
'tuple_setitem'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
24a9f497
...
...
@@ -42,6 +42,20 @@ class NetWorkSlicePositive(Cell):
return
ret0
,
ret1
,
ret2
,
ret3
class
NetWorkSliceEllipsis
(
Cell
):
def
__init__
(
self
):
super
(
NetWorkSliceEllipsis
,
self
).
__init__
()
self
.
tensor_ret0
=
Tensor
(
np
.
ones
([
2
,
7
,
8
],
np
.
int32
))
self
.
tensor_ret1
=
Tensor
(
np
.
ones
([
6
,
7
,
8
,
9
],
np
.
int32
))
self
.
tensor_ret2
=
Tensor
(
np
.
ones
([
1
,
6
,
7
,
8
,
9
],
np
.
int32
))
def
construct
(
self
,
tensor
):
ret0
=
tensor
[
0
:
4
:
2
,
...,
1
]
+
self
.
tensor_ret0
ret1
=
tensor
[...]
+
self
.
tensor_ret1
ret2
=
tensor
[
True
]
+
self
.
tensor_ret2
return
ret0
,
ret1
,
ret2
class
NetWorkReduceDimension
(
Cell
):
def
__init__
(
self
):
super
(
NetWorkReduceDimension
,
self
).
__init__
()
...
...
@@ -83,7 +97,7 @@ class NetWorkReduceToScalar(Cell):
class
TensorAssignWithBoolTensorIndex
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
def
construct
(
self
,
a
,
b
,
c
,
u_tensor
,
_scalar
):
a
[
c
]
=
u_scalar
...
...
@@ -104,14 +118,14 @@ class TensorAssignWithBoolTensorIndexError(Cell):
class
TensorAssignWithBoolTensorIndex2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex2
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
def
construct
(
self
,
a
,
u_tensor
,
_scalar
):
a
[
a
>
8
]
=
u_tensor
a
[
a
>=
6
]
=
u_scalar
a
[
a
<
3
]
=
u_scalar
a
[
a
<=
5
]
=
u_tensor
a
[
a
==
5
]
=
u_scalar
a
[
a
>
8
]
=
u_tensor
a
[
a
>=
6
]
=
u_scalar
a
[
a
<
3
]
=
u_scalar
a
[
a
<=
5
]
=
u_tensor
a
[
a
==
5
]
=
u_scalar
z
=
a
+
self
.
t
return
z
...
...
@@ -121,11 +135,11 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
super
(
TensorAssignWithBoolTensorIndex2Error
,
self
).
__init__
()
def
construct
(
self
,
a
,
u_tensor
):
a
[
a
>
8
][
a
>
5
]
=
u_tensor
a
[
a
>
8
][
a
>
5
]
=
u_tensor
return
a
a
=
np
.
random
.
uniform
(
1
,
10
,[
2
,
3
])
a
=
np
.
random
.
uniform
(
1
,
10
,
[
2
,
3
])
b
=
a
>
5
c
=
a
<
3
Ta
=
Tensor
(
a
)
...
...
@@ -152,7 +166,7 @@ def test_tensor_assign_bool_index():
net1
(
Ta
,
Tb
,
Ta
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Tc
,
u_tensor_error
,
u_scalar
)
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
#
net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with
pytest
.
raises
(
ValueError
):
net2
(
Ta
,
u_tensor_error
,
u_scalar
)
net3
=
TensorAssignWithBoolTensorIndexError
()
...
...
@@ -192,7 +206,10 @@ test_cases = [
'block'
:
NetWorkReduceToScalar
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))],
}),
(
'NetWorkSliceEllipsis'
,
{
'block'
:
NetWorkSliceEllipsis
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
7
,
8
,
9
],
np
.
int32
))],
}),
]
...
...
tests/ut/python/pipeline/parse/test_operator.py
浏览文件 @
24a9f497
...
...
@@ -162,14 +162,15 @@ def test_ops():
if
self
.
int
>
self
.
float
:
if
[
1
,
2
,
3
]
!=
None
:
if
self
.
str_a
+
self
.
str_b
==
"helloworld"
:
print
(
"hello world"
)
return
ret
if
q
==
86
:
print
(
"hello world"
)
return
ret
return
x
net
=
OpsNet
(
9
,
2
)
x
=
Tensor
(
np
.
random
.
randint
(
low
=
1
,
high
=
10
,
size
=
(
2
,
3
,
4
),
dtype
=
np
.
int32
))
y
=
Tensor
(
np
.
random
.
randint
(
low
=
10
,
high
=
20
,
size
=
(
2
,
3
,
4
),
dtype
=
np
.
int32
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
(
x
,
y
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录