提交 237fb749 编写于 作者: 李滨

Merge branch 'python_tool' into 'master'

Refactor image tensor tool

See merge request !881
...@@ -3,6 +3,8 @@ import os ...@@ -3,6 +3,8 @@ import os
import sys import sys
import tensorflow as tf import tensorflow as tf
# TODO(liyin): use dataset api and estimator with distributed strategy
FLAGS = None FLAGS = None
...@@ -32,10 +34,8 @@ def parse_args(): ...@@ -32,10 +34,8 @@ def parse_args():
def images_to_tensors(input_files, image_shape, mean_values=None): def images_to_tensors(input_files, image_shape, mean_values=None):
for i in xrange(len(input_files)): with tf.Graph().as_default():
with tf.Session() as sess: image_data = tf.placeholder(tf.string, name='input')
with tf.gfile.FastGFile(input_files[i], 'rb') as f:
image_data = f.read()
image_data = tf.image.decode_image(image_data, image_data = tf.image.decode_image(image_data,
channels=image_shape[2]) channels=image_shape[2])
if mean_values: if mean_values:
...@@ -54,10 +54,18 @@ def images_to_tensors(input_files, image_shape, mean_values=None): ...@@ -54,10 +54,18 @@ def images_to_tensors(input_files, image_shape, mean_values=None):
image_shape[:2], image_shape[:2],
align_corners=False) align_corners=False)
image = sess.run(image_data) with tf.Session() as sess:
output_file = os.path.join(FLAGS.output_dir, os.path.splitext( for i in xrange(len(input_files)):
os.path.basename(input_files[i]))[0] + '.dat') with tf.gfile.FastGFile(input_files[i], 'rb') as f:
image.tofile(output_file) src_image = f.read()
dst_image = sess.run(image_data,
feed_dict={'input:0': src_image})
output_file = os.path.join(FLAGS.output_dir,
os.path.splitext(
os.path.basename(
input_files[i]))[
0] + '.dat')
dst_image.tofile(output_file)
def main(unused_args): def main(unused_args):
......
...@@ -4,6 +4,8 @@ import sys ...@@ -4,6 +4,8 @@ import sys
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
# TODO(liyin): use dataset api and estimator with distributed strategy
FLAGS = None FLAGS = None
...@@ -27,22 +29,26 @@ def parse_args(): ...@@ -27,22 +29,26 @@ def parse_args():
def tensors_to_images(input_files, image_shape): def tensors_to_images(input_files, image_shape):
for i in xrange(len(input_files)): with tf.Graph().as_default():
with tf.Session() as sess: input = tf.placeholder(tf.float32, shape=image_shape, name='input')
tensor_data = np.fromfile(input_files[i], dtype=np.float32) \ output = tf.placeholder(tf.string, name='output_file')
.reshape(image_shape)
# use the second channel if it is gray image # use the second channel if it is gray image
if image_shape[2] == 2: if image_shape[2] == 2:
_, tensor_data = tf.split(tensor_data, 2, axis=2) _, input = tf.split(input, 2, axis=2)
tensor_data = tf.image.convert_image_dtype(tensor_data, tensor_data = tf.image.convert_image_dtype(input,
tf.uint8, tf.uint8,
saturate=True) saturate=True)
image_data = tf.image.encode_jpeg(tensor_data, quality=100) image_data = tf.image.encode_jpeg(tensor_data, quality=100)
image = sess.run(image_data) writer = tf.write_file(output, image_data, name='output_writer')
with tf.Session() as sess:
for i in xrange(len(input_files)):
input_data = np.fromfile(input_files[i], dtype=np.float32) \
.reshape(image_shape)
output_file = os.path.join(FLAGS.output_dir, os.path.splitext( output_file = os.path.join(FLAGS.output_dir, os.path.splitext(
os.path.basename(input_files[i]))[0] + '.jpg') os.path.basename(input_files[i]))[0] + '.jpg')
writer = tf.write_file(output_file, image) sess.run(writer, feed_dict={'input:0': input_data,
sess.run(writer) 'output_file:0': output_file})
def main(unused_args): def main(unused_args):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册