Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7602054a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7602054a
编写于
8月 06, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix do concat in while loop specialize error
上级
ef292bb9
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
42 addition
and
0 deletion
+42
-0
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
+16
-0
mindspore/core/ir/func_graph.h
mindspore/core/ir/func_graph.h
+1
-0
mindspore/core/ir/func_graph_cloner.cc
mindspore/core/ir/func_graph_cloner.cc
+1
-0
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+24
-0
未找到文件。
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
浏览文件 @
7602054a
...
...
@@ -144,6 +144,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
MS_EXCEPTION_IF_NULL
(
arg
);
return
arg
->
Broaden
();
});
if
(
func_graph_
->
joined_shapes_
.
size
()
!=
broaded_list
.
size
())
{
MS_EXCEPTION
(
ValueError
)
<<
"Number of input arguments "
<<
broaded_list
.
size
()
<<
" does not equal to number of original buffer arguments "
<<
func_graph_
->
joined_shapes_
.
size
();
}
for
(
size_t
i
=
0
;
i
<
broaded_list
.
size
();
++
i
)
{
broaded_list
[
i
]
->
set_shape
(
func_graph_
->
joined_shapes_
[
i
]);
}
MS_LOG
(
DEBUG
)
<<
func_graph_
->
ToString
()
<<
" original: "
<<
mindspore
::
ToString
(
args_spec_list
)
<<
", broaded: "
<<
mindspore
::
ToString
(
broaded_list
);
return
broaded_list
;
...
...
@@ -171,6 +179,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if
(
!
(
joined_args_spec_list
==
args_spec_list
))
{
func_graph_
->
set_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
,
true
);
func_graph_
->
joined_shapes_
.
clear
();
std
::
transform
(
joined_args_spec_list
.
begin
(),
joined_args_spec_list
.
end
(),
std
::
back_inserter
(
func_graph_
->
joined_shapes_
),
[](
const
AbstractBasePtr
&
arg_spec
)
{
return
arg_spec
->
GetShapeTrack
();
});
MS_LOG
(
DEBUG
)
<<
"Set "
<<
func_graph_
->
ToString
()
<<
" with IGNORE_VALUES flag."
;
}
return
joined_args_spec_list
;
...
...
@@ -185,6 +197,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if
(
!
(
joined_args_spec_list
==
args_spec_list
))
{
trace_
.
push_back
(
joined_args_spec_list
);
func_graph_
->
set_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
,
true
);
func_graph_
->
joined_shapes_
.
clear
();
std
::
transform
(
joined_args_spec_list
.
begin
(),
joined_args_spec_list
.
end
(),
std
::
back_inserter
(
func_graph_
->
joined_shapes_
),
[](
const
AbstractBasePtr
&
arg_spec
)
{
return
arg_spec
->
GetShapeTrack
();
});
MS_LOG
(
DEBUG
)
<<
"Set "
<<
func_graph_
->
ToString
()
<<
" with IGNORE_VALUES flag."
;
}
MS_LOG
(
DEBUG
)
<<
"Joined eval args: "
<<
::
mindspore
::
ToString
(
joined_args_spec_list
);
...
...
mindspore/core/ir/func_graph.h
浏览文件 @
7602054a
...
...
@@ -332,6 +332,7 @@ class FuncGraph : public FuncGraphBase {
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
&
make_ref_params
()
{
return
make_ref_params_
;
}
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
std
::
vector
<
BaseShapePtr
>
joined_shapes_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphTransform
>
transforms_
;
// parameter default value
std
::
map
<
std
::
string
,
AnfNodePtr
>
parameter_default_value_
;
...
...
mindspore/core/ir/func_graph_cloner.cc
浏览文件 @
7602054a
...
...
@@ -220,6 +220,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
TraceManager
::
DebugTrace
(
func_graph
->
debug_info
(),
target_relation_
);
*
target_func_graph
=
std
::
make_shared
<
FuncGraph
>
();
(
*
target_func_graph
)
->
set_attrs
(
func_graph
->
attrs
());
(
*
target_func_graph
)
->
joined_shapes_
=
func_graph
->
joined_shapes_
;
(
*
target_func_graph
)
->
set_transforms
(
func_graph
->
transforms
());
(
*
target_func_graph
)
->
set_has_vararg
(
func_graph
->
has_vararg
());
(
*
target_func_graph
)
->
set_has_kwarg
(
func_graph
->
has_kwarg
());
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
7602054a
...
...
@@ -645,3 +645,27 @@ def test_mixed_precision_cast():
x
=
Tensor
(
np
.
ones
([
2
,
3
],
dtype
=
np
.
float32
))
z
=
F
.
mixed_precision_cast
(
mstype
.
float16
,
x
)
assert
z
.
dtype
==
mstype
.
float16
def
test_while_concat
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
data
):
super
(
Net
,
self
).
__init__
()
self
.
start
=
Tensor
(
0
,
dtype
=
mstype
.
int32
)
self
.
end
=
Tensor
(
2
,
dtype
=
mstype
.
int32
)
self
.
out
=
Tensor
(
np
.
zeros
([
2
,
3
],
dtype
=
np
.
float32
))
self
.
concat
=
P
.
Concat
()
def
construct
(
self
,
inputs
):
idx
=
self
.
start
end
=
self
.
end
out
=
self
.
out
while
idx
<
end
:
xi
=
inputs
[
idx
,
:,
:]
out
=
self
.
concat
((
out
,
xi
))
idx
=
idx
+
1
return
out
x
=
Tensor
(
np
.
arange
(
10
*
2
*
3
).
reshape
(
10
,
2
,
3
).
astype
(
np
.
float32
))
net
=
Net
(
x
)
net
(
x
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录