提交 373eeb61 编写于 作者: S sineeli

test case for missing object annotation

上级 b054d2eb
......@@ -319,6 +319,7 @@ def create_tf_example(image,
feature_dict = tfrecord_lib.image_info_to_feature_dict(
image_height, image_width, filename, image_id, encoded_jpg, 'jpg')
feature_dict_len = len(feature_dict)
num_annotations_skipped = 0
if bbox_annotations:
box_feature_dict, num_skipped = bbox_annotations_to_feature_dict(
......@@ -352,7 +353,11 @@ def create_tf_example(image,
encoded_panoptic_masks['instance_mask'])
})
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
if feature_dict_len == len(feature_dict):
example = None
else:
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example, num_annotations_skipped
......
......@@ -21,6 +21,7 @@ from absl.testing import parameterized
import tensorflow as tf
from official.vision.data import tfrecord_lib
from official.vision.data.create_coco_tf_record import generate_annotations, create_tf_example
FLAGS = flags.FLAGS
......@@ -87,7 +88,90 @@ class TfrecordLibTest(parameterized.TestCase):
proto = tfrecord_lib.convert_to_feature([b'123', b'456'])
self.assertSequenceAlmostEqual(proto.bytes_list.value, [b'123', b'456'])
def test_obj_annotation_tf_example(self):
images = [
{
"id": 0,
"file_name": "example1.jpg",
"height": 512,
"width": 512,
},
{
"id": 1,
"file_name": "example2.jpg",
"height": 512,
"width": 512,
}
]
img_to_obj_annotation = {
0:
[
{
"id": 0,
"image_id": 0,
"category_id": 1,
"bbox": [3, 1, 511, 510],
"area": 260610.00,
"segmentation": [],
"iscrowd": 0
}
],
1:
[
{
"id": 1,
"image_id": 1,
"category_id": 1,
"bbox": [1, 1, 100, 150],
"area": 15000.00,
"segmentation": [],
"iscrowd": 0
}
]
}
id_to_name_map = {
0: 'Super-Class',
1: 'Class-1'
}
temp_dir = FLAGS.test_tmpdir
image_dir = os.path.join(temp_dir, 'data')
if not os.path.exists(image_dir):
os.mkdir(image_dir)
for image in images:
image_path = os.path.join(image_dir, image['file_name'])
tf.keras.utils.save_img(
image_path, tf.ones(shape=(image['height'], image['width'], 3)).numpy())
output_path = os.path.join(image_dir, 'train')
coco_annotations_iter = generate_annotations(
images=images,
image_dirs=[image_dir],
panoptic_masks_dir=None,
img_to_obj_annotation=img_to_obj_annotation,
img_to_caption_annotation=None,
img_to_panoptic_annotation=None,
is_category_thing=None,
id_to_name_map=id_to_name_map,
include_panoptic_masks=False,
include_masks=False)
tfrecord_lib.write_tf_record_dataset(
output_path, coco_annotations_iter, create_tf_example, 1, multiple_processes=0)
tfrecord_files = tf.io.gfile.glob(output_path + '*')
self.assertLen(tfrecord_files, 1)
ds = tf.data.TFRecordDataset(tfrecord_files)
assertion_count = 0
for _ in ds:
assertion_count += 1
self.assertEqual(assertion_count, 1)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册