diff --git a/deep_speech_2/tools/parse_tuning_log.py b/deep_speech_2/tools/parse_tuning_log.py new file mode 100644 index 0000000000000000000000000000000000000000..fa7873ee9501784fdb67589b414dc62b96ed6f6d --- /dev/null +++ b/deep_speech_2/tools/parse_tuning_log.py @@ -0,0 +1,102 @@ +"""Parse the log for tuning and plot error surface.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re +import numpy as np +import argparse +import functools +import _init_paths +from utils.utility import add_arguments, print_arguments +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +add_arg("log_path", str, '', "log path for parsing") +add_arg("fig_name", str, 'error_surface.png', "name of output figure") +args = parser.parse_args() + + +def plot_error_surface(num_alphas, alphas, betas, error_rate_type, err_ave): + fig = plt.figure(figsize=(8, 6)) + ax = Axes3D(fig) + + num_betas = len(alphas) // num_alphas + alphas_2d = np.reshape(alphas, (num_alphas, num_betas)) + betas_2d = np.reshape(betas, (num_alphas, num_betas)) + err_ave_2d = np.reshape(err_ave, (num_alphas, num_betas)) + + ax.plot_surface( + alphas_2d, + betas_2d, + err_ave_2d, + rstride=1, + cstride=1, + alpha=0.8, + cmap='rainbow') + z_label = 'WER' if error_rate_type == 'wer' else 'CER' + ax.set_xlabel('alpha', fontsize=12) + ax.set_ylabel('beta', fontsize=12) + ax.set_zlabel(z_label, fontsize=12) + plt.savefig(args.fig_name) + plt.show() + + +def parse_log(): + if not os.path.isfile(args.log_path): + raise IOError("Invaid model path: %s" % args.log_path) + + error_rate_type = None + num_alphas, num_betas = 0, 0 + alphas, betas, err_ave = [], [], [] + + err_rate_pat = re.compile( + '\(alpha, beta\) = ' + '\([-+]?\d+(?:\.\d+)?, [-+]?\d+(?:\.\d+)?\), \[[wcer]') + num_pat = re.compile(r'[-+]?\d+(?:\.\d+)?') + + with open(args.log_path, "r") as log_file: + line = log_file.readline() + while line: + if line.find("error_rate_type:") != -1: + error_rate_type = line.strip().split()[1] + elif line.find("num_alphas:") != -1: + num_alphas = int(line.strip().split()[1]) + elif line.find("num_betas:") != -1: + num_betas = int(line.strip().split()[1]) + elif err_rate_pat.match(line) is not None: + tuples = num_pat.findall(line) + alphas.append(float(tuples[0])) + betas.append(float(tuples[1])) + err_ave.append(float(tuples[2])) + line = log_file.readline() + + if error_rate_type == None: + raise ValueError("Illegal log format, cannot find error_rate_type") + + if num_alphas <= 0: + raise ValueError("Illegal log format, invalid num_alphas") + + if num_betas <= 0: + raise ValueError("Illegal log format, invalid num_betas") + + if alphas == []: + raise ValueError("Illegal log format, cannot find grid search result") + + if num_alphas * num_betas != len(alphas): + raise ValueError("Illegal log format, data's shape mismatches") + + return num_alphas, alphas, betas, error_rate_type, err_ave, + + +def main(): + print_arguments(args) + num_alphas, alphas, betas, error_rate_type, err_ave = parse_log() + plot_error_surface(num_alphas, alphas, betas, error_rate_type, err_ave) + + +if __name__ == '__main__': + main()