diff --git a/mindinsight/profiler/analyser/analyser.py b/mindinsight/profiler/analyser/analyser.py index 6b40eca828a5a89826c67c505d0ca5ee1eea9675..8dc8838fda30481f9f0495a701c9824e8d10ffc5 100644 --- a/mindinsight/profiler/analyser/analyser.py +++ b/mindinsight/profiler/analyser/analyser.py @@ -124,6 +124,8 @@ class AicoreDetailAnalyser(BaseAnalyser): result = [] for op_type in op_type_order: detail_infos = type_detail_cache.get(op_type) + if detail_infos is None: + continue detail_infos.sort(key=lambda item: item[2], reverse=True) result.extend(detail_infos) diff --git a/tests/ut/profiler/analyser/test_analyser_aicore_detail.py b/tests/ut/profiler/analyser/test_analyser_aicore_detail.py index 71be4f92ab19e506bc4b2a4d91141d5f71006d23..083baff4e9e568af446ee2bff404771aae954fcf 100644 --- a/tests/ut/profiler/analyser/test_analyser_aicore_detail.py +++ b/tests/ut/profiler/analyser/test_analyser_aicore_detail.py @@ -267,7 +267,7 @@ class TestAicoreDetailAnalyser(TestCase): result = self._analyser.query(condition) self.assertDictEqual(expect_result, result) - def test_query_and_sort_by_op_type(self): + def test_query_and_sort_by_op_type_1(self): """Test the success of the querying and sorting function by operator type.""" detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 5, 3, 4]) expect_result = { @@ -289,6 +289,31 @@ class TestAicoreDetailAnalyser(TestCase): ) self.assertDictEqual(expect_result, result) + def test_query_and_sort_by_op_type_2(self): + """Test the success of the querying and sorting function by operator type.""" + detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 3, 4, 8, 6]) + expect_result = { + 'col_name': AicoreDetailAnalyser.__col_names__[0:4], + 'object': [item[0:4] for item in detail_infos] + } + + filter_condition = { + 'op_type': {}, + 'subgraph': { + 'in': ['Default'] + }, + 'is_display_detail': False, + 'is_display_full_op_name': False + } + op_type_order = [ + 'MatMul', 'AtomicAddrClean', 'Cast', 'Conv2D', 'TransData' + ] + result = self._analyser.query_and_sort_by_op_type( + filter_condition, op_type_order + ) + print(result) + self.assertDictEqual(expect_result, result) + def test_col_names(self): """Test the querying column names function.""" self.assertListEqual(