diff --git a/tools/image/image_to_tensor.py b/tools/image/image_to_tensor.py index 936d8f3972cc1be4112fbda1faab03550e03580f..d39c07c3e32f6b2f635549d621355ea1c002080c 100644 --- a/tools/image/image_to_tensor.py +++ b/tools/image/image_to_tensor.py @@ -3,6 +3,8 @@ import os import sys import tensorflow as tf +# TODO(liyin): use dataset api and estimator with distributed strategy + FLAGS = None @@ -32,32 +34,38 @@ def parse_args(): def images_to_tensors(input_files, image_shape, mean_values=None): - for i in xrange(len(input_files)): - with tf.Session() as sess: - with tf.gfile.FastGFile(input_files[i], 'rb') as f: - image_data = f.read() - image_data = tf.image.decode_image(image_data, - channels=image_shape[2]) - if mean_values: - image_data = tf.cast(image_data, dtype=tf.float32) - mean_tensor = tf.constant(mean_values, dtype=tf.float32, - shape=[1, 1, image_shape[2]]) - image_data = (image_data - mean_tensor) / 255.0 - else: - image_data = tf.image.convert_image_dtype(image_data, - dtype=tf.float32) - image_data = tf.subtract(image_data, 0.5) - image_data = tf.multiply(image_data, 2.0) + with tf.Graph().as_default(): + image_data = tf.placeholder(tf.string, name='input') + image_data = tf.image.decode_image(image_data, + channels=image_shape[2]) + if mean_values: + image_data = tf.cast(image_data, dtype=tf.float32) + mean_tensor = tf.constant(mean_values, dtype=tf.float32, + shape=[1, 1, image_shape[2]]) + image_data = (image_data - mean_tensor) / 255.0 + else: + image_data = tf.image.convert_image_dtype(image_data, + dtype=tf.float32) + image_data = tf.subtract(image_data, 0.5) + image_data = tf.multiply(image_data, 2.0) - image_data = tf.expand_dims(image_data, 0) - image_data = tf.image.resize_bilinear(image_data, - image_shape[:2], - align_corners=False) + image_data = tf.expand_dims(image_data, 0) + image_data = tf.image.resize_bilinear(image_data, + image_shape[:2], + align_corners=False) - image = sess.run(image_data) - output_file = os.path.join(FLAGS.output_dir, os.path.splitext( - os.path.basename(input_files[i]))[0] + '.dat') - image.tofile(output_file) + with tf.Session() as sess: + for i in xrange(len(input_files)): + with tf.gfile.FastGFile(input_files[i], 'rb') as f: + src_image = f.read() + dst_image = sess.run(image_data, + feed_dict={'input:0': src_image}) + output_file = os.path.join(FLAGS.output_dir, + os.path.splitext( + os.path.basename( + input_files[i]))[ + 0] + '.dat') + dst_image.tofile(output_file) def main(unused_args): diff --git a/tools/image/tensor_to_image.py b/tools/image/tensor_to_image.py index dc903f0132e1b29aca3544fc0082e04ad7e14428..ce18628eae52c98ada0c4abb06ffc61be836d640 100644 --- a/tools/image/tensor_to_image.py +++ b/tools/image/tensor_to_image.py @@ -4,6 +4,8 @@ import sys import numpy as np import tensorflow as tf +# TODO(liyin): use dataset api and estimator with distributed strategy + FLAGS = None @@ -27,22 +29,26 @@ def parse_args(): def tensors_to_images(input_files, image_shape): - for i in xrange(len(input_files)): + with tf.Graph().as_default(): + input = tf.placeholder(tf.float32, shape=image_shape, name='input') + output = tf.placeholder(tf.string, name='output_file') + # use the second channel if it is gray image + if image_shape[2] == 2: + _, input = tf.split(input, 2, axis=2) + tensor_data = tf.image.convert_image_dtype(input, + tf.uint8, + saturate=True) + image_data = tf.image.encode_jpeg(tensor_data, quality=100) + writer = tf.write_file(output, image_data, name='output_writer') + with tf.Session() as sess: - tensor_data = np.fromfile(input_files[i], dtype=np.float32) \ - .reshape(image_shape) - # use the second channel if it is gray image - if image_shape[2] == 2: - _, tensor_data = tf.split(tensor_data, 2, axis=2) - tensor_data = tf.image.convert_image_dtype(tensor_data, - tf.uint8, - saturate=True) - image_data = tf.image.encode_jpeg(tensor_data, quality=100) - image = sess.run(image_data) - output_file = os.path.join(FLAGS.output_dir, os.path.splitext( - os.path.basename(input_files[i]))[0] + '.jpg') - writer = tf.write_file(output_file, image) - sess.run(writer) + for i in xrange(len(input_files)): + input_data = np.fromfile(input_files[i], dtype=np.float32) \ + .reshape(image_shape) + output_file = os.path.join(FLAGS.output_dir, os.path.splitext( + os.path.basename(input_files[i]))[0] + '.jpg') + sess.run(writer, feed_dict={'input:0': input_data, + 'output_file:0': output_file}) def main(unused_args):