Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
323a80c6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
323a80c6
编写于
7月 09, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 09, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2765 fix large for loop segment fault
Merge pull request !2765 from fary86/fix_large_for_loop
上级
5fc6dcfb
e17469bf
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
202 addition
and
10 deletion
+202
-10
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+6
-0
mindspore/ccsrc/operator/prim_statement.cc
mindspore/ccsrc/operator/prim_statement.cc
+6
-1
mindspore/ccsrc/operator/prim_structures.cc
mindspore/ccsrc/operator/prim_structures.cc
+5
-0
mindspore/ccsrc/pipeline/parse/function_block.cc
mindspore/ccsrc/pipeline/parse/function_block.cc
+7
-2
mindspore/ccsrc/pipeline/parse/function_block.h
mindspore/ccsrc/pipeline/parse/function_block.h
+2
-1
mindspore/ccsrc/pipeline/parse/parse.cc
mindspore/ccsrc/pipeline/parse/parse.cc
+133
-2
mindspore/ccsrc/pipeline/parse/parse.h
mindspore/ccsrc/pipeline/parse/parse.h
+2
-0
mindspore/ccsrc/pipeline/parse/parse_base.h
mindspore/ccsrc/pipeline/parse/parse_base.h
+1
-0
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+1
-4
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+39
-0
未找到文件。
mindspore/ccsrc/operator/ops.h
浏览文件 @
323a80c6
...
...
@@ -294,6 +294,12 @@ extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
extern
const
PrimitivePtr
kPrimIndexedSlicesGetDenseShape
;
extern
const
PrimitivePtr
kPrimIsIndexedSlices
;
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
const
char
SWITCH_UNROLL_FLAG
[]
=
"unroll_flag"
;
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
// will be sunk(i.e. not unrolled)
const
int
MAX_FOR_LOOP_COUNT
=
200
;
class
DoSignaturePrimitive
:
public
Primitive
{
public:
explicit
DoSignaturePrimitive
(
const
std
::
string
&
name
,
const
ValuePtr
&
function
)
...
...
mindspore/ccsrc/operator/prim_statement.cc
浏览文件 @
323a80c6
...
...
@@ -95,7 +95,7 @@ AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &prim
return
std
::
make_shared
<
AbstractTensor
>
(
input_x
->
element
(),
std
::
make_shared
<
Shape
>
(
param
));
}
AbstractBasePtr
InferImplSwitch
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
AbstractBasePtr
InferImplSwitch
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
prim
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: condition, true branch, false branch
if
(
args_spec_list
.
size
()
!=
3
)
{
...
...
@@ -108,6 +108,11 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
auto
fb
=
args_spec_list
[
2
];
MS_EXCEPTION_IF_NULL
(
cond
);
auto
unroll_flag
=
prim
->
GetAttr
(
prim
::
SWITCH_UNROLL_FLAG
);
if
(
unroll_flag
!=
nullptr
&&
GetValue
<
int
>
(
unroll_flag
)
==
0
)
{
return
tb
->
Join
(
fb
);
}
ValuePtr
v
=
cond
->
GetValueTrack
();
MS_EXCEPTION_IF_NULL
(
v
);
// for tensor as condition, keeps both true and false branch.
...
...
mindspore/ccsrc/operator/prim_structures.cc
浏览文件 @
323a80c6
...
...
@@ -208,6 +208,11 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
ValuePtr
index_value
=
index
->
BuildValue
();
if
(
!
index_value
->
isa
<
Int32Imm
>
())
{
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
// and continue
if
(
dyn_cast
<
AbstractScalar
>
(
queue
->
elements
()[
0
])
!=
nullptr
)
{
return
std
::
make_shared
<
AbstractScalar
>
(
queue
->
elements
()[
0
]
->
BuildType
());
}
MS_EXCEPTION
(
IndexError
)
<<
op_name
<<
" evaluator index should be an int32 number, but got "
<<
index_value
->
ToString
();
}
...
...
mindspore/ccsrc/pipeline/parse/function_block.cc
浏览文件 @
323a80c6
...
...
@@ -294,13 +294,18 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node)
// Perform a conditional jump using switch operation.
// The first CNode select graph with condition, and than execute this graph
void
FunctionBlock
::
ConditionalJump
(
AnfNodePtr
condNode
,
const
FunctionBlockPtr
&
true_block
,
const
FunctionBlockPtr
&
false_block
)
{
const
FunctionBlockPtr
&
false_block
,
bool
unroll_loop
)
{
if
(
func_graph
()
->
get_return
()
!=
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: have return node! NodeInfo: "
<<
trace
::
GetDebugInfo
(
func_graph
()
->
get_return
()
->
debug_info
());
}
// Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch'
auto
prim_switch
=
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimSwitch
->
name
());
if
(
!
unroll_loop
)
{
prim_switch
->
AddAttr
(
prim
::
SWITCH_UNROLL_FLAG
,
MakeValue
(
0
));
}
CNodePtr
switch_app
=
func_graph
()
->
NewCNode
({
NewValueNode
(
prim
::
kPrimS
witch
),
condNode
,
NewValueNode
(
true_block
->
func_graph
()),
func_graph
()
->
NewCNode
({
NewValueNode
(
prim
_s
witch
),
condNode
,
NewValueNode
(
true_block
->
func_graph
()),
NewValueNode
(
false_block
->
func_graph
())});
CNodePtr
switch_app_new
=
func_graph
()
->
NewCNode
({
switch_app
});
func_graph
()
->
set_output
(
switch_app_new
);
...
...
mindspore/ccsrc/pipeline/parse/function_block.h
浏览文件 @
323a80c6
...
...
@@ -59,7 +59,8 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
CNodePtr
ForceToWhileCond
(
const
AnfNodePtr
&
cond
);
void
Jump
(
const
FunctionBlockPtr
&
block
,
AnfNodePtr
node
);
AnfNodePtr
SearchReplaceNode
(
const
std
::
string
&
var
,
const
ParameterPtr
&
phi
);
void
ConditionalJump
(
AnfNodePtr
condNode
,
const
FunctionBlockPtr
&
trueBlock
,
const
FunctionBlockPtr
&
falseBlock
);
void
ConditionalJump
(
AnfNodePtr
condNode
,
const
FunctionBlockPtr
&
trueBlock
,
const
FunctionBlockPtr
&
falseBlock
,
bool
unroll_loop
=
true
);
// record the assign statement of self.xx weight parameter ,which will use state_setitem op
void
SetStateAssgin
(
const
AnfNodePtr
&
target
,
const
std
::
string
&
readid
);
void
AddAutoDepend
(
const
AnfNodePtr
&
target
);
...
...
mindspore/ccsrc/pipeline/parse/parse.cc
浏览文件 @
323a80c6
...
...
@@ -1002,6 +1002,7 @@ CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::
AnfNodePtr
iter_anf_node
=
ParseExprNode
(
block
,
iter_node
);
return
block
->
func_graph
()
->
NewCNode
({
op_iter
,
iter_anf_node
});
}
CNodePtr
Parser
::
GenerateCondInFor
(
const
ParameterPtr
&
iter_param
,
const
FunctionBlockPtr
&
header_block
,
const
AnfNodePtr
&
op_hasnext
)
{
MS_EXCEPTION_IF_NULL
(
header_block
);
...
...
@@ -1018,12 +1019,57 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
// it compiled to be following statement
// it is compiled to be following statement
// if len(xs) < max_loop_cnt:
// ParseForIter() // use iter to implement for loop, which always unroll loop
// else:
// ParseForLoop() // use loop var to implement for loop, which always sink loop
FunctionBlockPtr
Parser
::
ParseFor
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
)
{
MS_LOG
(
DEBUG
)
<<
"Process ast For, create an if else statement"
;
MS_EXCEPTION_IF_NULL
(
block
);
// create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT'
AnfNodePtr
op_len
=
block
->
MakeResolveSymbol
(
NAMED_PRIMITIVE_LEN
);
py
::
object
iter_obj
=
python_adapter
::
GetPyObjAttr
(
node
,
NAMED_PRIMITIVE_ITER
);
AnfNodePtr
iter_node
=
ParseExprNode
(
block
,
iter_obj
);
CNodePtr
len_iter
=
block
->
func_graph
()
->
NewCNode
({
op_len
,
iter_node
});
CNodePtr
bool_node
=
block
->
func_graph
()
->
NewCNode
(
{
NewValueNode
(
prim
::
kPrimScalarLt
),
len_iter
,
NewValueNode
(
prim
::
MAX_FOR_LOOP_COUNT
)});
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceIfStmtTrueBranch
>
(
block
->
func_graph
()
->
debug_info
()));
FunctionBlockPtr
true_block
=
MakeFunctionBlock
(
*
this
);
TraceManager
::
EndTrace
();
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceIfStmtFalseBranch
>
(
block
->
func_graph
()
->
debug_info
()));
FunctionBlockPtr
false_block
=
MakeFunctionBlock
(
*
this
);
TraceManager
::
EndTrace
();
MakeConditionBlocks
(
block
,
true_block
,
false_block
);
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceIfStmtAfterBranch
>
(
block
->
func_graph
()
->
debug_info
()));
FunctionBlockPtr
after_block
=
MakeFunctionBlock
(
*
this
);
TraceManager
::
EndTrace
();
FunctionBlockPtr
true_end
=
ParseForIter
(
true_block
,
node
);
true_end
->
Jump
(
after_block
,
nullptr
);
FunctionBlockPtr
false_end
=
ParseForLoop
(
false_block
,
node
);
false_end
->
Jump
(
after_block
,
nullptr
);
block
->
ConditionalJump
(
bool_node
,
true_block
,
false_block
);
after_block
->
Mature
();
return
after_block
;
}
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
// it is compiled to be following statement
// it = iter(xs)
// while hastnext(it)
// x, it = next(it)
// body
FunctionBlockPtr
Parser
::
ParseFor
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
)
{
FunctionBlockPtr
Parser
::
ParseFor
Iter
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
)
{
MS_LOG
(
DEBUG
)
<<
"Process ast For"
;
MS_EXCEPTION_IF_NULL
(
block
);
AnfNodePtr
op_iter
=
block
->
MakeResolveOperation
(
NAMED_PRIMITIVE_ITER
);
...
...
@@ -1088,6 +1134,91 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
// No 'break', no end_block.
return
after_block
;
}
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
// it is compiled to be following statement
// i = 0
// while i < len(xs)
// x = xs[i]
// i = i + 1
// body
FunctionBlockPtr
Parser
::
ParseForLoop
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
)
{
MS_LOG
(
DEBUG
)
<<
"Process ast For by loop variable"
;
MS_EXCEPTION_IF_NULL
(
block
);
AnfNodePtr
op_len
=
block
->
MakeResolveSymbol
(
NAMED_PRIMITIVE_LEN
);
AnfNodePtr
op_getitem
=
block
->
MakeResolveOperation
(
NAMED_PRIMITIVE_GETITEM
);
// get varibale name of 'x' in statement 'for x in xs'
py
::
object
target_node
=
python_adapter
::
GetPyObjAttr
(
node
,
"target"
);
auto
name_id
=
py
::
cast
<
std
::
string
>
(
python_adapter
::
GetPyObjAttr
(
target_node
,
"id"
));
// create statement 'len(xs)'
py
::
object
iter_obj
=
python_adapter
::
GetPyObjAttr
(
node
,
"iter"
);
AnfNodePtr
iter_node
=
ParseExprNode
(
block
,
iter_obj
);
MS_EXCEPTION_IF_NULL
(
iter_node
);
CNodePtr
len_iter
=
block
->
func_graph
()
->
NewCNode
({
op_len
,
iter_node
});
FunctionBlockPtr
header_block
=
GenerateBlockInFor
(
std
::
make_shared
<
TraceForHeader
>
(
block
->
func_graph
()
->
debug_info
()));
MS_EXCEPTION_IF_NULL
(
header_block
);
// create loop variable 'i'
ParameterPtr
loop_var
=
header_block
->
func_graph
()
->
add_parameter
();
// create loop condition 'i < len(xs)'
CNodePtr
cond_node
=
header_block
->
func_graph
()
->
NewCNode
({
NewValueNode
(
prim
::
kPrimScalarLt
),
loop_var
,
len_iter
});
// generate the body of the for statement
FunctionBlockPtr
body_block
=
GenerateBlockInFor
(
std
::
make_shared
<
TraceForBody
>
(
block
->
func_graph
()
->
debug_info
()));
MS_EXCEPTION_IF_NULL
(
body_block
);
body_block
->
AddPrevBlock
(
header_block
);
// create 'x = xs[i]'
CNodePtr
target_var
=
body_block
->
func_graph
()
->
NewCNode
({
op_getitem
,
iter_node
,
loop_var
});
target_var
->
debug_info
()
->
set_name
(
name_id
);
body_block
->
WriteVariable
(
name_id
,
target_var
);
// create 'i = i + 1'
CNodePtr
loop_var_inc
=
body_block
->
func_graph
()
->
NewCNode
({
NewValueNode
(
prim
::
kPrimScalarAdd
),
loop_var
,
NewValueNode
(
1
)});
body_block
->
WriteVariable
(
loop_var
->
name
(),
loop_var_inc
);
loop_var_inc
->
debug_info
()
->
set_name
(
name_id
);
// link the variable name with the target
auto
it_info
=
std
::
make_shared
<
TraceIterator
>
(
loop_var_inc
->
debug_info
());
loop_var
->
debug_info
()
->
set_trace_info
(
it_info
);
len_iter
->
debug_info
()
->
set_trace_info
(
it_info
);
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceForAfter
>
(
block
->
func_graph
()
->
debug_info
()));
FunctionBlockPtr
after_block
=
MakeFunctionBlock
(
*
this
);
MS_EXCEPTION_IF_NULL
(
after_block
);
TraceManager
::
EndTrace
();
after_block
->
AddPrevBlock
(
header_block
);
block
->
Jump
(
header_block
,
NewValueNode
(
0
));
body_block
->
Mature
();
header_block
->
ConditionalJump
(
cond_node
,
body_block
,
after_block
,
false
);
// Parse loop body statements with loop context.
LoopContext
loop_context
{
&
loops_
,
header_block
,
loop_var_inc
};
py
::
object
body_node
=
python_adapter
::
GetPyObjAttr
(
node
,
"body"
);
FunctionBlockPtr
after_body_block
=
ParseStatements
(
body_block
,
body_node
);
if
(
after_body_block
->
func_graph
()
->
get_return
()
==
nullptr
)
{
after_body_block
->
Jump
(
header_block
,
loop_var_inc
);
}
header_block
->
Mature
();
after_block
->
Mature
();
auto
&
end_block
=
loop_context
.
EndBlock
();
if
(
end_block
)
{
// end_block exists if we encounter 'break' in loop body.
after_block
->
Jump
(
end_block
,
nullptr
);
end_block
->
Mature
();
return
end_block
;
}
// No 'break', no end_block.
return
after_block
;
}
AnfNodePtr
Parser
::
ParseIfExp
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
)
{
MS_LOG
(
DEBUG
)
<<
"Process ast IfExp"
;
MS_EXCEPTION_IF_NULL
(
block
);
...
...
mindspore/ccsrc/pipeline/parse/parse.h
浏览文件 @
323a80c6
...
...
@@ -106,6 +106,8 @@ class Parser {
FunctionBlockPtr
ParseWhile
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
// process a for statement
FunctionBlockPtr
ParseFor
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
FunctionBlockPtr
ParseForIter
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
FunctionBlockPtr
ParseForLoop
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
// process a function def statement
FunctionBlockPtr
ParseFunctionDef
(
const
FunctionBlockPtr
&
block
,
const
py
::
object
&
node
);
// process a augment assign
...
...
mindspore/ccsrc/pipeline/parse/parse_base.h
浏览文件 @
323a80c6
...
...
@@ -87,6 +87,7 @@ const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
const
char
PYTHON_MOD_GET_DEFAULT_INPUT
[]
=
"get_default_input"
;
// define the common name
const
char
NAMED_PRIMITIVE_LEN
[]
=
"len"
;
const
char
NAMED_PRIMITIVE_ITER
[]
=
"iter"
;
const
char
NAMED_PRIMITIVE_NEXT
[]
=
"next"
;
const
char
NAMED_PRIMITIVE_GETITEM
[]
=
"getitem"
;
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
323a80c6
...
...
@@ -621,11 +621,8 @@ void Pipeline::Run() {
draw
::
Draw
(
base_name
+
".dot"
,
graph
);
// generate IR file in human readable format
DumpIR
(
base_name
+
".ir"
,
graph
);
// generate IR file in a heavily commented format, which can also be reloaded
if
(
action
.
first
!=
"parse"
)
{
ExportIR
(
base_name
+
".dat"
,
std
::
to_string
(
i
),
graph
);
}
ExportIR
(
base_name
+
".dat"
,
std
::
to_string
(
i
),
graph
);
}
#ifdef MS_DEBUG
// Dump graph cnode list
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
323a80c6
...
...
@@ -600,3 +600,42 @@ def test_while_tensor():
x
=
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))
out
=
net
(
x
,
y
)
def
test_large_for_loop
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
flatten
=
P
.
ReLU
()
#nn.Flatten()
def
construct
(
self
,
x
):
for
elem
in
range
(
1
,
19000
):
x
=
self
.
flatten
(
x
+
elem
)
return
x
t
=
Tensor
(
np
.
ones
([
2
,
3
],
dtype
=
np
.
float32
))
net
=
Net
()
net
(
t
)
def
test_large_for_loop_with_continue_break
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
flatten
=
P
.
ReLU
()
#nn.Flatten()
def
construct
(
self
,
x
):
idx
=
0
for
elem1
in
range
(
200
):
idx
=
idx
+
1
if
idx
<
10
:
x
=
x
+
0.5
continue
if
idx
>
500
:
break
x
=
self
.
flatten
(
x
+
elem1
)
return
x
t
=
Tensor
(
np
.
ones
([
2
,
3
],
dtype
=
np
.
float32
))
net
=
Net
()
net
(
t
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录