parse_tuning_log.py 3.2 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
"""Parse the log for tuning and plot error surface."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import re
import numpy as np
import argparse
import functools
import _init_paths
from utils.utility import add_arguments, print_arguments
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("log_path", str, '', "log path for parsing")
add_arg("fig_name", str, 'error_surface.png', "name of output figure")
args = parser.parse_args()


def plot_error_surface(num_alphas, alphas, betas, error_rate_type, err_ave):
    fig = plt.figure(figsize=(8, 6))
    ax = Axes3D(fig)

    num_betas = len(alphas) // num_alphas
    alphas_2d = np.reshape(alphas, (num_alphas, num_betas))
    betas_2d = np.reshape(betas, (num_alphas, num_betas))
    err_ave_2d = np.reshape(err_ave, (num_alphas, num_betas))

    ax.plot_surface(
        alphas_2d,
        betas_2d,
        err_ave_2d,
        rstride=1,
        cstride=1,
        alpha=0.8,
        cmap='rainbow')
    z_label = 'WER' if error_rate_type == 'wer' else 'CER'
    ax.set_xlabel('alpha', fontsize=12)
    ax.set_ylabel('beta', fontsize=12)
    ax.set_zlabel(z_label, fontsize=12)
    plt.savefig(args.fig_name)
    plt.show()


def parse_log():
    if not os.path.isfile(args.log_path):
        raise IOError("Invaid model path: %s" % args.log_path)

    error_rate_type = None
    num_alphas, num_betas = 0, 0
    alphas, betas, err_ave = [], [], []

    err_rate_pat = re.compile(
        '\(alpha, beta\) = '
        '\([-+]?\d+(?:\.\d+)?, [-+]?\d+(?:\.\d+)?\), \[[wcer]')
    num_pat = re.compile(r'[-+]?\d+(?:\.\d+)?')

    with open(args.log_path, "r") as log_file:
        line = log_file.readline()
        while line:
            if line.find("error_rate_type:") != -1:
                error_rate_type = line.strip().split()[1]
            elif line.find("num_alphas:") != -1:
                num_alphas = int(line.strip().split()[1])
            elif line.find("num_betas:") != -1:
                num_betas = int(line.strip().split()[1])
            elif err_rate_pat.match(line) is not None:
                tuples = num_pat.findall(line)
                alphas.append(float(tuples[0]))
                betas.append(float(tuples[1]))
                err_ave.append(float(tuples[2]))
            line = log_file.readline()

    if error_rate_type == None:
        raise ValueError("Illegal log format, cannot find error_rate_type")

    if num_alphas <= 0:
        raise ValueError("Illegal log format, invalid num_alphas")

    if num_betas <= 0:
        raise ValueError("Illegal log format, invalid num_betas")

    if alphas == []:
        raise ValueError("Illegal log format, cannot find grid search result")

    if num_alphas * num_betas != len(alphas):
        raise ValueError("Illegal log format, data's shape mismatches")

    return num_alphas, alphas, betas, error_rate_type, err_ave,


def main():
    print_arguments(args)
    num_alphas, alphas, betas, error_rate_type, err_ave = parse_log()
    plot_error_surface(num_alphas, alphas, betas, error_rate_type, err_ave)


if __name__ == '__main__':
    main()