metrics_tracker.py 2.9 KB
Newer Older
S
sserdoubleh 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MetricsTracker class
"""

from collections import defaultdict
import math


class MetricsTracker(object):
    """ Tracking metrics. """

    def __init__(self):
        self.metrics_val = defaultdict(float)
        self.metrics_avg = defaultdict(float)
        self.num_samples = 0

    def update(self, metrics, num_samples):
        for key, val in metrics.items():
            if val is not None:
                val = float(val)
                self.metrics_val[key] = val
                avg_val = (self.metrics_avg.get(key, 0) * self.num_samples +
                           val * num_samples) / (self.num_samples + num_samples)
                self.metrics_avg[key] = avg_val
        self.num_samples += num_samples

    def clear(self):
        self.metrics_val = defaultdict(float)
        self.metrics_avg = defaultdict(float)
        self.num_samples = 0

    def items(self):
        return self.metrics_avg.items()

    def get(self, name):
        if self.num_samples == 0:
            raise ValueError("There is no data in Metrics.")
        return self.metrics_avg.get(name)

    def state_dict(self):
        return {
            "metrics_val": self.metrics_val,
            "metrics_avg": self.metrics_avg,
            "num_samples": self.num_samples,
        }

    def load_state_dict(self, state_dict):
        self.metrics_val = state_dict["metrics_val"]
        self.metrics_avg = state_dict["metrics_avg"]
        self.num_samples = state_dict["num_samples"]

    def value(self):
        metric_strs = []
        for key, val in self.metrics_val.items():
            metric_str = f"{key.upper()}-{val:.3f}"
            metric_strs.append(metric_str)
        if "token_nll" in self.metrics_val:
            metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}"
            metric_strs.append(metric_str)
        metric_strs = "   ".join(metric_strs)
        return metric_strs

    def summary(self):
        metric_strs = []
        for key, val in self.metrics_avg.items():
            metric_str = f"{key.upper()}-{val:.3f}"
            metric_strs.append(metric_str)
        if "token_nll" in self.metrics_avg:
            metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}"
            metric_strs.append(metric_str)
        metric_strs = "   ".join(metric_strs)
        return metric_strs