提交 65cd4184 编写于 作者: K kpy

change manager logic to faster combine_like step

上级 e3d5fa90
...@@ -130,7 +130,7 @@ bool ParseAction(const ResourcePtr &res) { ...@@ -130,7 +130,7 @@ bool ParseAction(const ResourcePtr &res) {
// This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}->
// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
// all obj_map's graph shared base_graph // all obj_map's graph shared base_graph
bool CombineLikeGraphs(const ResourcePtr &) { bool CombineLikeGraphs(const ResourcePtr &res) {
auto &obj_map = parse::data_converter::GetObjGraphs(); auto &obj_map = parse::data_converter::GetObjGraphs();
for (auto it : obj_map) { for (auto it : obj_map) {
...@@ -147,13 +147,15 @@ bool CombineLikeGraphs(const ResourcePtr &) { ...@@ -147,13 +147,15 @@ bool CombineLikeGraphs(const ResourcePtr &) {
if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) { if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) {
continue; continue;
} }
auto mng = Manage(base_graph, false);
for (auto &fv : fg->paramter_obj_nodes()) { for (auto &fv : fg->paramter_obj_nodes()) {
TraceManager::DebugTrace(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
auto param = base_graph->add_parameter(); auto param = base_graph->add_parameter();
TraceManager::EndTrace(); TraceManager::EndTrace();
auto repl_node = (*cloner->cloned_node())[fv]; auto &node_users = res->manager()->node_users()[fv];
(void)mng->Replace(repl_node, param); for (auto &n : node_users) {
auto repl_n = (*cloner->cloned_node())[n.first]->cast<CNodePtr>();
repl_n->set_input(n.second, param);
}
} }
MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size();
......
...@@ -24,9 +24,7 @@ from mindspore.ops import operations as P ...@@ -24,9 +24,7 @@ from mindspore.ops import operations as P
def setup_module(module): def setup_module(module):
context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend")
context.set_context(enable_task_sink = True, device_id = 0)
c1 = Tensor([2], mstype.int32) c1 = Tensor([2], mstype.int32)
c2 = Tensor([14], mstype.int32) c2 = Tensor([14], mstype.int32)
...@@ -135,6 +133,10 @@ def while_in_while_in_while(x, y, z): ...@@ -135,6 +133,10 @@ def while_in_while_in_while(x, y, z):
return out return out
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_simple_if(): def test_simple_if():
output = simple_if(c1, c2, c3) output = simple_if(c1, c2, c3)
expect = Tensor([6], mstype.int32) expect = Tensor([6], mstype.int32)
...@@ -153,30 +155,49 @@ def test_if_in_if(): ...@@ -153,30 +155,49 @@ def test_if_in_if():
assert output == expect assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_simple_while(): def test_simple_while():
output = simple_while(c1, c2, c3) output = simple_while(c1, c2, c3)
expect = Tensor([21], mstype.int32) expect = Tensor([21], mstype.int32)
assert output == expect assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_while_by_while(): def test_while_by_while():
output = while_by_while(c1, c2, c3) output = while_by_while(c1, c2, c3)
expect = Tensor([28], mstype.int32) expect = Tensor([28], mstype.int32)
assert output == expect assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_while_in_while(): def test_while_in_while():
output = while_in_while(c1, c2, c3) output = while_in_while(c1, c2, c3)
expect = Tensor([1274], mstype.int32) expect = Tensor([1274], mstype.int32)
assert output == expect assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_while_by_while_in_while(): def test_while_by_while_in_while():
output = while_by_while_in_while(c1, c2, c3) output = while_by_while_in_while(c1, c2, c3)
expect = Tensor([350], mstype.int32) expect = Tensor([350], mstype.int32)
assert output == expect assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_while_in_while_in_while(): def test_while_in_while_in_while():
output = while_in_while_in_while(c1, c2, c3) output = while_in_while_in_while(c1, c2, c3)
expect = Tensor([2534], mstype.int32) expect = Tensor([2534], mstype.int32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册