progressbar.py 5.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
import sys
import time
import numpy as np


class ProgressBar(object):
    """progress bar """

    def __init__(self,
                 num=None,
                 width=30,
                 verbose=1,
                 start=True,
                 file=sys.stdout):
        self._num = num
        if isinstance(num, int) and num <= 0:
            raise TypeError('num should be None or integer (> 0)')
        max_width = self._get_max_width()
        self._width = width if width <= max_width else max_width
        self._total_width = 0
        self._verbose = verbose
        self.file = file
        self._values = {}
        self._values_order = []
        if start:
            self._start = time.time()
        self._last_update = 0

        self._dynamic_display = (
            (hasattr(self.file, 'isatty') and
             self.file.isatty()) or 'ipykernel' in sys.modules or
            'posix' in sys.modules or 'PYCHARM_HOSTED' in os.environ)

    def _get_max_width(self):
        if sys.version_info > (3, 3):
            from shutil import get_terminal_size
        else:
            from backports.shutil_get_terminal_size import get_terminal_size
        terminal_width, _ = get_terminal_size()
        max_width = min(int(terminal_width * 0.6), terminal_width - 50)
        return max_width

    def start(self):
        self.file.flush()
        self._start = time.time()

    def update(self, current_num, values=None):
        now = time.time()

        if current_num:
            time_per_unit = (now - self._start) / current_num
        else:
            time_per_unit = 0

        if time_per_unit >= 1 or time_per_unit == 0:
            fps = ' - %.0fs/%s' % (time_per_unit, 'step')
        elif time_per_unit >= 1e-3:
            fps = ' - %.0fms/%s' % (time_per_unit * 1e3, 'step')
        else:
            fps = ' - %.0fus/%s' % (time_per_unit * 1e6, 'step')

        info = ''
        if self._verbose == 1:
            prev_total_width = self._total_width

            if self._dynamic_display:
                sys.stdout.write('\b' * prev_total_width)
                sys.stdout.write('\r')
            else:
                sys.stdout.write('\n')

            if self._num is not None:
                numdigits = int(np.log10(self._num)) + 1

                bar_chars = ('step %' + str(numdigits) + 'd/%d [') % (
                    current_num, self._num)
                prog = float(current_num) / self._num
                prog_width = int(self._width * prog)

                if prog_width > 0:
                    bar_chars += ('=' * (prog_width - 1))
                    if current_num < self._num:
                        bar_chars += '>'
                    else:
                        bar_chars += '='
                bar_chars += ('.' * (self._width - prog_width))
                bar_chars += ']'
            else:
                bar_chars = 'step %3d' % current_num

            self._total_width = len(bar_chars)
            sys.stdout.write(bar_chars)

            for k, v in values:
                info += ' - %s:' % k
                if isinstance(v, (float, np.float32, np.float64)):
                    if abs(v) > 1e-3:
                        info += ' %.4f' % v
                    else:
                        info += ' %.4e' % v
                else:
                    info += ' %s' % v

            if self._num is not None and current_num < self._num:
                eta = time_per_unit * (self._num - current_num)
                if eta > 3600:
                    eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
                                                   60, eta % 60)
                elif eta > 60:
                    eta_format = '%d:%02d' % (eta // 60, eta % 60)
                else:
                    eta_format = '%ds' % eta

                info += ' - ETA: %s' % eta_format

            info += fps
            self._total_width += len(info)
            if prev_total_width > self._total_width:
                info += (' ' * (prev_total_width - self._total_width))

            # newline for another epoch
            if self._num is not None and current_num >= self._num:
                info += '\n'
            if self._num is None:
                info += '\n'

            sys.stdout.write(info)
            sys.stdout.flush()
            self._last_update = now
        elif self._verbose == 2:
            if self._num:
                numdigits = int(np.log10(self._num)) + 1
                count = ('step %' + str(numdigits) + 'd/%d') % (current_num,
                                                                self._num)
            else:
                count = 'step %3d' % current_num
            info = count + info

            for k, v in values:
                info += ' - %s:' % k
                if isinstance(v, (float, np.float32, np.float64)):
                    if abs(v) > 1e-3:
                        info += ' %.4f' % v
                    else:
                        info += ' %.4e' % v
                elif isinstance(v, np.ndarray) and \
                     isinstance(v.size, 1) and \
                     isinstance(v.dtype, (np.float32, np.float64)):
                    if abs(v[0]) > 1e-3:
                        info += ' %.4f' % v[0]
                    else:
                        info += ' %.4e' % v[0]
                else:
                    info += ' %s' % v

            info += fps
            info += '\n'
            sys.stdout.write(info)
            sys.stdout.flush()