progressbar.py 5.7 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
import sys
import time
import numpy as np


class ProgressBar(object):
    """progress bar """

    def __init__(self,
                 num=None,
                 width=30,
                 verbose=1,
                 start=True,
                 file=sys.stdout):
        self._num = num
        if isinstance(num, int) and num <= 0:
            raise TypeError('num should be None or integer (> 0)')
        max_width = self._get_max_width()
        self._width = width if width <= max_width else max_width
        self._total_width = 0
        self._verbose = verbose
        self.file = file
        self._values = {}
        self._values_order = []
        if start:
            self._start = time.time()
        self._last_update = 0

        self._dynamic_display = (
            (hasattr(self.file, 'isatty') and
             self.file.isatty()) or 'ipykernel' in sys.modules or
            'posix' in sys.modules or 'PYCHARM_HOSTED' in os.environ)

    def _get_max_width(self):
        if sys.version_info > (3, 3):
            from shutil import get_terminal_size
        else:
            from backports.shutil_get_terminal_size import get_terminal_size
        terminal_width, _ = get_terminal_size()
        max_width = min(int(terminal_width * 0.6), terminal_width - 50)
        return max_width

    def start(self):
        self.file.flush()
        self._start = time.time()

    def update(self, current_num, values=None):
        now = time.time()

        if current_num:
            time_per_unit = (now - self._start) / current_num
        else:
            time_per_unit = 0

        if time_per_unit >= 1 or time_per_unit == 0:
            fps = ' - %.0fs/%s' % (time_per_unit, 'step')
        elif time_per_unit >= 1e-3:
            fps = ' - %.0fms/%s' % (time_per_unit * 1e3, 'step')
        else:
            fps = ' - %.0fus/%s' % (time_per_unit * 1e6, 'step')

        info = ''
        if self._verbose == 1:
            prev_total_width = self._total_width

            if self._dynamic_display:
                sys.stdout.write('\b' * prev_total_width)
                sys.stdout.write('\r')
            else:
                sys.stdout.write('\n')

            if self._num is not None:
                numdigits = int(np.log10(self._num)) + 1

                bar_chars = ('step %' + str(numdigits) + 'd/%d [') % (
                    current_num, self._num)
                prog = float(current_num) / self._num
                prog_width = int(self._width * prog)

                if prog_width > 0:
                    bar_chars += ('=' * (prog_width - 1))
                    if current_num < self._num:
                        bar_chars += '>'
                    else:
                        bar_chars += '='
                bar_chars += ('.' * (self._width - prog_width))
                bar_chars += ']'
            else:
                bar_chars = 'step %3d' % current_num

            self._total_width = len(bar_chars)
            sys.stdout.write(bar_chars)

Q
qingqing01 已提交
94
            for k, val in values:
Q
qingqing01 已提交
95
                info += ' - %s:' % k
Q
qingqing01 已提交
96 97 98 99 100 101 102
                val = val if isinstance(val, list) else [val]
                for i, v in enumerate(val):
                    if isinstance(v, (float, np.float32, np.float64)):
                        if abs(v) > 1e-3:
                            info += ' %.4f' % v
                        else:
                            info += ' %.4e' % v
Q
qingqing01 已提交
103
                    else:
Q
qingqing01 已提交
104
                        info += ' %s' % v
Q
qingqing01 已提交
105 106 107 108 109

            if self._num is not None and current_num < self._num:
                eta = time_per_unit * (self._num - current_num)
                if eta > 3600:
                    eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
110
                                                   60, eta % 60)
Q
qingqing01 已提交
111 112
                elif eta > 60:
                    eta_format = '%d:%02d' % (eta // 60, eta % 60)
L
LielinJiang 已提交
113
                else:
Q
qingqing01 已提交
114 115 116 117 118 119 120 121 122 123 124
                    eta_format = '%ds' % eta

                info += ' - ETA: %s' % eta_format

            info += fps
            self._total_width += len(info)
            if prev_total_width > self._total_width:
                info += (' ' * (prev_total_width - self._total_width))

            # newline for another epoch
            if self._num is not None and current_num >= self._num:
L
LielinJiang 已提交
125
                info += '\n'
Q
qingqing01 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
            if self._num is None:
                info += '\n'

            sys.stdout.write(info)
            sys.stdout.flush()
            self._last_update = now
        elif self._verbose == 2:
            if self._num:
                numdigits = int(np.log10(self._num)) + 1
                count = ('step %' + str(numdigits) + 'd/%d') % (current_num,
                                                                self._num)
            else:
                count = 'step %3d' % current_num
            info = count + info

Q
qingqing01 已提交
141
            for k, val in values:
Q
qingqing01 已提交
142
                info += ' - %s:' % k
Q
qingqing01 已提交
143 144 145 146 147 148 149 150
                val = val if isinstance(val, list) else [val]
                for v in val:
                    if isinstance(v, (float, np.float32, np.float64)):
                        if abs(v) > 1e-3:
                            info += ' %.4f' % v
                        else:
                            info += ' %.4e' % v
                    elif isinstance(v, np.ndarray) and \
151
                        v.size == 1 and \
L
LielinJiang 已提交
152
                        isinstance(v.dtype, (np.float32, np.float64)):
Q
qingqing01 已提交
153 154 155 156
                        if abs(v[0]) > 1e-3:
                            info += ' %.4f' % v[0]
                        else:
                            info += ' %.4e' % v[0]
Q
qingqing01 已提交
157
                    else:
Q
qingqing01 已提交
158
                        info += ' %s' % v
Q
qingqing01 已提交
159 160 161 162 163

            info += fps
            info += '\n'
            sys.stdout.write(info)
            sys.stdout.flush()