progressbar.py 7.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import time
import numpy as np
19
import struct
20 21
from collections import namedtuple

22
__all__ = []
23 24


25
class ProgressBar:
26 27 28 29 30 31 32 33 34 35 36
    """progress bar"""

    def __init__(
        self,
        num=None,
        width=30,
        verbose=1,
        start=True,
        file=sys.stdout,
        name='step',
    ):
37 38 39 40 41 42 43 44 45 46 47 48 49
        self._num = num
        if isinstance(num, int) and num <= 0:
            raise TypeError('num should be None or integer (> 0)')
        max_width = self._get_max_width()
        self._width = width if width <= max_width else max_width
        self._total_width = 0
        self._verbose = verbose
        self.file = file
        self._values = {}
        self._values_order = []
        if start:
            self._start = time.time()
        self._last_update = 0
50
        self.name = name
51

52 53 54 55 56 57
        self._dynamic_display = (
            (hasattr(self.file, 'isatty') and self.file.isatty())
            or 'ipykernel' in sys.modules
            or 'posix' in sys.modules
            or 'PYCHARM_HOSTED' in os.environ
        )
58 59 60 61 62 63 64 65 66 67 68 69 70 71

    def _get_max_width(self):
        if sys.version_info > (3, 3):
            from shutil import get_terminal_size
        else:
            try:
                from backports.shutil_get_terminal_size import get_terminal_size
            except:

                def get_terminal_size():
                    terminal_size = namedtuple("terminal_size", "columns lines")
                    return terminal_size(80, 24)

        terminal_width, _ = get_terminal_size()
72
        terminal_width = terminal_width if terminal_width > 0 else 80
73 74 75 76 77 78 79
        max_width = min(int(terminal_width * 0.6), terminal_width - 50)
        return max_width

    def start(self):
        self.file.flush()
        self._start = time.time()

80
    def update(self, current_num, values={}):
81 82
        now = time.time()

83 84 85 86
        def convert_uint16_to_float(in_list):
            in_list = np.asarray(in_list)
            out = np.vectorize(
                lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
87 88
                otypes=[np.float32],
            )(in_list.flat)
89 90 91 92
            return np.reshape(out, in_list.shape)

        for i, (k, val) in enumerate(values):
            if k == "loss":
93 94 95 96 97
                val = (
                    val
                    if isinstance(val, list) or isinstance(val, np.ndarray)
                    else [val]
                )
98 99 100
                if isinstance(val[0], np.uint16):
                    values[i] = ("loss", list(convert_uint16_to_float(val)))

101 102 103 104 105 106
        if current_num:
            time_per_unit = (now - self._start) / current_num
        else:
            time_per_unit = 0

        if time_per_unit >= 1 or time_per_unit == 0:
107
            fps = ' - %.0fs/%s' % (time_per_unit, self.name)
108
        elif time_per_unit >= 1e-3:
109
            fps = ' - %.0fms/%s' % (time_per_unit * 1e3, self.name)
110
        else:
111
            fps = ' - %.0fus/%s' % (time_per_unit * 1e6, self.name)
112 113 114 115 116 117 118 119 120 121 122 123 124 125

        info = ''
        if self._verbose == 1:
            prev_total_width = self._total_width

            if self._dynamic_display:
                sys.stdout.write('\b' * prev_total_width)
                sys.stdout.write('\r')
            else:
                sys.stdout.write('\n')

            if self._num is not None:
                numdigits = int(np.log10(self._num)) + 1

126 127 128 129
                bar_chars = (self.name + ' %' + str(numdigits) + 'd/%d [') % (
                    current_num,
                    self._num,
                )
130 131 132 133
                prog = float(current_num) / self._num
                prog_width = int(self._width * prog)

                if prog_width > 0:
134
                    bar_chars += '=' * (prog_width - 1)
135 136 137 138
                    if current_num < self._num:
                        bar_chars += '>'
                    else:
                        bar_chars += '='
139
                bar_chars += '.' * (self._width - prog_width)
140 141
                bar_chars += ']'
            else:
142
                bar_chars = self.name + ' %3d' % current_num
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161

            self._total_width = len(bar_chars)
            sys.stdout.write(bar_chars)

            for k, val in values:
                info += ' - %s:' % k
                val = val if isinstance(val, list) else [val]
                for i, v in enumerate(val):
                    if isinstance(v, (float, np.float32, np.float64)):
                        if abs(v) > 1e-3:
                            info += ' %.4f' % v
                        else:
                            info += ' %.4e' % v
                    else:
                        info += ' %s' % v

            if self._num is not None and current_num < self._num:
                eta = time_per_unit * (self._num - current_num)
                if eta > 3600:
162 163 164 165 166
                    eta_format = '%d:%02d:%02d' % (
                        eta // 3600,
                        (eta % 3600) // 60,
                        eta % 60,
                    )
167 168 169 170 171 172 173 174 175 176
                elif eta > 60:
                    eta_format = '%d:%02d' % (eta // 60, eta % 60)
                else:
                    eta_format = '%ds' % eta

                info += ' - ETA: %s' % eta_format

            info += fps
            self._total_width += len(info)
            if prev_total_width > self._total_width:
177
                info += ' ' * (prev_total_width - self._total_width)
178 179 180 181 182 183 184 185 186 187

            # newline for another epoch
            if self._num is not None and current_num >= self._num:
                info += '\n'
            if self._num is None:
                info += '\n'

            sys.stdout.write(info)
            sys.stdout.flush()
            self._last_update = now
188
        elif self._verbose == 2 or self._verbose == 3:
189 190
            if self._num:
                numdigits = int(np.log10(self._num)) + 1
191 192 193 194
                count = (self.name + ' %' + str(numdigits) + 'd/%d') % (
                    current_num,
                    self._num,
                )
195
            else:
196
                count = self.name + ' %3d' % current_num
197 198 199 200 201 202 203 204 205 206 207
            info = count + info

            for k, val in values:
                info += ' - %s:' % k
                val = val if isinstance(val, list) else [val]
                for v in val:
                    if isinstance(v, (float, np.float32, np.float64)):
                        if abs(v) > 1e-3:
                            info += ' %.4f' % v
                        else:
                            info += ' %.4e' % v
208 209 210 211 212
                    elif (
                        isinstance(v, np.ndarray)
                        and v.size == 1
                        and v.dtype in [np.float32, np.float64]
                    ):
213 214 215 216 217 218 219 220 221 222 223
                        if abs(v[0]) > 1e-3:
                            info += ' %.4f' % v[0]
                        else:
                            info += ' %.4e' % v[0]
                    else:
                        info += ' %s' % v

            info += fps
            info += '\n'
            sys.stdout.write(info)
            sys.stdout.flush()