visualizations.py 6.9 KB
Newer Older
1
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2 3
from datetime import datetime
from time import perf_counter as clock
4
import matplotlib.pyplot as plt
5
import numpy as np
C
Corentin Jemine 已提交
6
import webbrowser
7
import visdom
8
import umap
C
Corentin Jemine 已提交
9

10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
colormap = np.array([
    [76, 255, 0],
    [0, 127, 70],
    [255, 0, 0],
    [255, 217, 38],
    [0, 135, 255],
    [165, 0, 165],
    [255, 167, 255],
    [0, 255, 255],
    [255, 96, 38],
    [142, 76, 0],
    [33, 0, 127],
    [0, 0, 0],
    [183, 183, 183],
], dtype=np.float) / 255 

26 27

class Visualizations:
28 29 30 31 32 33 34 35 36
    def __init__(self, env_name=None, device_name=None, server="http://localhost", disabled=False):
        self.last_update_timestamp = clock()
        self.mean_time_per_step = -1
        self.loss_exp = None
        self.eer_exp = None
        self.disabled = disabled    # TODO: use a better paradigm for that
        if self.disabled:
            return 
        
37
        now = str(datetime.now().strftime("%d-%m %Hh%M"))
38
        if env_name is None:
39 40
            self.env_name = now
        else:
41
            self.env_name = "%s (%s)" % (env_name, now)
C
Corentin Jemine 已提交
42 43
        
        try:
44
            self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
45
        except ConnectionError:
46 47
            raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
                            "start it.")
48
        # webbrowser.open("http://localhost:8097/env/" + self.env_name)
C
Corentin Jemine 已提交
49
        
50
        self.loss_win = None
C
Corentin Jemine 已提交
51
        self.eer_win = None
52 53
        self.lr_win = None
        self.implementation_win = None
54
        self.projection_win = None
55 56 57
        self.implementation_string = ""
        self.log_params()
        
58
        if device_name is not None:
59
            self.log_implementation({"Device": device_name})
60
        
61
    def log_params(self):
62 63
        if self.disabled:
            return 
64 65
        from encoder import params_data
        from encoder import params_model
66
        param_string = "<b>Model parameters</b>:<br>"
67
        for param_name in (p for p in dir(params_model) if not p.startswith("__")):
68 69 70
            value = getattr(params_model, param_name)
            param_string += "\t%s: %s<br>" % (param_name, value)
        param_string += "<b>Data parameters</b>:<br>"
71
        for param_name in (p for p in dir(params_data) if not p.startswith("__")):
72 73
            value = getattr(params_data, param_name)
            param_string += "\t%s: %s<br>" % (param_name, value)
74
        self.vis.text(param_string, opts={"title": "Parameters"})
75 76
        
    def log_dataset(self, dataset: SpeakerVerificationDataset):
77 78
        if self.disabled:
            return 
79
        dataset_string = ""
80
        dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
81
        dataset_string += "\n" + dataset.get_logs()
82
        dataset_string = dataset_string.replace("\n", "<br>")
83
        self.vis.text(dataset_string, opts={"title": "Dataset"})
84 85
        
    def log_implementation(self, params):
86 87
        if self.disabled:
            return 
88 89 90 91 92 93 94
        implementation_string = ""
        for param, value in params.items():
            implementation_string += "<b>%s</b>: %s\n" % (param, value)
            implementation_string = implementation_string.replace("\n", "<br>")
        self.implementation_string = implementation_string
        self.implementation_win = self.vis.text(
            implementation_string, 
95
            opts={"title": "Training implementation"}
96 97
        )

C
Corentin Jemine 已提交
98
    def update(self, loss, eer, lr, step):
C
Corentin Jemine 已提交
99 100
        self.loss_exp = loss if self.loss_exp is None else 0.985 * self.loss_exp + 0.015 * loss
        self.eer_exp = eer if self.eer_exp is None else 0.985 * self.eer_exp + 0.015 * eer
101 102 103 104 105 106 107 108 109 110 111 112
        if not self.disabled:
            self.loss_win = self.vis.line(
                [[loss, self.loss_exp]],
                [[step, step]],
                win=self.loss_win,
                update="append" if self.loss_win else None,
                opts=dict(
                    legend=["Loss", "Avg. loss"],
                    xlabel="Step",
                    ylabel="Loss",
                    title="Loss",
                )
113
            )
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
            self.eer_win = self.vis.line(
                [[eer, self.eer_exp]],
                [[step, step]],
                win=self.eer_win,
                update="append" if self.eer_win else None,
                opts=dict(
                    legend=["EER", "Avg. EER"],
                    xlabel="Step",
                    ylabel="EER",
                    title="Equal error rate"
                )
            )
            self.lr_win = self.vis.line(
                [lr],
                [step],
                win=self.lr_win,
                update="append" if self.lr_win else None,
                opts=dict(
                    xlabel="Step",
                    ylabel="Learning rate",
                    ytype="log",
                    title="Learning rate"
                )
137 138 139
            )
        
        now = clock()
140 141 142 143 144 145 146 147
        time_per_step = (now - self.last_update_timestamp)
        self.last_update_timestamp = now
        if self.mean_time_per_step == -1:
            self.mean_time_per_step = time_per_step
        else:
            self.mean_time_per_step = self.mean_time_per_step * 0.9 + time_per_step * 0.1
            
        if not self.disabled and self.implementation_win is not None:
148
            time_string = "<b>Mean time per step</b>: %dms" % int(1000 * self.mean_time_per_step)
149
            time_string += "<br><b>Last step time</b>: %dms" % int(1000 * time_per_step)
150 151 152
            self.vis.text(
                self.implementation_string + time_string, 
                win=self.implementation_win,
153
                opts={"title": "Training implementation"},
154
            )
155 156 157 158 159

        print("Step %6d   Loss: %.4f   EER: %.4f   LR: %g   Mean step time: %5dms   "
              "Last step time: %5dms" %
              (step, self.loss_exp, self.eer_exp, lr, int(1000 * self.mean_time_per_step),
               int(1000 * time_per_step)))
160
        
161
    def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
162
                         max_speakers=10):
C
Corentin Jemine 已提交
163
        max_speakers = min(max_speakers, len(colormap))
164 165
        embeds = embeds[:max_speakers * utterances_per_speaker]
        
166 167
        n_speakers = len(embeds) // utterances_per_speaker
        ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
168
        colors = [colormap[i] for i in ground_truth]
169 170 171
        
        reducer = umap.UMAP()
        projected = reducer.fit_transform(embeds)
172
        plt.scatter(projected[:, 0], projected[:, 1], c=colors)
173 174
        plt.gca().set_aspect("equal", "datalim")
        plt.title("UMAP projection (step %d)" % step)
175 176
        if not self.disabled:
            self.projection_win = self.vis.matplot(plt, win=self.projection_win)
177 178
        if out_fpath is not None:
            plt.savefig(out_fpath)
179 180
        plt.clf()
        
181
    def save(self):
182 183 184
        if not self.disabled:
            self.vis.save([self.env_name])