diff --git a/visualdl/component/base_component.py b/visualdl/component/base_component.py index 04516ad84f8f15a33defe075655c138e02683509..932be4a5c77d32828ecc3fd9c96e523468999327 100644 --- a/visualdl/component/base_component.py +++ b/visualdl/component/base_component.py @@ -80,7 +80,65 @@ def imgarray2bytes(np_array): return img_bin -def image(tag, image_array, step, walltime=None): +def make_grid(I, ncols=8): + assert isinstance( + I, np.ndarray), 'plugin error, should pass numpy array here' + if I.shape[1] == 1: + I = np.concatenate([I, I, I], 1) + assert I.ndim == 4 and I.shape[1] == 3 or I.shape[1] == 4 + nimg = I.shape[0] + H = I.shape[2] + W = I.shape[3] + ncols = min(nimg, ncols) + nrows = int(np.ceil(float(nimg) / ncols)) + canvas = np.zeros((I.shape[1], H * nrows, W * ncols), dtype=I.dtype) + i = 0 + for y in range(nrows): + for x in range(ncols): + if i >= nimg: + break + canvas[:, y * H:(y + 1) * H, x * W:(x + 1) * W] = I[i] + i = i + 1 + return canvas + + +def convert_to_HWC(tensor, input_format): + """Convert `NCHW`, `HWC`, `HW` to `HWC` + + Args: + tensor (numpy.ndarray): Value of image + input_format (string): Format of image + + Return: + Image of format `HWC`. + """ + assert(len(set(input_format)) == len(input_format)), "You can not use the same dimension shordhand twice. \ + input_format: {}".format(input_format) + assert(len(tensor.shape) == len(input_format)), "size of input tensor and input format are different. \ + tensor shape: {}, input_format: {}".format(tensor.shape, input_format) + input_format = input_format.upper() + + if len(input_format) == 4: + index = [input_format.find(c) for c in 'NCHW'] + tensor_NCHW = tensor.transpose(index) + tensor_CHW = make_grid(tensor_NCHW) + return tensor_CHW.transpose(1, 2, 0) + + if len(input_format) == 3: + index = [input_format.find(c) for c in 'HWC'] + tensor_HWC = tensor.transpose(index) + if tensor_HWC.shape[2] == 1: + tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2) + return tensor_HWC + + if len(input_format) == 2: + index = [input_format.find(c) for c in 'HW'] + tensor = tensor.transpose(index) + tensor = np.stack([tensor, tensor, tensor], 2) + return tensor + + +def image(tag, image_array, step, walltime=None, dataformats="HWC"): """Package data to one image. Args: @@ -92,6 +150,7 @@ def image(tag, image_array, step, walltime=None): Return: Package with format of record_pb2.Record """ + image_array = convert_to_HWC(image_array, dataformats) image_bytes = imgarray2bytes(image_array) image = Record.Image(encoded_image_string=image_bytes) return Record(values=[ diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index 140c244da871eb8979127540b918791a5d4d5658..ea8eaea9badaf6eba82d0ede8f0c90b33b3b6e17 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -163,7 +163,7 @@ class LogWriter(object): self._get_file_writer().add_record( scalar(tag=tag, value=value, step=step, walltime=walltime)) - def add_image(self, tag, img, step, walltime=None): + def add_image(self, tag, img, step, walltime=None, dataformats="HWC"): """Add an image to vdl record file. Args: @@ -184,7 +184,8 @@ class LogWriter(object): raise RuntimeError("% can't appear in tag!") walltime = round(time.time() * 1000) if walltime is None else walltime self._get_file_writer().add_record( - image(tag=tag, image_array=img, step=step, walltime=walltime)) + image(tag=tag, image_array=img, step=step, walltime=walltime, + dataformats=dataformats)) def add_embeddings(self, tag, labels, hot_vectors, walltime=None): """Add embeddings to vdl record file.