__main__.py 7.5 KB
Newer Older
R
Romain 已提交
1 2 3
#!/usr/bin/env python
# coding: utf8

F
Faylixe 已提交
4 5
"""
    Python oneliner script usage.
R
Romain 已提交
6

F
Faylixe 已提交
7 8 9
    USAGE: python -m spleeter {train,evaluate,separate} ...
"""

F
Faylixe 已提交
10 11 12
# NOTE: disable TF logging before import.
from .utils.logging import configure_logger, logger

F
Faylixe 已提交
13
import json
F
Faylixe 已提交
14 15

from functools import partial
F
Faylixe 已提交
16 17 18
from itertools import product
from glob import glob
from os.path import join
F
Faylixe 已提交
19
from pathlib import Path
F
Faylixe 已提交
20
from typing import Any, Container, Dict, List
R
Romain 已提交
21

F
Faylixe 已提交
22
from . import SpleeterError
F
Faylixe 已提交
23 24
from .audio import Codec
from .audio.adapter import AudioAdapter
F
Faylixe 已提交
25 26 27 28
from .options import *
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
F
Faylixe 已提交
29
from .separator import Separator
F
Faylixe 已提交
30
from .utils.configuration import load_configuration
R
Romain 已提交
31

F
Faylixe 已提交
32 33
# pyright: reportMissingImports=false
# pylint: disable=import-error
F
Faylixe 已提交
34 35
import numpy as np
import pandas as pd
F
Faylixe 已提交
36 37 38
import tensorflow as tf

from typer import Exit, Typer
F
Faylixe 已提交
39
# pylint: enable=import-error
R
Romain 已提交
40

F
Faylixe 已提交
41
spleeter: Typer = Typer(add_completion=False)
F
Faylixe 已提交
42
""" CLI application. """
R
Romain 已提交
43

F
Faylixe 已提交
44 45 46

@spleeter.command()
def train(
F
Faylixe 已提交
47 48 49 50
        adapter: str = AudioAdapterOption,
        data: Path = TrainingDataDirectoryOption,
        params_filename: str = ModelParametersOption,
        verbose: bool = VerboseOption) -> None:
F
Faylixe 已提交
51 52 53
    """
        Train a source separation model
    """
F
Faylixe 已提交
54
    configure_logger(verbose)
F
Faylixe 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    audio_adapter = AudioAdapter.get(adapter)
    audio_path = str(data)
    params = load_configuration(params_filename)
    session_config = tf.compat.v1.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=params['model_dir'],
        params=params,
        config=tf.estimator.RunConfig(
            save_checkpoints_steps=params['save_checkpoints_steps'],
            tf_random_seed=params['random_seed'],
            save_summary_steps=params['save_summary_steps'],
            session_config=session_config,
            log_step_count_steps=10,
            keep_checkpoint_max=2))
    input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
    train_spec = tf.estimator.TrainSpec(
        input_fn=input_fn,
        max_steps=params['train_max_steps'])
    input_fn = partial(
        get_validation_dataset,
        params,
        audio_adapter,
        audio_path)
    evaluation_spec = tf.estimator.EvalSpec(
        input_fn=input_fn,
        steps=None,
        throttle_secs=params['throttle_secs'])
F
Faylixe 已提交
84
    logger.info('Start model training')
F
Faylixe 已提交
85 86
    tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
    ModelProvider.writeProbe(params['model_dir'])
F
Faylixe 已提交
87 88 89
    logger.info('Model training done')


F
Faylixe 已提交
90
@spleeter.command()
F
Faylixe 已提交
91
def separate(
F
Faylixe 已提交
92
        files: List[Path] = AudioInputOptions,
F
Faylixe 已提交
93 94 95 96 97
        adapter: str = AudioAdapterOption,
        bitrate: str = AudioBitrateOption,
        codec: Codec = AudioCodecOption,
        duration: float = AudioDurationOption,
        offset: float = AudioOffsetOption,
98
        output_path: Path = AudioOutputOption,
F
Faylixe 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        stft_backend: STFTBackend = AudioSTFTBackendOption,
        filename_format: str = FilenameFormatOption,
        params_filename: str = ModelParametersOption,
        mwf: bool = MWFOption,
        verbose: bool = VerboseOption) -> None:
    """
        Separate audio file(s)
    """
    configure_logger(verbose)
    audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
    separator: Separator = Separator(
        params_filename,
        MWF=mwf,
        stft_backend=stft_backend)
    for filename in files:
        separator.separate_to_file(
F
Faylixe 已提交
115 116
            str(filename),
            str(output_path),
F
Faylixe 已提交
117 118 119 120 121 122 123 124
            audio_adapter=audio_adapter,
            offset=offset,
            duration=duration,
            codec=codec,
            bitrate=bitrate,
            filename_format=filename_format,
            synchronous=False)
    separator.join()
F
Faylixe 已提交
125 126


F
Faylixe 已提交
127 128 129 130 131 132
EVALUATION_SPLIT: str = 'test'
EVALUATION_METRICS_DIRECTORY: str = 'metrics'
EVALUATION_INSTRUMENTS: Container[str] = ('vocals', 'drums', 'bass', 'other')
EVALUATION_METRICS: Container[str] = ('SDR', 'SAR', 'SIR', 'ISR')
EVALUATION_MIXTURE: str = 'mixture.wav'
EVALUATION_AUDIO_DIRECTORY: str = 'audio'
F
Faylixe 已提交
133 134


F
Faylixe 已提交
135
def _compile_metrics(metrics_output_directory) -> Dict:
F
Faylixe 已提交
136
    """
F
Faylixe 已提交
137
        Compiles metrics from given directory and returns results as dict.
F
Faylixe 已提交
138

F
Faylixe 已提交
139 140 141
        Parameters:
            metrics_output_directory (str):
                Directory to get metrics from.
F
Faylixe 已提交
142

F
Faylixe 已提交
143 144 145
        Returns:
            Dict:
                Compiled metrics as dict.
F
Faylixe 已提交
146 147 148
    """
    songs = glob(join(metrics_output_directory, 'test/*.json'))
    index = pd.MultiIndex.from_tuples(
F
Faylixe 已提交
149
        product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
F
Faylixe 已提交
150 151 152
        names=['instrument', 'metric'])
    pd.DataFrame([], index=['config1', 'config2'], columns=index)
    metrics = {
F
Faylixe 已提交
153 154
        instrument: {k: [] for k in EVALUATION_METRICS}
        for instrument in EVALUATION_INSTRUMENTS}
F
Faylixe 已提交
155 156 157 158 159
    for song in songs:
        with open(song, 'r') as stream:
            data = json.load(stream)
        for target in data['targets']:
            instrument = target['name']
F
Faylixe 已提交
160
            for metric in EVALUATION_METRICS:
F
Faylixe 已提交
161 162 163 164 165 166
                sdr_med = np.median([
                    frame['metrics'][metric]
                    for frame in target['frames']
                    if not np.isnan(frame['metrics'][metric])])
                metrics[instrument][metric].append(sdr_med)
    return metrics
F
Faylixe 已提交
167 168 169 170


@spleeter.command()
def evaluate(
F
Faylixe 已提交
171
        adapter: str = AudioAdapterOption,
172
        output_path: Path = AudioOutputOption,
F
Faylixe 已提交
173 174 175 176
        stft_backend: STFTBackend = AudioSTFTBackendOption,
        params_filename: str = ModelParametersOption,
        mus_dir: Path = MUSDBDirectoryOption,
        mwf: bool = MWFOption,
F
Faylixe 已提交
177
        verbose: bool = VerboseOption) -> Dict:
F
Faylixe 已提交
178 179 180
    """
        Evaluate a model on the musDB test dataset
    """
F
Faylixe 已提交
181 182 183 184 185 186 187 188
    configure_logger(verbose)
    try:
        import musdb
        import museval
    except ImportError:
        logger.error('Extra dependencies musdb and museval not found')
        logger.error('Please install musdb and museval first, abort')
        raise Exit(10)
F
Faylixe 已提交
189
    # Separate musdb sources.
F
Faylixe 已提交
190 191 192 193 194
    songs = glob(join(mus_dir, EVALUATION_SPLIT, '*/'))
    mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
    audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
    separate(
        files=mixtures,
F
Faylixe 已提交
195 196 197 198 199
        adapter=adapter,
        bitrate='128k',
        codec=Codec.WAV,
        duration=600.,
        offset=0,
F
Faylixe 已提交
200
        output_path=join(audio_output_directory, EVALUATION_SPLIT),
F
Faylixe 已提交
201
        stft_backend=stft_backend,
F
Faylixe 已提交
202
        filename_format='{foldername}/{instrument}.{codec}',
F
Faylixe 已提交
203
        params_filename=params_filename,
F
Faylixe 已提交
204
        mwf=mwf,
F
Faylixe 已提交
205
        verbose=verbose)
F
Faylixe 已提交
206 207 208 209 210 211 212 213 214
    # Compute metrics with musdb.
    metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
    logger.info('Starting musdb evaluation (this could be long) ...')
    dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
    museval.eval_mus_dir(
        dataset=dataset,
        estimates_dir=audio_output_directory,
        output_dir=metrics_output_directory)
    logger.info('musdb evaluation done')
F
Faylixe 已提交
215 216 217
    # Compute and pretty print median metrics.
    metrics = _compile_metrics(metrics_output_directory)
    for instrument, metric in metrics.items():
F
Faylixe 已提交
218
        logger.info(f'{instrument}:')
F
Faylixe 已提交
219
        for metric, value in metric.items():
F
Faylixe 已提交
220
            logger.info(f'{metric}: {np.median(value):.3f}')
F
Faylixe 已提交
221
    return metrics
F
Faylixe 已提交
222 223


F
Faylixe 已提交
224 225
def entrypoint():
    """ Application entrypoint. """
F
Faylixe 已提交
226 227 228 229
    try:
        spleeter()
    except SpleeterError as e:
        logger.error(e)
F
Faylixe 已提交
230 231 232 233


if __name__ == '__main__':
    entrypoint()