From de2ec2ea839a23fa7269f3f896f283fa09498ea7 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 23 Mar 2021 18:58:14 +0800 Subject: [PATCH] add add_figure for LogWriter (#941) * add add_figure for LogWriter * format code * update date Co-authored-by: ShenYuhan --- demo/components/figure_test.py | 29 +++++++++++++++++++++ requirements.txt | 1 + visualdl/utils/figure_util.py | 47 ++++++++++++++++++++++++++++++++++ visualdl/writer/writer.py | 28 ++++++++++++++++++++ 4 files changed, 105 insertions(+) create mode 100644 demo/components/figure_test.py create mode 100644 visualdl/utils/figure_util.py diff --git a/demo/components/figure_test.py b/demo/components/figure_test.py new file mode 100644 index 00000000..0d235764 --- /dev/null +++ b/demo/components/figure_test.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= + +# coding=utf-8 + +from visualdl import LogWriter +import numpy as np +from matplotlib import pyplot as plt + + +if __name__ == '__main__': + with LogWriter(logdir="./log/audio_test/train") as writer: + x = np.arange(100) + y = x ** 2 + 1 + plt.plot(x, y) + fig = plt.gcf() + writer.add_figure(tag="figure", figure=fig, step=0) diff --git a/requirements.txt b/requirements.txt index a2b1fe8d..5a4e5512 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ protobuf >= 3.11.0 requests shellcheck-py six >= 1.14.0 +matplotlib diff --git a/visualdl/utils/figure_util.py b/visualdl/utils/figure_util.py new file mode 100644 index 00000000..0e8b178d --- /dev/null +++ b/visualdl/utils/figure_util.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= + + +def figure_to_image(figures, close=True): + """Render matplotlib figure to numpy format. + + Note that this requires the ``matplotlib`` package. + + Args: + figure (matplotlib.pyplot.figure) : figure + close (bool): Flag to automatically close the figure + + Returns: + numpy.array: image in [HWC] order + """ + import numpy as np + try: + import matplotlib.pyplot as plt + import matplotlib.backends.backend_agg as plt_backend_agg + except ModuleNotFoundError: + print('please install matplotlib') + + def render_to_rgb(figure): + canvas = plt_backend_agg.FigureCanvasAgg(figure) + canvas.draw() + data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) + w, h = figure.canvas.get_width_height() + image_hwc = data.reshape([h, w, 4])[:, :, 0:3] + if close: + plt.close(figure) + return image_hwc + + image = render_to_rgb(figures) + return image diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index e285ef57..214708fd 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -19,6 +19,7 @@ import numpy as np from visualdl.writer.record_writer import RecordFileWriter from visualdl.server.log import logger from visualdl.utils.img_util import merge_images +from visualdl.utils.figure_util import figure_to_image from visualdl.component.base_component import scalar, image, embedding, audio, \ histogram, pr_curve, roc_curve, meta_data, text @@ -192,6 +193,33 @@ class LogWriter(object): image(tag=tag, image_array=img, step=step, walltime=walltime, dataformats=dataformats)) + def add_figure(self, tag, figure, step, walltime=None): + """Add an figure to vdl record file. + + Args: + tag (string): Data identifier + figure (matplotlib.figure.Figure): Image represented by a Figure + step (int): Step of image + walltime (int): Wall time of image + dataformats (string): Format of image + + Example: + form matplotlib import pyplot as plt + import numpy as np + + x = np.arange(100) + y = x ** 2 + 1 + plt.plot(x, y) + fig = plt.gcf() + writer.add_figure(tag="lll", figure=fig, step=0) + """ + if '%' in tag: + raise RuntimeError("% can't appear in tag!") + walltime = round(time.time() * 1000) if walltime is None else walltime + img = figure_to_image(figure) + self._get_file_writer().add_record( + image(tag=tag, image_array=img, step=step, walltime=walltime)) + def add_text(self, tag, text_string, step=None, walltime=None): """Add an text to vdl record file. Args: -- GitLab