# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import numpy as np import datetime __all__ = ['TrainingStats', 'Time'] class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size): self.deque = collections.deque(maxlen=window_size) def add_value(self, value): self.deque.append(value) def get_median_value(self): return np.median(self.deque) def Time(): return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') class TrainingStats(object): def __init__(self, window_size, stats_keys): self.smoothed_losses_and_metrics = { key: SmoothedValue(window_size) for key in stats_keys } def update(self, stats): for k, v in self.smoothed_losses_and_metrics.items(): v.add_value(stats[k].numpy()) def get(self, extras=None): stats = collections.OrderedDict() if extras: for k, v in extras.items(): stats[k] = v for k, v in self.smoothed_losses_and_metrics.items(): stats[k] = format(v.get_median_value(), '.6f') return stats def log(self, extras=None): d = self.get(extras) strs = ', '.join(str(dict({x: y})).strip('{}') for x, y in d.items()) return strs