plotlog.py 2.2 KB
Newer Older
R
ranqiu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
#coding=utf-8

import sys
import argparse
import matplotlib.pyplot as plt


def parse_args():
    parser = argparse.ArgumentParser('Parse Log')
    parser.add_argument(
        '--file_path', '-f', type=str, help='the path of the log file')
    parser.add_argument(
        '--sample_rate',
        '-s',
        type=float,
        default=1.0,
        help='the rate to take samples from log')
    parser.add_argument(
        '--log_period', '-p', type=int, default=1, help='the period of log')

    args = parser.parse_args()
    return args


def parse_file(file_name):
    loss = []
    error = []
    with open(file_name) as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line.startswith('pass'):
                continue
            line_split = line.split(' ')
            if len(line_split) != 5:
                continue

            loss_str = line_split[2][:-1]
            cur_loss = float(loss_str.split('=')[-1])
            loss.append(cur_loss)

            err_str = line_split[3][:-1]
            cur_err = float(err_str.split('=')[-1])
            error.append(cur_err)

    accuracy = [1.0 - err for err in error]

    return loss, accuracy


def sample(metric, sample_rate):
    interval = int(1.0 / sample_rate)
    if interval > len(metric):
        return metric[:1]

    num = len(metric) / interval
    idx = [interval * i for i in range(num)]
    metric_sample = [metric[id] for id in idx]
    return metric_sample


def plot_metric(metric, batch_id, graph_title):
    plt.figure()
    plt.title(graph_title)
    plt.plot(batch_id, metric)
    plt.xlabel('batch')
    plt.ylabel(graph_title)
    plt.savefig(graph_title + '.jpg')
    plt.close()


def main():
    args = parse_args()
    assert args.sample_rate > 0. and args.sample_rate <= 1.0, "The sample rate should in the range (0, 1]."

    loss, accuracy = parse_file(args.file_path)
    batch = [args.log_period * i for i in range(len(loss))]

    batch_sample = sample(batch, args.sample_rate)
    loss_sample = sample(loss, args.sample_rate)
    accuracy_sample = sample(accuracy, args.sample_rate)

    plot_metric(loss_sample, batch_sample, 'loss')
    plot_metric(accuracy_sample, batch_sample, 'accuracy')


if __name__ == '__main__':
    main()