提交 c04eea1e 编写于 作者: J Jonathan Huang 提交者: TF Object Detection Team

Add support in label map proto for LVIS specific fields.

PiperOrigin-RevId: 338526646
上级 aca137c1
......@@ -6,6 +6,14 @@ syntax = "proto2";
package object_detection.protos;
// LVIS frequency:
enum LVISFrequency {
UNSPECIFIED = 0;
FREQUENT = 1;
COMMON = 2;
RARE = 3;
}
message StringIntLabelMapItem {
// String name. The most common practice is to set this to a MID or synsets
// id.
......@@ -38,6 +46,10 @@ message StringIntLabelMapItem {
// current element. Value should correspond to another label id element.
repeated int32 ancestor_ids = 5;
repeated int32 descendant_ids = 6;
// LVIS specific label map fields
optional LVISFrequency frequency = 7;
optional int32 instance_count = 8;
};
message StringIntLabelMap {
......
......@@ -130,6 +130,18 @@ def convert_label_map_to_categories(label_map,
if item.id not in list_of_ids_already_added:
list_of_ids_already_added.append(item.id)
category = {'id': item.id, 'name': name}
if item.HasField('frequency'):
if item.frequency == string_int_label_map_pb2.LVISFrequency.Value(
'FREQUENT'):
category['frequency'] = 'f'
elif item.frequency == string_int_label_map_pb2.LVISFrequency.Value(
'COMMON'):
category['frequency'] = 'c'
elif item.frequency == string_int_label_map_pb2.LVISFrequency.Value(
'RARE'):
category['frequency'] = 'r'
if item.HasField('instance_count'):
category['instance_count'] = item.instance_count
if item.keypoints:
keypoints = {}
list_of_keypoint_ids = []
......
......@@ -201,7 +201,7 @@ class LabelMapUtilTest(tf.test.TestCase):
name:'n00007846'
}
"""
text_format.Merge(label_map_string, label_map_proto)
text_format.Parse(label_map_string, label_map_proto)
categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3)
self.assertListEqual([{
......@@ -227,19 +227,61 @@ class LabelMapUtilTest(tf.test.TestCase):
}]
self.assertListEqual(expected_categories_list, categories)
def test_convert_label_map_to_categories_lvis_frequency_and_counts(self):
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
label_map_string = """
item {
id:1
name:'person'
frequency: FREQUENT
instance_count: 1000
}
item {
id:2
name:'dog'
frequency: COMMON
instance_count: 100
}
item {
id:3
name:'cat'
frequency: RARE
instance_count: 10
}
"""
text_format.Parse(label_map_string, label_map_proto)
categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3)
self.assertListEqual([{
'id': 1,
'name': u'person',
'frequency': 'f',
'instance_count': 1000
}, {
'id': 2,
'name': u'dog',
'frequency': 'c',
'instance_count': 100
}, {
'id': 3,
'name': u'cat',
'frequency': 'r',
'instance_count': 10
}], categories)
def test_convert_label_map_to_categories(self):
label_map_proto = self._generate_label_map(num_classes=4)
categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3)
expected_categories_list = [{
'name': u'1',
'id': 1
'id': 1,
}, {
'name': u'2',
'id': 2
'id': 2,
}, {
'name': u'3',
'id': 3
'id': 3,
}]
self.assertListEqual(expected_categories_list, categories)
......@@ -259,7 +301,7 @@ class LabelMapUtilTest(tf.test.TestCase):
}
"""
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
text_format.Merge(label_map_str, label_map_proto)
text_format.Parse(label_map_str, label_map_proto)
categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=1)
self.assertEqual('person', categories[0]['name'])
......@@ -291,7 +333,7 @@ class LabelMapUtilTest(tf.test.TestCase):
}
"""
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
text_format.Merge(label_map_str, label_map_proto)
text_format.Parse(label_map_str, label_map_proto)
with self.assertRaises(ValueError):
label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册