未验证 提交 23391a43 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

add dataformat of image (#801)

上级 5a19d2e6
......@@ -80,7 +80,65 @@ def imgarray2bytes(np_array):
return img_bin
def image(tag, image_array, step, walltime=None):
def make_grid(I, ncols=8):
assert isinstance(
I, np.ndarray), 'plugin error, should pass numpy array here'
if I.shape[1] == 1:
I = np.concatenate([I, I, I], 1)
assert I.ndim == 4 and I.shape[1] == 3 or I.shape[1] == 4
nimg = I.shape[0]
H = I.shape[2]
W = I.shape[3]
ncols = min(nimg, ncols)
nrows = int(np.ceil(float(nimg) / ncols))
canvas = np.zeros((I.shape[1], H * nrows, W * ncols), dtype=I.dtype)
i = 0
for y in range(nrows):
for x in range(ncols):
if i >= nimg:
break
canvas[:, y * H:(y + 1) * H, x * W:(x + 1) * W] = I[i]
i = i + 1
return canvas
def convert_to_HWC(tensor, input_format):
"""Convert `NCHW`, `HWC`, `HW` to `HWC`
Args:
tensor (numpy.ndarray): Value of image
input_format (string): Format of image
Return:
Image of format `HWC`.
"""
assert(len(set(input_format)) == len(input_format)), "You can not use the same dimension shordhand twice. \
input_format: {}".format(input_format)
assert(len(tensor.shape) == len(input_format)), "size of input tensor and input format are different. \
tensor shape: {}, input_format: {}".format(tensor.shape, input_format)
input_format = input_format.upper()
if len(input_format) == 4:
index = [input_format.find(c) for c in 'NCHW']
tensor_NCHW = tensor.transpose(index)
tensor_CHW = make_grid(tensor_NCHW)
return tensor_CHW.transpose(1, 2, 0)
if len(input_format) == 3:
index = [input_format.find(c) for c in 'HWC']
tensor_HWC = tensor.transpose(index)
if tensor_HWC.shape[2] == 1:
tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2)
return tensor_HWC
if len(input_format) == 2:
index = [input_format.find(c) for c in 'HW']
tensor = tensor.transpose(index)
tensor = np.stack([tensor, tensor, tensor], 2)
return tensor
def image(tag, image_array, step, walltime=None, dataformats="HWC"):
"""Package data to one image.
Args:
......@@ -92,6 +150,7 @@ def image(tag, image_array, step, walltime=None):
Return:
Package with format of record_pb2.Record
"""
image_array = convert_to_HWC(image_array, dataformats)
image_bytes = imgarray2bytes(image_array)
image = Record.Image(encoded_image_string=image_bytes)
return Record(values=[
......
......@@ -163,7 +163,7 @@ class LogWriter(object):
self._get_file_writer().add_record(
scalar(tag=tag, value=value, step=step, walltime=walltime))
def add_image(self, tag, img, step, walltime=None):
def add_image(self, tag, img, step, walltime=None, dataformats="HWC"):
"""Add an image to vdl record file.
Args:
......@@ -184,7 +184,8 @@ class LogWriter(object):
raise RuntimeError("% can't appear in tag!")
walltime = round(time.time() * 1000) if walltime is None else walltime
self._get_file_writer().add_record(
image(tag=tag, image_array=img, step=step, walltime=walltime))
image(tag=tag, image_array=img, step=step, walltime=walltime,
dataformats=dataformats))
def add_embeddings(self, tag, labels, hot_vectors, walltime=None):
"""Add embeddings to vdl record file.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册