diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 778600dc0a91d92fecfb2dd17e309e57b6775581..f15723d64d354b4286656a72113e8183fbfc3bd9 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -130,7 +130,7 @@ bool ParseAction(const ResourcePtr &res) { // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} // all obj_map's graph shared base_graph -bool CombineLikeGraphs(const ResourcePtr &) { +bool CombineLikeGraphs(const ResourcePtr &res) { auto &obj_map = parse::data_converter::GetObjGraphs(); for (auto it : obj_map) { @@ -147,13 +147,15 @@ bool CombineLikeGraphs(const ResourcePtr &) { if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) { continue; } - auto mng = Manage(base_graph, false); for (auto &fv : fg->paramter_obj_nodes()) { TraceManager::DebugTrace(std::make_shared(fv->debug_info())); auto param = base_graph->add_parameter(); TraceManager::EndTrace(); - auto repl_node = (*cloner->cloned_node())[fv]; - (void)mng->Replace(repl_node, param); + auto &node_users = res->manager()->node_users()[fv]; + for (auto &n : node_users) { + auto repl_n = (*cloner->cloned_node())[n.first]->cast(); + repl_n->set_input(n.second, param); + } } MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); diff --git a/tests/st/control/test_multigraph_sink.py b/tests/st/control/test_multigraph_sink.py index b2732a63d49a422d45cad752de5788421d2c907c..2b9a1a020aae323658352bcb9da4c24c7870dff7 100644 --- a/tests/st/control/test_multigraph_sink.py +++ b/tests/st/control/test_multigraph_sink.py @@ -24,9 +24,7 @@ from mindspore.ops import operations as P def setup_module(module): - context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") - context.set_context(enable_task_sink = True, device_id = 0) - + context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend") c1 = Tensor([2], mstype.int32) c2 = Tensor([14], mstype.int32) @@ -135,6 +133,10 @@ def while_in_while_in_while(x, y, z): return out +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard def test_simple_if(): output = simple_if(c1, c2, c3) expect = Tensor([6], mstype.int32) @@ -153,30 +155,49 @@ def test_if_in_if(): assert output == expect +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard def test_simple_while(): output = simple_while(c1, c2, c3) expect = Tensor([21], mstype.int32) assert output == expect +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard def test_while_by_while(): output = while_by_while(c1, c2, c3) expect = Tensor([28], mstype.int32) assert output == expect +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard def test_while_in_while(): output = while_in_while(c1, c2, c3) expect = Tensor([1274], mstype.int32) assert output == expect +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard def test_while_by_while_in_while(): output = while_by_while_in_while(c1, c2, c3) expect = Tensor([350], mstype.int32) assert output == expect - +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard def test_while_in_while_in_while(): output = while_in_while_in_while(c1, c2, c3) expect = Tensor([2534], mstype.int32)