From c04eea1e288900e50f3a392cc305c71fbfab7844 Mon Sep 17 00:00:00 2001 From: Jonathan Huang Date: Thu, 22 Oct 2020 12:30:36 -0700 Subject: [PATCH] Add support in label map proto for LVIS specific fields. PiperOrigin-RevId: 338526646 --- .../protos/string_int_label_map.proto | 12 +++++ .../object_detection/utils/label_map_util.py | 12 +++++ .../utils/label_map_util_test.py | 54 ++++++++++++++++--- 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/research/object_detection/protos/string_int_label_map.proto b/research/object_detection/protos/string_int_label_map.proto index c9095a9d7..d77dd92af 100644 --- a/research/object_detection/protos/string_int_label_map.proto +++ b/research/object_detection/protos/string_int_label_map.proto @@ -6,6 +6,14 @@ syntax = "proto2"; package object_detection.protos; +// LVIS frequency: +enum LVISFrequency { + UNSPECIFIED = 0; + FREQUENT = 1; + COMMON = 2; + RARE = 3; +} + message StringIntLabelMapItem { // String name. The most common practice is to set this to a MID or synsets // id. @@ -38,6 +46,10 @@ message StringIntLabelMapItem { // current element. Value should correspond to another label id element. repeated int32 ancestor_ids = 5; repeated int32 descendant_ids = 6; + + // LVIS specific label map fields + optional LVISFrequency frequency = 7; + optional int32 instance_count = 8; }; message StringIntLabelMap { diff --git a/research/object_detection/utils/label_map_util.py b/research/object_detection/utils/label_map_util.py index 37c823a8d..ecf7d82fb 100644 --- a/research/object_detection/utils/label_map_util.py +++ b/research/object_detection/utils/label_map_util.py @@ -130,6 +130,18 @@ def convert_label_map_to_categories(label_map, if item.id not in list_of_ids_already_added: list_of_ids_already_added.append(item.id) category = {'id': item.id, 'name': name} + if item.HasField('frequency'): + if item.frequency == string_int_label_map_pb2.LVISFrequency.Value( + 'FREQUENT'): + category['frequency'] = 'f' + elif item.frequency == string_int_label_map_pb2.LVISFrequency.Value( + 'COMMON'): + category['frequency'] = 'c' + elif item.frequency == string_int_label_map_pb2.LVISFrequency.Value( + 'RARE'): + category['frequency'] = 'r' + if item.HasField('instance_count'): + category['instance_count'] = item.instance_count if item.keypoints: keypoints = {} list_of_keypoint_ids = [] diff --git a/research/object_detection/utils/label_map_util_test.py b/research/object_detection/utils/label_map_util_test.py index 969f3258b..cd5bb4169 100644 --- a/research/object_detection/utils/label_map_util_test.py +++ b/research/object_detection/utils/label_map_util_test.py @@ -201,7 +201,7 @@ class LabelMapUtilTest(tf.test.TestCase): name:'n00007846' } """ - text_format.Merge(label_map_string, label_map_proto) + text_format.Parse(label_map_string, label_map_proto) categories = label_map_util.convert_label_map_to_categories( label_map_proto, max_num_classes=3) self.assertListEqual([{ @@ -227,19 +227,61 @@ class LabelMapUtilTest(tf.test.TestCase): }] self.assertListEqual(expected_categories_list, categories) + def test_convert_label_map_to_categories_lvis_frequency_and_counts(self): + label_map_proto = string_int_label_map_pb2.StringIntLabelMap() + label_map_string = """ + item { + id:1 + name:'person' + frequency: FREQUENT + instance_count: 1000 + } + item { + id:2 + name:'dog' + frequency: COMMON + instance_count: 100 + } + item { + id:3 + name:'cat' + frequency: RARE + instance_count: 10 + } + """ + text_format.Parse(label_map_string, label_map_proto) + categories = label_map_util.convert_label_map_to_categories( + label_map_proto, max_num_classes=3) + self.assertListEqual([{ + 'id': 1, + 'name': u'person', + 'frequency': 'f', + 'instance_count': 1000 + }, { + 'id': 2, + 'name': u'dog', + 'frequency': 'c', + 'instance_count': 100 + }, { + 'id': 3, + 'name': u'cat', + 'frequency': 'r', + 'instance_count': 10 + }], categories) + def test_convert_label_map_to_categories(self): label_map_proto = self._generate_label_map(num_classes=4) categories = label_map_util.convert_label_map_to_categories( label_map_proto, max_num_classes=3) expected_categories_list = [{ 'name': u'1', - 'id': 1 + 'id': 1, }, { 'name': u'2', - 'id': 2 + 'id': 2, }, { 'name': u'3', - 'id': 3 + 'id': 3, }] self.assertListEqual(expected_categories_list, categories) @@ -259,7 +301,7 @@ class LabelMapUtilTest(tf.test.TestCase): } """ label_map_proto = string_int_label_map_pb2.StringIntLabelMap() - text_format.Merge(label_map_str, label_map_proto) + text_format.Parse(label_map_str, label_map_proto) categories = label_map_util.convert_label_map_to_categories( label_map_proto, max_num_classes=1) self.assertEqual('person', categories[0]['name']) @@ -291,7 +333,7 @@ class LabelMapUtilTest(tf.test.TestCase): } """ label_map_proto = string_int_label_map_pb2.StringIntLabelMap() - text_format.Merge(label_map_str, label_map_proto) + text_format.Parse(label_map_str, label_map_proto) with self.assertRaises(ValueError): label_map_util.convert_label_map_to_categories( label_map_proto, max_num_classes=2) -- GitLab