提交 7007d9e3 编写于 作者: Z Zhichao Lu 提交者: pkulzc

Updating transform_input_data to resize original image. This is necessary for...

Updating transform_input_data to resize original image. This is necessary for result_dict_for_single_example(), since it expects the input image and groundtruth masks to be of the same spatial dimensions.

PiperOrigin-RevId: 189786443
上级 6f1756bc
......@@ -58,7 +58,8 @@ def transform_input_data(tensor_dict,
Data transformation functions are applied in the following order.
1. data_augmentation_fn (optional): applied on tensor_dict.
2. model_preprocess_fn: applied only on image tensor in tensor_dict.
3. image_resizer_fn: applied only on instance mask tensor in tensor_dict.
3. image_resizer_fn: applied on original image and instance mask tensor in
tensor_dict.
4. one_hot_encoding: applied to classes tensor in tensor_dict.
5. merge_multiple_boxes (optional): when groundtruth boxes are exactly the
same they can be merged into a single box with an associated k-hot class
......@@ -70,10 +71,11 @@ def transform_input_data(tensor_dict,
model_preprocess_fn: model's preprocess function to apply on image tensor.
This function must take in a 4-D float tensor and return a 4-D preprocess
float tensor and a tensor containing the true image shape.
image_resizer_fn: image resizer function to apply on groundtruth instance
masks. This function must take a 4-D float tensor of image and a 4-D
tensor of instances masks and return resized version of these along with
the true shapes.
image_resizer_fn: image resizer function to apply on original image (if
`retain_original_image` is True) and groundtruth instance masks. This
function must take a 3-D float tensor of an image and a 3-D tensor of
instance masks and return a resized version of these along with the true
shapes.
num_classes: number of max classes to one-hot (or k-hot) encode the class
labels.
data_augmentation_fn: (optional) data augmentation function to apply on
......@@ -88,17 +90,19 @@ def transform_input_data(tensor_dict,
after applying all the transformations.
"""
if retain_original_image:
tensor_dict[fields.InputDataFields.
original_image] = tensor_dict[fields.InputDataFields.image]
original_image_resized, _ = image_resizer_fn(
tensor_dict[fields.InputDataFields.image])
tensor_dict[fields.InputDataFields.original_image] = tf.cast(
original_image_resized, tf.uint8)
# Apply data augmentation ops.
if data_augmentation_fn is not None:
tensor_dict = data_augmentation_fn(tensor_dict)
# Apply model preprocessing ops and resize instance masks.
image = tf.expand_dims(
tf.to_float(tensor_dict[fields.InputDataFields.image]), axis=0)
preprocessed_resized_image, true_image_shape = model_preprocess_fn(image)
image = tensor_dict[fields.InputDataFields.image]
preprocessed_resized_image, true_image_shape = model_preprocess_fn(
tf.expand_dims(tf.to_float(image), axis=0))
tensor_dict[fields.InputDataFields.image] = tf.squeeze(
preprocessed_resized_image, axis=0)
tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze(
......
......@@ -462,22 +462,31 @@ class DataTransformationFnTest(tf.test.TestCase):
fields.InputDataFields.groundtruth_classes:
tf.constant(np.array([3, 1], np.int32))
}
def fake_image_resizer_fn(image, masks):
def fake_image_resizer_fn(image, masks=None):
resized_image = tf.image.resize_images(image, [8, 8])
resized_masks = tf.transpose(
tf.image.resize_images(tf.transpose(masks, [1, 2, 0]), [8, 8]),
[2, 0, 1])
return resized_image, resized_masks, tf.shape(resized_image)
results = [resized_image]
if masks is not None:
resized_masks = tf.transpose(
tf.image.resize_images(tf.transpose(masks, [1, 2, 0]), [8, 8]),
[2, 0, 1])
results.append(resized_masks)
results.append(tf.shape(resized_image))
return results
num_classes = 3
input_transformation_fn = functools.partial(
inputs.transform_input_data,
model_preprocess_fn=_fake_model_preprocessor_fn,
image_resizer_fn=fake_image_resizer_fn,
num_classes=num_classes)
num_classes=num_classes,
retain_original_image=True)
with self.test_session() as sess:
transformed_inputs = sess.run(
input_transformation_fn(tensor_dict=tensor_dict))
self.assertAllEqual(transformed_inputs[
fields.InputDataFields.original_image].dtype, tf.uint8)
self.assertAllEqual(transformed_inputs[
fields.InputDataFields.original_image].shape, [8, 8, 3])
self.assertAllEqual(transformed_inputs[
fields.InputDataFields.groundtruth_instance_masks].shape, [2, 8, 8])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册