Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c8a1a24c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c8a1a24c
编写于
4月 23, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the infer of TruncatedNormal and a bug of structure output and a bug of tensorslice ellipsis
上级
ffdad1ac
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
60 addition
and
43 deletion
+60
-43
mindspore/ccsrc/operator/composite/composite.cc
mindspore/ccsrc/operator/composite/composite.cc
+8
-5
mindspore/ccsrc/pipeline/pipeline_ge.cc
mindspore/ccsrc/pipeline/pipeline_ge.cc
+18
-13
mindspore/ops/composite/multitype_ops/getitem_impl.py
mindspore/ops/composite/multitype_ops/getitem_impl.py
+15
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+14
-21
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+1
-1
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+4
-3
未找到文件。
mindspore/ccsrc/operator/composite/composite.cc
浏览文件 @
c8a1a24c
...
...
@@ -1084,6 +1084,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
std
::
vector
<
unsigned
int
>
shrink
;
auto
slice_tuple_eles
=
slice_tuple
->
elements
();
size_t
ellipsis_num
=
0
;
for
(
size_t
index
=
0
;
index
<
slice_tuple_size
;
index
++
)
{
if
(
slice_tuple_eles
[
index
]
->
isa
<
AbstractSlice
>
())
{
AbstractSlicePtr
slice
=
dyn_cast
<
AbstractSlice
>
(
slice_tuple_eles
[
index
]);
...
...
@@ -1118,12 +1119,13 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
<<
slice_tuple_eles
[
index
]
->
ToString
();
}
for
(
size_t
index
=
slice_tuple_size
;
index
<
shape_size
;
index
++
)
{
begin
->
push_back
(
0
);
end
->
push_back
(
shape
[
index
]);
strides
->
push_back
(
1
);
if
(
ellipsis_num
==
0
)
{
for
(
size_t
index
=
slice_tuple_size
;
index
<
shape_size
;
index
++
)
{
begin
->
push_back
(
0
);
end
->
push_back
(
shape
[
index
]);
strides
->
push_back
(
1
);
}
}
return
ConvertBinaryToDecimal
(
shrink
);
}
...
...
@@ -1199,6 +1201,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
if
(
scalar_ptr
->
BuildValue
()
->
cast
<
BoolImmPtr
>
()
->
value
())
{
return
ExpandADim
(
ret_graph
,
tensor_node
);
}
MS_LOG
(
EXCEPTION
)
<<
"TensorSlice not support the index is False."
;
}
shrink_axis_mask
=
GenerateStridedSliceParametersFromNumber
(
scalar_ptr
,
shape
,
&
begin
,
&
end
,
&
strides
);
}
else
if
(
args_spec_list
[
1
]
->
isa
<
AbstractEllipsis
>
())
{
...
...
mindspore/ccsrc/pipeline/pipeline_ge.cc
浏览文件 @
c8a1a24c
...
...
@@ -319,19 +319,24 @@ void RunGEInitGraph(const py::dict &init_params, const std::string &phase) {
py
::
object
ExtractGeneralCnodeRet
(
const
AbstractBasePtr
&
cnode_data
,
const
py
::
tuple
&
data
,
size_t
*
count
)
{
MS_EXCEPTION_IF_NULL
(
cnode_data
);
if
(
*
count
>=
data
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The number of elements in the outputs : "
<<
data
.
size
()
<<
" less than the number of elements required. "
;
}
if
(
cnode_data
->
isa
<
AbstractTensor
>
())
{
if
(
*
count
>=
data
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The number of elements in the outputs : "
<<
data
.
size
()
<<
" less than the number of elements required. "
;
}
BaseShapePtr
shape
=
cnode_data
->
BuildShape
();
auto
shape_act
=
shape
->
cast
<
abstract
::
ShapePtr
>
()
->
shape
();
Tensor
tensor_exp
=
py
::
cast
<
Tensor
>
(
data
[
*
count
]);
if
(
shape_act
!=
tensor_exp
.
shape
())
{
MS_LOG
(
EXCEPTION
)
<<
"The shape of the tensor returned from GE is not the same as "
"the shape of the tensor derived from ME."
;
if
(
!
shape
->
isa
<
abstract
::
Shape
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The shape of the tensor derived is not Shape, is "
<<
shape
->
ToString
();
}
auto
shape_me
=
shape
->
cast
<
abstract
::
ShapePtr
>
()
->
shape
();
auto
shape_ge
=
py
::
cast
<
Tensor
>
(
data
[
*
count
]).
shape
();
if
(
shape_ge
!=
shape_me
)
{
MS_LOG
(
EXCEPTION
)
<<
"The shape of the "
<<
*
count
<<
"th tensor returned: "
<<
shape_ge
<<
" is not the same as the shape of the tensor derived: "
<<
shape_me
;
}
return
data
[(
*
count
)
++
];
}
...
...
@@ -357,11 +362,11 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
return
ValuePtrToPyData
(
GetValueNode
(
output_node
));
}
if
(
*
count
>=
data
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The number of elements in the outputs : "
<<
data
.
size
()
<<
" less than the number of elements required. "
;
}
if
(
output_node
->
isa
<
Parameter
>
())
{
if
(
*
count
>=
data
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The number of elements in the outputs : "
<<
data
.
size
()
<<
" less than the number of elements required. "
;
}
return
data
[(
*
count
)
++
];
}
...
...
mindspore/ops/composite/multitype_ops/getitem_impl.py
浏览文件 @
c8a1a24c
...
...
@@ -147,6 +147,21 @@ def _tensor_getitem_by_number(data, number_index):
return
_tensor_slice
(
data
,
number_index
)
@
getitem
.
register
(
"Tensor"
,
"None"
)
def
_tensor_getitem_by_none
(
data
,
index
):
"""
Getting item of tensor by None.
Inputs:
data (Tensor): A tensor.
index (None): None.
Outputs:
Tensor, element type is as same as the element type of data.
"""
return
_tensor_slice
(
data
,
index
)
@
getitem
.
register
(
"Tensor"
,
"Slice"
)
def
_tensor_getitem_by_slice
(
data
,
slice_index
):
"""
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
c8a1a24c
...
...
@@ -633,7 +633,7 @@ class TruncatedNormal(PrimitiveWithInfer):
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
Inputs:
- **shape** (
Tensor) - Shape of output tensor. The shape is a 1-D tensor, and type is
int.
- **shape** (
tuple[int]) - Shape of output tensor, is a tuple of positive
int.
Outputs:
Tensor, type of output tensor is same as attribute `dtype`.
...
...
@@ -651,16 +651,10 @@ class TruncatedNormal(PrimitiveWithInfer):
validator
.
check_typename
(
'dtype'
,
dtype
,
mstype
.
number_type
)
def
__infer__
(
self
,
shape
):
shape_t
=
shape
[
'value'
]
validator
.
check_subclass
(
"shape"
,
shape
[
'dtype'
],
mstype
.
tensor
)
shape_n
=
shape_t
.
asnumpy
()
if
shape_n
.
ndim
!=
1
:
raise
ValueError
(
'The rank of input shape must be 1.'
)
if
shape_n
.
dtype
not
in
(
np
.
int32
,
np
.
int64
):
raise
TypeError
(
'The type of input shape must be int32 or int64.'
)
for
i
,
item
in
enumerate
(
shape_n
):
validator
.
check_integer
(
f
"shape[
{
i
}
]"
,
item
.
item
(),
0
,
Rel
.
GT
)
out
=
{
'shape'
:
tuple
(
shape_n
),
shape_value
=
shape
[
'value'
]
for
i
,
value
in
enumerate
(
shape_value
):
validator
.
check_integer
(
f
'
{
i
}
th value of shape'
,
value
,
0
,
Rel
.
GT
)
out
=
{
'shape'
:
shape_value
,
'dtype'
:
mstype
.
tensor_type
(
self
.
dtype
),
'value'
:
None
}
return
out
...
...
@@ -1648,20 +1642,19 @@ class StridedSlice(PrimitiveWithInfer):
validator
.
check_type
(
'shrink_axis_mask'
,
shrink_axis_mask
,
[
int
])
def
__infer__
(
self
,
x
,
begin
,
end
,
strides
):
begin_shape
,
end_shape
,
strides_shape
=
begin
[
'shape'
],
end
[
'shape'
],
strides
[
'shape'
]
if
begin_shape
!=
strides_shape
or
end_shape
!=
strides_shape
:
raise
ValueError
(
"The shape of begin, end and strides in 'StridedSlice' must be equal."
)
validator
.
check_const_input
(
"begin"
,
begin
[
'value'
])
validator
.
check_const_input
(
"end"
,
end
[
'value'
])
validator
.
check_const_input
(
"strides"
,
strides
[
'value'
])
x_shape
=
x
[
'shape'
]
x_shp_len
=
len
(
x_shape
)
begin_v
,
end_v
,
strides_v
=
begin
[
'value'
],
end
[
'value'
],
strides
[
'value'
]
validator
.
check_const_input
(
"begin"
,
begin_v
)
validator
.
check_const_input
(
"end"
,
end_v
)
validator
.
check_const_input
(
"strides"
,
strides_v
)
validator
.
check_type
(
"begin"
,
begin
[
'value'
],
[
tuple
])
validator
.
check_type
(
"end"
,
end
[
'value'
],
[
tuple
])
validator
.
check_type
(
"strides"
,
strides
[
'value'
],
[
tuple
])
if
len
(
begin_v
)
!=
x_shp_len
or
len
(
end_v
)
!=
x_shp_len
or
len
(
strides_v
)
!=
x_shp_len
:
raise
ValueError
(
f
"The length of begin index
{
begin_v
}
, end index
{
end_v
}
and strides
{
strides_v
}
"
f
"must be equal to the dims(
{
x_shp_len
}
) of input."
)
x_shape
=
x
[
'shape'
]
x_shp_len
=
len
(
x_shape
)
begin_v
,
end_v
,
strides_v
=
begin
[
'value'
],
end
[
'value'
],
strides
[
'value'
]
ret_shape
=
[]
append_dimensions
=
[]
shrink_pos
=
bin
(
self
.
shrink_axis_mask
)[::
-
1
]
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
c8a1a24c
...
...
@@ -372,7 +372,7 @@ test_case_math_ops = [
'desc_bprop'
:
[[
3
]]}),
(
'TruncatedNormal'
,
{
'block'
:
P
.
TruncatedNormal
(),
'desc_const'
:
[
Tensor
(
np
.
array
([
1
,
2
,
3
]))
],
'desc_const'
:
[
[
1
,
2
,
3
]
],
'desc_inputs'
:
[],
'skip'
:
[
'backward'
],
'add_fake_input'
:
True
}),
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
c8a1a24c
...
...
@@ -52,8 +52,9 @@ class NetWorkSliceEllipsis(Cell):
def
construct
(
self
,
tensor
):
ret0
=
tensor
[
0
:
4
:
2
,
...,
1
]
+
self
.
tensor_ret0
ret1
=
tensor
[...]
+
self
.
tensor_ret1
ret2
=
tensor
[
True
]
+
self
.
tensor_ret2
return
ret0
,
ret1
,
ret2
ret2
=
tensor
[
None
]
+
self
.
tensor_ret2
ret3
=
tensor
[
True
]
+
self
.
tensor_ret2
return
ret0
,
ret1
,
ret2
,
ret3
class
NetWorkReduceDimension
(
Cell
):
...
...
@@ -305,7 +306,7 @@ test_cases = [
'block'
:
NetWorkReduceToScalar
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))],
}),
(
'
NetWork
SliceEllipsis'
,
{
(
'
Tensor
SliceEllipsis'
,
{
'block'
:
NetWorkSliceEllipsis
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
7
,
8
,
9
],
np
.
int32
))],
}),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录