From b381439349c4c6ebc96b2abe115dc0c4a4d21ab1 Mon Sep 17 00:00:00 2001 From: chenchao99 Date: Thu, 28 May 2020 13:17:26 +0800 Subject: [PATCH] fix the bug that when the profiler parameter subgraph is Default or Gradients, the profiler analyse will raise an exception --- mindinsight/profiler/analyser/analyser.py | 2 ++ .../analyser/test_analyser_aicore_detail.py | 27 ++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/mindinsight/profiler/analyser/analyser.py b/mindinsight/profiler/analyser/analyser.py index 6b40eca..8dc8838 100644 --- a/mindinsight/profiler/analyser/analyser.py +++ b/mindinsight/profiler/analyser/analyser.py @@ -124,6 +124,8 @@ class AicoreDetailAnalyser(BaseAnalyser): result = [] for op_type in op_type_order: detail_infos = type_detail_cache.get(op_type) + if detail_infos is None: + continue detail_infos.sort(key=lambda item: item[2], reverse=True) result.extend(detail_infos) diff --git a/tests/ut/profiler/analyser/test_analyser_aicore_detail.py b/tests/ut/profiler/analyser/test_analyser_aicore_detail.py index 71be4f9..083baff 100644 --- a/tests/ut/profiler/analyser/test_analyser_aicore_detail.py +++ b/tests/ut/profiler/analyser/test_analyser_aicore_detail.py @@ -267,7 +267,7 @@ class TestAicoreDetailAnalyser(TestCase): result = self._analyser.query(condition) self.assertDictEqual(expect_result, result) - def test_query_and_sort_by_op_type(self): + def test_query_and_sort_by_op_type_1(self): """Test the success of the querying and sorting function by operator type.""" detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 5, 3, 4]) expect_result = { @@ -289,6 +289,31 @@ class TestAicoreDetailAnalyser(TestCase): ) self.assertDictEqual(expect_result, result) + def test_query_and_sort_by_op_type_2(self): + """Test the success of the querying and sorting function by operator type.""" + detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 3, 4, 8, 6]) + expect_result = { + 'col_name': AicoreDetailAnalyser.__col_names__[0:4], + 'object': [item[0:4] for item in detail_infos] + } + + filter_condition = { + 'op_type': {}, + 'subgraph': { + 'in': ['Default'] + }, + 'is_display_detail': False, + 'is_display_full_op_name': False + } + op_type_order = [ + 'MatMul', 'AtomicAddrClean', 'Cast', 'Conv2D', 'TransData' + ] + result = self._analyser.query_and_sort_by_op_type( + filter_condition, op_type_order + ) + print(result) + self.assertDictEqual(expect_result, result) + def test_col_names(self): """Test the querying column names function.""" self.assertListEqual( -- GitLab