提交 323a80c6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2765 fix large for loop segment fault

Merge pull request !2765 from fary86/fix_large_for_loop
......@@ -294,6 +294,12 @@ extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
extern const PrimitivePtr kPrimIsIndexedSlices;
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
// will be sunk(i.e. not unrolled)
const int MAX_FOR_LOOP_COUNT = 200;
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
......
......@@ -95,7 +95,7 @@ AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &prim
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
}
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim,
const AbstractBasePtrList &args_spec_list) {
// Inputs: condition, true branch, false branch
if (args_spec_list.size() != 3) {
......@@ -108,6 +108,11 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
auto fb = args_spec_list[2];
MS_EXCEPTION_IF_NULL(cond);
auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG);
if (unroll_flag != nullptr && GetValue<int>(unroll_flag) == 0) {
return tb->Join(fb);
}
ValuePtr v = cond->GetValueTrack();
MS_EXCEPTION_IF_NULL(v);
// for tensor as condition, keeps both true and false branch.
......
......@@ -208,6 +208,11 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
ValuePtr index_value = index->BuildValue();
if (!index_value->isa<Int32Imm>()) {
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
// and continue
if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
}
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
<< index_value->ToString();
}
......
......@@ -294,13 +294,18 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node)
// Perform a conditional jump using switch operation.
// The first CNode select graph with condition, and than execute this graph
void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block,
const FunctionBlockPtr &false_block) {
const FunctionBlockPtr &false_block, bool unroll_loop) {
if (func_graph()->get_return() != nullptr) {
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
}
// Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch'
auto prim_switch = std::make_shared<Primitive>(prim::kPrimSwitch->name());
if (!unroll_loop) {
prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0));
}
CNodePtr switch_app =
func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()),
NewValueNode(false_block->func_graph())});
CNodePtr switch_app_new = func_graph()->NewCNode({switch_app});
func_graph()->set_output(switch_app_new);
......
......@@ -59,7 +59,8 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
void Jump(const FunctionBlockPtr &block, AnfNodePtr node);
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock);
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock,
bool unroll_loop = true);
// record the assign statement of self.xx weight parameter ,which will use state_setitem op
void SetStateAssgin(const AnfNodePtr &target, const std::string &readid);
void AddAutoDepend(const AnfNodePtr &target);
......
......@@ -1002,6 +1002,7 @@ CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::
AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
return block->func_graph()->NewCNode({op_iter, iter_anf_node});
}
CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
const AnfNodePtr &op_hasnext) {
MS_EXCEPTION_IF_NULL(header_block);
......@@ -1018,12 +1019,57 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
// it compiled to be following statement
// it is compiled to be following statement
// if len(xs) < max_loop_cnt:
// ParseForIter() // use iter to implement for loop, which always unroll loop
// else:
// ParseForLoop() // use loop var to implement for loop, which always sink loop
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For, create an if else statement";
MS_EXCEPTION_IF_NULL(block);
// create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT'
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
CNodePtr bool_node = block->func_graph()->NewCNode(
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)});
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr true_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr false_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
MakeConditionBlocks(block, true_block, false_block);
TraceManager::DebugTrace(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
FunctionBlockPtr true_end = ParseForIter(true_block, node);
true_end->Jump(after_block, nullptr);
FunctionBlockPtr false_end = ParseForLoop(false_block, node);
false_end->Jump(after_block, nullptr);
block->ConditionalJump(bool_node, true_block, false_block);
after_block->Mature();
return after_block;
}
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
// it is compiled to be following statement
// it = iter(xs)
// while hastnext(it)
// x, it = next(it)
// body
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For";
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
......@@ -1088,6 +1134,91 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
// No 'break', no end_block.
return after_block;
}
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
// it is compiled to be following statement
// i = 0
// while i < len(xs)
// x = xs[i]
// i = i + 1
// body
FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For by loop variable";
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
// get varibale name of 'x' in statement 'for x in xs'
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(target_node, "id"));
// create statement 'len(xs)'
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
MS_EXCEPTION_IF_NULL(iter_node);
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
FunctionBlockPtr header_block =
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(header_block);
// create loop variable 'i'
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
// create loop condition 'i < len(xs)'
CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter});
// generate the body of the for statement
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(body_block);
body_block->AddPrevBlock(header_block);
// create 'x = xs[i]'
CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var});
target_var->debug_info()->set_name(name_id);
body_block->WriteVariable(name_id, target_var);
// create 'i = i + 1'
CNodePtr loop_var_inc =
body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)});
body_block->WriteVariable(loop_var->name(), loop_var_inc);
loop_var_inc->debug_info()->set_name(name_id);
// link the variable name with the target
auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
loop_var->debug_info()->set_trace_info(it_info);
len_iter->debug_info()->set_trace_info(it_info);
TraceManager::DebugTrace(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
MS_EXCEPTION_IF_NULL(after_block);
TraceManager::EndTrace();
after_block->AddPrevBlock(header_block);
block->Jump(header_block, NewValueNode(0));
body_block->Mature();
header_block->ConditionalJump(cond_node, body_block, after_block, false);
// Parse loop body statements with loop context.
LoopContext loop_context{&loops_, header_block, loop_var_inc};
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
if (after_body_block->func_graph()->get_return() == nullptr) {
after_body_block->Jump(header_block, loop_var_inc);
}
header_block->Mature();
after_block->Mature();
auto &end_block = loop_context.EndBlock();
if (end_block) {
// end_block exists if we encounter 'break' in loop body.
after_block->Jump(end_block, nullptr);
end_block->Mature();
return end_block;
}
// No 'break', no end_block.
return after_block;
}
AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast IfExp";
MS_EXCEPTION_IF_NULL(block);
......
......@@ -106,6 +106,8 @@ class Parser {
FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
// process a for statement
FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
// process a function def statement
FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node);
// process a augment assign
......
......@@ -87,6 +87,7 @@ const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input";
// define the common name
const char NAMED_PRIMITIVE_LEN[] = "len";
const char NAMED_PRIMITIVE_ITER[] = "iter";
const char NAMED_PRIMITIVE_NEXT[] = "next";
const char NAMED_PRIMITIVE_GETITEM[] = "getitem";
......
......@@ -621,11 +621,8 @@ void Pipeline::Run() {
draw::Draw(base_name + ".dot", graph);
// generate IR file in human readable format
DumpIR(base_name + ".ir", graph);
// generate IR file in a heavily commented format, which can also be reloaded
if (action.first != "parse") {
ExportIR(base_name + ".dat", std::to_string(i), graph);
}
ExportIR(base_name + ".dat", std::to_string(i), graph);
}
#ifdef MS_DEBUG
// Dump graph cnode list
......
......@@ -600,3 +600,42 @@ def test_while_tensor():
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
out = net(x, y)
def test_large_for_loop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flatten = P.ReLU() #nn.Flatten()
def construct(self, x):
for elem in range(1, 19000):
x = self.flatten(x + elem)
return x
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)
def test_large_for_loop_with_continue_break():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flatten = P.ReLU() #nn.Flatten()
def construct(self, x):
idx = 0
for elem1 in range(200):
idx = idx + 1
if idx < 10:
x = x + 0.5
continue
if idx > 500:
break
x = self.flatten(x + elem1)
return x
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册