train.py 4.7 KB
Newer Older
1 2 3 4
from encoder.ui.visualizations import Visualizations
from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
from encoder.params_model import *
from encoder.model import SpeakerEncoder
5
from utils.profiler import Profiler
6 7 8 9 10
from pathlib import Path
import torch


def train(run_id: str, clean_data_root: Path, models_dir: Path, vis_every: int, save_every: int,
11
          backup_every: int, force_restart: bool, visdom_server: str, no_visdom: bool):
12 13 14 15 16 17 18 19 20
    # Create a dataset and a dataloader
    dataset = SpeakerVerificationDataset(clean_data_root)
    loader = SpeakerVerificationDataLoader(
        dataset,
        speakers_per_batch,
        utterances_per_speaker,
        num_workers=4,
    )
    
21 22 23
    # Setup the device on which to run the forward pass and the loss. These can be different, 
    # because the forward pass is faster on the GPU whereas the loss is often (depending on your
    # hyperparameters) faster on the CPU.
24
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
    # FIXME: currently, the gradient is incorrect if loss_device is cuda
26 27 28 29
    loss_device = torch.device("cpu")
    
    # Create the model and the optimizer
    model = SpeakerEncoder(device, loss_device)
30 31 32 33
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
    init_step = 1
    
    # Configure file path for the model
34 35
    state_fpath = models_dir.joinpath(run_id + ".pt")
    backup_dir = models_dir.joinpath(run_id + "_backups")
36 37 38 39

    # Load any existing model
    if not force_restart:
        if state_fpath.exists():
40
            print("Found existing model \"%s\", loading it and resuming training." % run_id)
41
            checkpoint = torch.load(state_fpath)
42 43 44 45
            init_step = checkpoint["step"]
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            optimizer.param_groups[0]["lr"] = learning_rate_init
46
        else:
47
            print("No model \"%s\" found, starting training from scratch." % run_id)
48 49
    else:
        print("Starting the training from scratch.")
C
Corentin Jemine 已提交
50
    model.train()
51 52
    
    # Initialize the visualization environment
53
    device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
54
    vis = Visualizations(run_id, device_name=device_name, server=visdom_server, disabled=no_visdom)
55 56 57
    vis.log_dataset(dataset)
    
    # Training loop
58
    profiler = Profiler(summarize_every=10)
59
    for step, speaker_batch in enumerate(loader, init_step):
60 61
        profiler.tick("Blocking, waiting for batch (threaded)")
        
62 63
        # Forward pass
        inputs = torch.from_numpy(speaker_batch.data).to(device)
64
        torch.cuda.synchronize(device)
65
        profiler.tick("Data to %s" % device)
66 67
        embeds = model(inputs)
        torch.cuda.synchronize(device)
68
        profiler.tick("Forward pass")
69 70
        embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
        loss, eer = model.loss(embeds_loss)
71 72
        profiler.tick("Loss")

73 74 75
        # Backward pass
        model.zero_grad()
        loss.backward()
76
        profiler.tick("Backward pass")
77
        model.do_gradient_ops()
78
        optimizer.step()
79
        profiler.tick("Parameter update")
80 81
        
        # Update visualizations
82
        learning_rate = optimizer.param_groups[0]["lr"]
83 84 85 86
        vis.update(loss.item(), eer, learning_rate, step)
        
        # Draw projections and save them to the backup folder
        if vis_every != 0 and step % vis_every == 0:
87
            print("Drawing and saving projections (step %d)" % step)
88
            backup_dir.mkdir(exist_ok=True)
89
            projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
90 91
            embeds = embeds.detach().cpu().numpy()
            vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
92 93 94 95
            vis.save()

        # Overwrite the latest version of the model
        if save_every != 0 and step % save_every == 0:
96
            print("Saving the model (step %d)" % step)
97
            torch.save({
98 99 100
                "step": step + 1,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
101 102 103 104
            }, state_fpath)
            
        # Make a backup
        if backup_every != 0 and step % backup_every == 0:
105
            print("Making a backup (step %d)" % step)
106 107 108
            backup_dir.mkdir(exist_ok=True)
            backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
            torch.save({
109 110 111
                "step": step + 1,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
112
            }, backup_fpath)
113 114 115
            
        profiler.tick("Extras (visualizations, saving)")