__main__.py 8.2 KB
Newer Older
R
Romain 已提交
1 2 3
#!/usr/bin/env python
# coding: utf8

F
Faylixe 已提交
4 5
"""
    Python oneliner script usage.
R
Romain 已提交
6

F
Faylixe 已提交
7 8
    USAGE: python -m spleeter {train,evaluate,separate} ...

F
Faylixe 已提交
9 10 11 12 13
    Notes:
        All critical import involving TF, numpy or Pandas are deported to
        command function scope to avoid heavy import on CLI evaluation,
        leading to large bootstraping time.
"""
F
Faylixe 已提交
14
import json
F
Faylixe 已提交
15
from functools import partial
F
Faylixe 已提交
16
from glob import glob
17
from itertools import product
F
Faylixe 已提交
18
from os.path import join
F
Faylixe 已提交
19
from pathlib import Path
20
from typing import Container, Dict, List, Optional
R
Romain 已提交
21

22 23 24 25
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer

F
Faylixe 已提交
26
from . import SpleeterError
F
Faylixe 已提交
27
from .options import *
F
Faylixe 已提交
28
from .utils.logging import configure_logger, logger
R
Romain 已提交
29

F
Faylixe 已提交
30
# pylint: enable=import-error
R
Romain 已提交
31

32
spleeter: Typer = Typer(add_completion=False, no_args_is_help=True, short_help="-h")
F
Faylixe 已提交
33
""" CLI application. """
R
Romain 已提交
34

F
Faylixe 已提交
35

36 37 38
@spleeter.callback()
def default(
    version: bool = VersionOption,
R
romi1502 已提交
39
) -> None:
40 41
    pass

R
romi1502 已提交
42

43
@spleeter.command(no_args_is_help=True)
F
Faylixe 已提交
44
def train(
45 46 47 48 49
    adapter: str = AudioAdapterOption,
    data: Path = TrainingDataDirectoryOption,
    params_filename: str = ModelParametersOption,
    verbose: bool = VerboseOption,
) -> None:
F
Faylixe 已提交
50
    """
51
    Train a source separation model
F
Faylixe 已提交
52
    """
53 54
    import tensorflow as tf

F
Faylixe 已提交
55 56 57 58 59 60
    from .audio.adapter import AudioAdapter
    from .dataset import get_training_dataset, get_validation_dataset
    from .model import model_fn
    from .model.provider import ModelProvider
    from .utils.configuration import load_configuration

F
Faylixe 已提交
61
    configure_logger(verbose)
F
Faylixe 已提交
62 63 64 65 66 67 68
    audio_adapter = AudioAdapter.get(adapter)
    audio_path = str(data)
    params = load_configuration(params_filename)
    session_config = tf.compat.v1.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
69
        model_dir=params["model_dir"],
F
Faylixe 已提交
70 71
        params=params,
        config=tf.estimator.RunConfig(
72 73 74
            save_checkpoints_steps=params["save_checkpoints_steps"],
            tf_random_seed=params["random_seed"],
            save_summary_steps=params["save_summary_steps"],
F
Faylixe 已提交
75 76
            session_config=session_config,
            log_step_count_steps=10,
77 78 79
            keep_checkpoint_max=2,
        ),
    )
F
Faylixe 已提交
80 81
    input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
    train_spec = tf.estimator.TrainSpec(
82 83 84
        input_fn=input_fn, max_steps=params["train_max_steps"]
    )
    input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
F
Faylixe 已提交
85
    evaluation_spec = tf.estimator.EvalSpec(
86 87 88
        input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
    )
    logger.info("Start model training")
F
Faylixe 已提交
89
    tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
90 91
    ModelProvider.writeProbe(params["model_dir"])
    logger.info("Model training done")
F
Faylixe 已提交
92 93


94
@spleeter.command(no_args_is_help=True)
F
Faylixe 已提交
95
def separate(
96 97 98 99 100 101 102 103 104 105 106 107 108 109
    deprecated_files: Optional[str] = AudioInputOption,
    files: List[Path] = AudioInputArgument,
    adapter: str = AudioAdapterOption,
    bitrate: str = AudioBitrateOption,
    codec: Codec = AudioCodecOption,
    duration: float = AudioDurationOption,
    offset: float = AudioOffsetOption,
    output_path: Path = AudioOutputOption,
    stft_backend: STFTBackend = AudioSTFTBackendOption,
    filename_format: str = FilenameFormatOption,
    params_filename: str = ModelParametersOption,
    mwf: bool = MWFOption,
    verbose: bool = VerboseOption,
) -> None:
F
Faylixe 已提交
110
    """
111
    Separate audio file(s)
F
Faylixe 已提交
112
    """
F
Faylixe 已提交
113 114 115
    from .audio.adapter import AudioAdapter
    from .separator import Separator

F
Faylixe 已提交
116
    configure_logger(verbose)
F
Félix Voituret 已提交
117
    if deprecated_files is not None:
118
        logger.error(
119 120 121
            "⚠️ -i option is not supported anymore, audio files must be supplied "
            "using input argument instead (see spleeter separate --help)"
        )
122
        raise Exit(20)
F
Faylixe 已提交
123 124
    audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
    separator: Separator = Separator(
125 126
        params_filename, MWF=mwf, stft_backend=stft_backend
    )
F
Faylixe 已提交
127 128
    for filename in files:
        separator.separate_to_file(
F
Faylixe 已提交
129 130
            str(filename),
            str(output_path),
F
Faylixe 已提交
131 132 133 134 135 136
            audio_adapter=audio_adapter,
            offset=offset,
            duration=duration,
            codec=codec,
            bitrate=bitrate,
            filename_format=filename_format,
137 138
            synchronous=False,
        )
F
Faylixe 已提交
139
    separator.join()
F
Faylixe 已提交
140 141


142 143 144 145 146 147
EVALUATION_SPLIT: str = "test"
EVALUATION_METRICS_DIRECTORY: str = "metrics"
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
EVALUATION_MIXTURE: str = "mixture.wav"
EVALUATION_AUDIO_DIRECTORY: str = "audio"
F
Faylixe 已提交
148 149


F
Faylixe 已提交
150
def _compile_metrics(metrics_output_directory) -> Dict:
F
Faylixe 已提交
151
    """
152
    Compiles metrics from given directory and returns results as dict.
F
Faylixe 已提交
153

154 155 156
    Parameters:
        metrics_output_directory (str):
            Directory to get metrics from.
F
Faylixe 已提交
157

158 159 160
    Returns:
        Dict:
            Compiled metrics as dict.
F
Faylixe 已提交
161
    """
F
Faylixe 已提交
162
    import numpy as np
163
    import pandas as pd
F
Faylixe 已提交
164

165
    songs = glob(join(metrics_output_directory, "test/*.json"))
F
Faylixe 已提交
166
    index = pd.MultiIndex.from_tuples(
F
Faylixe 已提交
167
        product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
168 169 170
        names=["instrument", "metric"],
    )
    pd.DataFrame([], index=["config1", "config2"], columns=index)
F
Faylixe 已提交
171
    metrics = {
F
Faylixe 已提交
172
        instrument: {k: [] for k in EVALUATION_METRICS}
173 174
        for instrument in EVALUATION_INSTRUMENTS
    }
F
Faylixe 已提交
175
    for song in songs:
176
        with open(song, "r") as stream:
F
Faylixe 已提交
177
            data = json.load(stream)
178 179
        for target in data["targets"]:
            instrument = target["name"]
F
Faylixe 已提交
180
            for metric in EVALUATION_METRICS:
181 182 183 184 185 186 187
                sdr_med = np.median(
                    [
                        frame["metrics"][metric]
                        for frame in target["frames"]
                        if not np.isnan(frame["metrics"][metric])
                    ]
                )
F
Faylixe 已提交
188 189
                metrics[instrument][metric].append(sdr_med)
    return metrics
F
Faylixe 已提交
190 191


192
@spleeter.command(no_args_is_help=True)
F
Faylixe 已提交
193
def evaluate(
194 195 196 197 198 199 200 201
    adapter: str = AudioAdapterOption,
    output_path: Path = AudioOutputOption,
    stft_backend: STFTBackend = AudioSTFTBackendOption,
    params_filename: str = ModelParametersOption,
    mus_dir: Path = MUSDBDirectoryOption,
    mwf: bool = MWFOption,
    verbose: bool = VerboseOption,
) -> Dict:
F
Faylixe 已提交
202
    """
203
    Evaluate a model on the musDB test dataset
F
Faylixe 已提交
204
    """
F
Faylixe 已提交
205 206
    import numpy as np

F
Faylixe 已提交
207 208 209 210 211
    configure_logger(verbose)
    try:
        import musdb
        import museval
    except ImportError:
212 213
        logger.error("Extra dependencies musdb and museval not found")
        logger.error("Please install musdb and museval first, abort")
F
Faylixe 已提交
214
        raise Exit(10)
F
Faylixe 已提交
215
    # Separate musdb sources.
216
    songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
F
Faylixe 已提交
217 218 219
    mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
    audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
    separate(
F
Félix Voituret 已提交
220
        deprecated_files=None,
F
Faylixe 已提交
221
        files=mixtures,
F
Faylixe 已提交
222
        adapter=adapter,
223
        bitrate="128k",
F
Faylixe 已提交
224
        codec=Codec.WAV,
225
        duration=600.0,
F
Faylixe 已提交
226
        offset=0,
F
Faylixe 已提交
227
        output_path=join(audio_output_directory, EVALUATION_SPLIT),
F
Faylixe 已提交
228
        stft_backend=stft_backend,
229
        filename_format="{foldername}/{instrument}.{codec}",
F
Faylixe 已提交
230
        params_filename=params_filename,
F
Faylixe 已提交
231
        mwf=mwf,
232 233
        verbose=verbose,
    )
F
Faylixe 已提交
234 235
    # Compute metrics with musdb.
    metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
236
    logger.info("Starting musdb evaluation (this could be long) ...")
F
Faylixe 已提交
237 238 239 240
    dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
    museval.eval_mus_dir(
        dataset=dataset,
        estimates_dir=audio_output_directory,
241 242 243
        output_dir=metrics_output_directory,
    )
    logger.info("musdb evaluation done")
F
Faylixe 已提交
244 245 246
    # Compute and pretty print median metrics.
    metrics = _compile_metrics(metrics_output_directory)
    for instrument, metric in metrics.items():
247
        logger.info(f"{instrument}:")
F
Faylixe 已提交
248
        for metric, value in metric.items():
249
            logger.info(f"{metric}: {np.median(value):.3f}")
F
Faylixe 已提交
250
    return metrics
F
Faylixe 已提交
251 252


F
Faylixe 已提交
253 254
def entrypoint():
    """ Application entrypoint. """
F
Faylixe 已提交
255 256 257 258
    try:
        spleeter()
    except SpleeterError as e:
        logger.error(e)
F
Faylixe 已提交
259 260


261
if __name__ == "__main__":
F
Faylixe 已提交
262
    entrypoint()