__main__.py 8.1 KB
Newer Older
R
Romain 已提交
1 2 3
#!/usr/bin/env python
# coding: utf8

F
Faylixe 已提交
4 5
"""
    Python oneliner script usage.
R
Romain 已提交
6

F
Faylixe 已提交
7 8
    USAGE: python -m spleeter {train,evaluate,separate} ...

F
Faylixe 已提交
9 10 11 12 13
    Notes:
        All critical import involving TF, numpy or Pandas are deported to
        command function scope to avoid heavy import on CLI evaluation,
        leading to large bootstraping time.
"""
F
Faylixe 已提交
14
import json
F
Faylixe 已提交
15
from functools import partial
F
Faylixe 已提交
16
from glob import glob
17
from itertools import product
F
Faylixe 已提交
18
from os.path import join
F
Faylixe 已提交
19
from pathlib import Path
20
from typing import Container, Dict, List, Optional
R
Romain 已提交
21

22 23 24 25
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer

F
Faylixe 已提交
26
from . import SpleeterError
F
Faylixe 已提交
27
from .options import *
F
Faylixe 已提交
28
from .utils.logging import configure_logger, logger
R
Romain 已提交
29

F
Faylixe 已提交
30
# pylint: enable=import-error
R
Romain 已提交
31

F
Faylixe 已提交
32
spleeter: Typer = Typer(add_completion=False)
F
Faylixe 已提交
33
""" CLI application. """
R
Romain 已提交
34

F
Faylixe 已提交
35

36 37 38 39 40 41
@spleeter.callback()
def default(
    version: bool = VersionOption,
) ->  None:
    pass

F
Faylixe 已提交
42 43
@spleeter.command()
def train(
44 45 46 47 48
    adapter: str = AudioAdapterOption,
    data: Path = TrainingDataDirectoryOption,
    params_filename: str = ModelParametersOption,
    verbose: bool = VerboseOption,
) -> None:
F
Faylixe 已提交
49
    """
50
    Train a source separation model
F
Faylixe 已提交
51
    """
52 53
    import tensorflow as tf

F
Faylixe 已提交
54 55 56 57 58 59
    from .audio.adapter import AudioAdapter
    from .dataset import get_training_dataset, get_validation_dataset
    from .model import model_fn
    from .model.provider import ModelProvider
    from .utils.configuration import load_configuration

F
Faylixe 已提交
60
    configure_logger(verbose)
F
Faylixe 已提交
61 62 63 64 65 66 67
    audio_adapter = AudioAdapter.get(adapter)
    audio_path = str(data)
    params = load_configuration(params_filename)
    session_config = tf.compat.v1.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
68
        model_dir=params["model_dir"],
F
Faylixe 已提交
69 70
        params=params,
        config=tf.estimator.RunConfig(
71 72 73
            save_checkpoints_steps=params["save_checkpoints_steps"],
            tf_random_seed=params["random_seed"],
            save_summary_steps=params["save_summary_steps"],
F
Faylixe 已提交
74 75
            session_config=session_config,
            log_step_count_steps=10,
76 77 78
            keep_checkpoint_max=2,
        ),
    )
F
Faylixe 已提交
79 80
    input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
    train_spec = tf.estimator.TrainSpec(
81 82 83
        input_fn=input_fn, max_steps=params["train_max_steps"]
    )
    input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
F
Faylixe 已提交
84
    evaluation_spec = tf.estimator.EvalSpec(
85 86 87
        input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
    )
    logger.info("Start model training")
F
Faylixe 已提交
88
    tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
89 90
    ModelProvider.writeProbe(params["model_dir"])
    logger.info("Model training done")
F
Faylixe 已提交
91 92


F
Faylixe 已提交
93
@spleeter.command()
F
Faylixe 已提交
94
def separate(
95 96 97 98 99 100 101 102 103 104 105 106 107 108
    deprecated_files: Optional[str] = AudioInputOption,
    files: List[Path] = AudioInputArgument,
    adapter: str = AudioAdapterOption,
    bitrate: str = AudioBitrateOption,
    codec: Codec = AudioCodecOption,
    duration: float = AudioDurationOption,
    offset: float = AudioOffsetOption,
    output_path: Path = AudioOutputOption,
    stft_backend: STFTBackend = AudioSTFTBackendOption,
    filename_format: str = FilenameFormatOption,
    params_filename: str = ModelParametersOption,
    mwf: bool = MWFOption,
    verbose: bool = VerboseOption,
) -> None:
F
Faylixe 已提交
109
    """
110
    Separate audio file(s)
F
Faylixe 已提交
111
    """
F
Faylixe 已提交
112 113 114
    from .audio.adapter import AudioAdapter
    from .separator import Separator

F
Faylixe 已提交
115
    configure_logger(verbose)
F
Félix Voituret 已提交
116
    if deprecated_files is not None:
117
        logger.error(
118 119 120
            "⚠️ -i option is not supported anymore, audio files must be supplied "
            "using input argument instead (see spleeter separate --help)"
        )
121
        raise Exit(20)
F
Faylixe 已提交
122 123
    audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
    separator: Separator = Separator(
124 125
        params_filename, MWF=mwf, stft_backend=stft_backend
    )
F
Faylixe 已提交
126 127
    for filename in files:
        separator.separate_to_file(
F
Faylixe 已提交
128 129
            str(filename),
            str(output_path),
F
Faylixe 已提交
130 131 132 133 134 135
            audio_adapter=audio_adapter,
            offset=offset,
            duration=duration,
            codec=codec,
            bitrate=bitrate,
            filename_format=filename_format,
136 137
            synchronous=False,
        )
F
Faylixe 已提交
138
    separator.join()
F
Faylixe 已提交
139 140


141 142 143 144 145 146
EVALUATION_SPLIT: str = "test"
EVALUATION_METRICS_DIRECTORY: str = "metrics"
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
EVALUATION_MIXTURE: str = "mixture.wav"
EVALUATION_AUDIO_DIRECTORY: str = "audio"
F
Faylixe 已提交
147 148


F
Faylixe 已提交
149
def _compile_metrics(metrics_output_directory) -> Dict:
F
Faylixe 已提交
150
    """
151
    Compiles metrics from given directory and returns results as dict.
F
Faylixe 已提交
152

153 154 155
    Parameters:
        metrics_output_directory (str):
            Directory to get metrics from.
F
Faylixe 已提交
156

157 158 159
    Returns:
        Dict:
            Compiled metrics as dict.
F
Faylixe 已提交
160
    """
F
Faylixe 已提交
161
    import numpy as np
162
    import pandas as pd
F
Faylixe 已提交
163

164
    songs = glob(join(metrics_output_directory, "test/*.json"))
F
Faylixe 已提交
165
    index = pd.MultiIndex.from_tuples(
F
Faylixe 已提交
166
        product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
167 168 169
        names=["instrument", "metric"],
    )
    pd.DataFrame([], index=["config1", "config2"], columns=index)
F
Faylixe 已提交
170
    metrics = {
F
Faylixe 已提交
171
        instrument: {k: [] for k in EVALUATION_METRICS}
172 173
        for instrument in EVALUATION_INSTRUMENTS
    }
F
Faylixe 已提交
174
    for song in songs:
175
        with open(song, "r") as stream:
F
Faylixe 已提交
176
            data = json.load(stream)
177 178
        for target in data["targets"]:
            instrument = target["name"]
F
Faylixe 已提交
179
            for metric in EVALUATION_METRICS:
180 181 182 183 184 185 186
                sdr_med = np.median(
                    [
                        frame["metrics"][metric]
                        for frame in target["frames"]
                        if not np.isnan(frame["metrics"][metric])
                    ]
                )
F
Faylixe 已提交
187 188
                metrics[instrument][metric].append(sdr_med)
    return metrics
F
Faylixe 已提交
189 190 191 192


@spleeter.command()
def evaluate(
193 194 195 196 197 198 199 200
    adapter: str = AudioAdapterOption,
    output_path: Path = AudioOutputOption,
    stft_backend: STFTBackend = AudioSTFTBackendOption,
    params_filename: str = ModelParametersOption,
    mus_dir: Path = MUSDBDirectoryOption,
    mwf: bool = MWFOption,
    verbose: bool = VerboseOption,
) -> Dict:
F
Faylixe 已提交
201
    """
202
    Evaluate a model on the musDB test dataset
F
Faylixe 已提交
203
    """
F
Faylixe 已提交
204 205
    import numpy as np

F
Faylixe 已提交
206 207 208 209 210
    configure_logger(verbose)
    try:
        import musdb
        import museval
    except ImportError:
211 212
        logger.error("Extra dependencies musdb and museval not found")
        logger.error("Please install musdb and museval first, abort")
F
Faylixe 已提交
213
        raise Exit(10)
F
Faylixe 已提交
214
    # Separate musdb sources.
215
    songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
F
Faylixe 已提交
216 217 218
    mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
    audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
    separate(
F
Félix Voituret 已提交
219
        deprecated_files=None,
F
Faylixe 已提交
220
        files=mixtures,
F
Faylixe 已提交
221
        adapter=adapter,
222
        bitrate="128k",
F
Faylixe 已提交
223
        codec=Codec.WAV,
224
        duration=600.0,
F
Faylixe 已提交
225
        offset=0,
F
Faylixe 已提交
226
        output_path=join(audio_output_directory, EVALUATION_SPLIT),
F
Faylixe 已提交
227
        stft_backend=stft_backend,
228
        filename_format="{foldername}/{instrument}.{codec}",
F
Faylixe 已提交
229
        params_filename=params_filename,
F
Faylixe 已提交
230
        mwf=mwf,
231 232
        verbose=verbose,
    )
F
Faylixe 已提交
233 234
    # Compute metrics with musdb.
    metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
235
    logger.info("Starting musdb evaluation (this could be long) ...")
F
Faylixe 已提交
236 237 238 239
    dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
    museval.eval_mus_dir(
        dataset=dataset,
        estimates_dir=audio_output_directory,
240 241 242
        output_dir=metrics_output_directory,
    )
    logger.info("musdb evaluation done")
F
Faylixe 已提交
243 244 245
    # Compute and pretty print median metrics.
    metrics = _compile_metrics(metrics_output_directory)
    for instrument, metric in metrics.items():
246
        logger.info(f"{instrument}:")
F
Faylixe 已提交
247
        for metric, value in metric.items():
248
            logger.info(f"{metric}: {np.median(value):.3f}")
F
Faylixe 已提交
249
    return metrics
F
Faylixe 已提交
250 251


F
Faylixe 已提交
252 253
def entrypoint():
    """ Application entrypoint. """
F
Faylixe 已提交
254 255 256 257
    try:
        spleeter()
    except SpleeterError as e:
        logger.error(e)
F
Faylixe 已提交
258 259


260
if __name__ == "__main__":
F
Faylixe 已提交
261
    entrypoint()