Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ea475637
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea475637
编写于
7月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2720 fix assign used in while loop
Merge pull request !2720 from xychow/fix-assign-in-while
上级
1332a049
d5255fe3
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
69 addition
and
21 deletion
+69
-21
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+1
-1
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
+20
-0
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
+1
-0
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
+2
-0
tests/ut/python/pipeline/infer/test_net_infer.py
tests/ut/python/pipeline/infer/test_net_infer.py
+45
-20
未找到文件。
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
ea475637
...
...
@@ -314,7 +314,7 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
auto
weight_name
=
weight_node
->
cast
<
ParameterPtr
>
()
->
name
();
// find the fakequant from input
int
count
=
0
;
int
max_depth
=
5
;
const
int
max_depth
=
5
;
while
(
!
is_quant_cnode
(
x
))
{
if
(
count
>=
max_depth
)
{
break
;
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
浏览文件 @
ea475637
...
...
@@ -451,6 +451,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
if
(
other_tensor
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Join failed as type mismatch, this: "
<<
ToString
()
<<
", other: "
<<
other
->
ToString
();
}
if
(
*
this
==
*
other
)
{
if
(
sparse_grad
()
==
other
->
sparse_grad
())
{
return
shared_from_base
<
AbstractBase
>
();
}
}
auto
element
=
element_
->
Join
(
other_tensor
->
element_
);
auto
shape
=
ShapeJoin
(
this
->
shape
(),
other_tensor
->
shape
());
auto
ret
=
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
...
...
@@ -830,6 +835,21 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return
false
;
}
AbstractBasePtr
AbstractRef
::
Join
(
const
AbstractBasePtr
&
other
)
{
auto
other_ref
=
other
->
cast
<
AbstractRefPtr
>
();
if
(
other_ref
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Join failed as type mismatch, this: "
<<
ToString
()
<<
", other: "
<<
other
->
ToString
();
}
if
(
*
this
==
*
other
)
{
return
shared_from_base
<
AbstractBase
>
();
}
auto
ref_key
=
ref_key_
->
Join
(
other_ref
->
ref_key_
);
auto
ref
=
ref_
->
Join
(
other_ref
->
ref
());
auto
ref_origin
=
ref_origin_
->
Join
(
other_ref
->
ref_origin_
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key
,
ref
,
ref_origin
);
}
std
::
string
AbstractRef
::
ToString
()
const
{
std
::
ostringstream
buffer
;
buffer
<<
type_name
()
<<
"("
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
浏览文件 @
ea475637
...
...
@@ -578,6 +578,7 @@ class AbstractRef : public AbstractBase {
AbstractBasePtr
Broaden
()
const
override
{
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(),
ref_
->
Broaden
(),
ref_origin_
->
Broaden
());
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
std
::
size_t
hash
()
const
override
{
return
ref_key_
->
hash
()
^
ref_
->
hash
()
^
ref_origin_
->
hash
()
^
(
std
::
hash
<
uint32_t
>
{}(
this
->
tid
())
<<
1
);
}
...
...
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
浏览文件 @
ea475637
...
...
@@ -166,6 +166,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if
(
!
(
joined_args_spec_list
==
args_spec_list
))
{
func_graph_
->
set_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
,
true
);
MS_LOG
(
DEBUG
)
<<
"Set "
<<
func_graph_
->
ToString
()
<<
" with IGNORE_VALUES flag."
;
}
return
joined_args_spec_list
;
}
...
...
@@ -179,6 +180,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if
(
!
(
joined_args_spec_list
==
args_spec_list
))
{
trace_
.
push_back
(
joined_args_spec_list
);
func_graph_
->
set_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
,
true
);
MS_LOG
(
DEBUG
)
<<
"Set "
<<
func_graph_
->
ToString
()
<<
" with IGNORE_VALUES flag."
;
}
MS_LOG
(
DEBUG
)
<<
"Joined eval args: "
<<
::
mindspore
::
ToString
(
joined_args_spec_list
);
return
joined_args_spec_list
;
...
...
tests/ut/python/pipeline/infer/test_net_infer.py
浏览文件 @
ea475637
...
...
@@ -16,10 +16,14 @@
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
,
context
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
import
mindspore.ops.operations
as
op
class
Net
(
nn
.
Cell
):
def
test_net_infer
():
""" test_net_infer """
class
Net
(
nn
.
Cell
):
""" Net definition """
def
__init__
(
self
):
...
...
@@ -36,9 +40,30 @@ class Net(nn.Cell):
x
=
self
.
flatten
(
x
)
out
=
self
.
fc
(
x
)
return
out
def
test_net_infer
():
""" test_net_infer """
Tensor
(
np
.
random
.
randint
(
0
,
255
,
[
1
,
3
,
224
,
224
]))
Net
()
def
test_assign_in_while
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
input_shape
):
super
().
__init__
()
self
.
assign
=
op
.
Assign
()
self
.
inputdata
=
Parameter
(
initializer
(
1
,
input_shape
),
name
=
"global_step"
)
def
construct
(
self
,
x
,
y
,
z
):
out
=
z
while
x
<
y
:
inputdata
=
self
.
inputdata
x
=
x
+
1
out
=
self
.
assign
(
inputdata
,
z
)
return
out
x
=
Tensor
(
np
.
array
(
1
).
astype
(
np
.
int32
))
y
=
Tensor
(
np
.
array
(
3
).
astype
(
np
.
int32
))
input_shape
=
(
1024
,
512
)
z
=
Tensor
(
np
.
random
.
randn
(
*
input_shape
).
astype
(
np
.
float32
))
net
=
Net
(
input_shape
)
ret
=
net
(
x
,
y
,
z
)
assert
ret
==
z
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录