提交 3b4caa8f 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 525618180
上级 e545ea9b
......@@ -82,7 +82,8 @@ class Parser(parser.Parser):
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes.
image_field_key: `str`, the key name to encoded image in tf.Example.
image_field_key: `str`, the key name to encoded image or decoded image
matrix in tf.Example.
label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True.
......@@ -185,8 +186,15 @@ class Parser(parser.Parser):
def _parse_train_image(self, decoded_tensors):
"""Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only and self._aug_crop:
require_decoding = (
not tf.is_tensor(image_bytes) or image_bytes.dtype == tf.dtypes.string
)
if (
require_decoding
and self._decode_jpeg_only
and self._aug_crop
):
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image.
......@@ -197,9 +205,13 @@ class Parser(parser.Parser):
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape),
lambda: cropped_image)
else:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
if require_decoding:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
else:
# Already decoded image matrix
image = image_bytes
# Crops image.
if self._aug_crop:
......@@ -252,17 +264,28 @@ class Parser(parser.Parser):
def _parse_eval_image(self, decoded_tensors):
"""Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key]
if self._decode_jpeg_only and self._aug_crop:
require_decoding = (
not tf.is_tensor(image_bytes) or image_bytes.dtype == tf.dtypes.string
)
if (
require_decoding
and self._decode_jpeg_only
and self._aug_crop
):
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops.
image = preprocess_ops.center_crop_image_v2(
image_bytes, image_shape, self._center_crop_fraction)
else:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
if require_decoding:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
else:
# Already decoded image matrix
image = image_bytes
# Center crops.
if self._aug_crop:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册