#!/usr/bin/env python # coding: utf8 """ Python oneliner script usage. USAGE: python -m spleeter {train,evaluate,separate} ... Notes: All critical import involving TF, numpy or Pandas are deported to command function scope to avoid heavy import on CLI evaluation, leading to large bootstraping time. """ import json from functools import partial from glob import glob from itertools import product from os.path import join from pathlib import Path from typing import Container, Dict, List, Optional # pyright: reportMissingImports=false # pylint: disable=import-error from typer import Exit, Typer from . import SpleeterError from .options import * from .utils.logging import configure_logger, logger # pylint: enable=import-error spleeter: Typer = Typer(add_completion=False, no_args_is_help=True, short_help="-h") """ CLI application. """ @spleeter.callback() def default( version: bool = VersionOption, ) -> None: pass @spleeter.command(no_args_is_help=True) def train( adapter: str = AudioAdapterOption, data: Path = TrainingDataDirectoryOption, params_filename: str = ModelParametersOption, verbose: bool = VerboseOption, ) -> None: """ Train a source separation model """ import tensorflow as tf from .audio.adapter import AudioAdapter from .dataset import get_training_dataset, get_validation_dataset from .model import model_fn from .model.provider import ModelProvider from .utils.configuration import load_configuration configure_logger(verbose) audio_adapter = AudioAdapter.get(adapter) audio_path = str(data) params = load_configuration(params_filename) session_config = tf.compat.v1.ConfigProto() session_config.gpu_options.per_process_gpu_memory_fraction = 0.45 estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=params["model_dir"], params=params, config=tf.estimator.RunConfig( save_checkpoints_steps=params["save_checkpoints_steps"], tf_random_seed=params["random_seed"], save_summary_steps=params["save_summary_steps"], session_config=session_config, log_step_count_steps=10, keep_checkpoint_max=2, ), ) input_fn = partial(get_training_dataset, params, audio_adapter, audio_path) train_spec = tf.estimator.TrainSpec( input_fn=input_fn, max_steps=params["train_max_steps"] ) input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path) evaluation_spec = tf.estimator.EvalSpec( input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"] ) logger.info("Start model training") tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec) ModelProvider.writeProbe(params["model_dir"]) logger.info("Model training done") @spleeter.command(no_args_is_help=True) def separate( deprecated_files: Optional[str] = AudioInputOption, files: List[Path] = AudioInputArgument, adapter: str = AudioAdapterOption, bitrate: str = AudioBitrateOption, codec: Codec = AudioCodecOption, duration: float = AudioDurationOption, offset: float = AudioOffsetOption, output_path: Path = AudioOutputOption, stft_backend: STFTBackend = AudioSTFTBackendOption, filename_format: str = FilenameFormatOption, params_filename: str = ModelParametersOption, mwf: bool = MWFOption, verbose: bool = VerboseOption, ) -> None: """ Separate audio file(s) """ from .audio.adapter import AudioAdapter from .separator import Separator configure_logger(verbose) if deprecated_files is not None: logger.error( "⚠️ -i option is not supported anymore, audio files must be supplied " "using input argument instead (see spleeter separate --help)" ) raise Exit(20) audio_adapter: AudioAdapter = AudioAdapter.get(adapter) separator: Separator = Separator( params_filename, MWF=mwf, stft_backend=stft_backend ) for filename in files: separator.separate_to_file( str(filename), str(output_path), audio_adapter=audio_adapter, offset=offset, duration=duration, codec=codec, bitrate=bitrate, filename_format=filename_format, synchronous=False, ) separator.join() EVALUATION_SPLIT: str = "test" EVALUATION_METRICS_DIRECTORY: str = "metrics" EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other") EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR") EVALUATION_MIXTURE: str = "mixture.wav" EVALUATION_AUDIO_DIRECTORY: str = "audio" def _compile_metrics(metrics_output_directory) -> Dict: """ Compiles metrics from given directory and returns results as dict. Parameters: metrics_output_directory (str): Directory to get metrics from. Returns: Dict: Compiled metrics as dict. """ import numpy as np import pandas as pd songs = glob(join(metrics_output_directory, "test/*.json")) index = pd.MultiIndex.from_tuples( product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS), names=["instrument", "metric"], ) pd.DataFrame([], index=["config1", "config2"], columns=index) metrics = { instrument: {k: [] for k in EVALUATION_METRICS} for instrument in EVALUATION_INSTRUMENTS } for song in songs: with open(song, "r") as stream: data = json.load(stream) for target in data["targets"]: instrument = target["name"] for metric in EVALUATION_METRICS: sdr_med = np.median( [ frame["metrics"][metric] for frame in target["frames"] if not np.isnan(frame["metrics"][metric]) ] ) metrics[instrument][metric].append(sdr_med) return metrics @spleeter.command(no_args_is_help=True) def evaluate( adapter: str = AudioAdapterOption, output_path: Path = AudioOutputOption, stft_backend: STFTBackend = AudioSTFTBackendOption, params_filename: str = ModelParametersOption, mus_dir: Path = MUSDBDirectoryOption, mwf: bool = MWFOption, verbose: bool = VerboseOption, ) -> Dict: """ Evaluate a model on the musDB test dataset """ import numpy as np configure_logger(verbose) try: import musdb import museval except ImportError: logger.error("Extra dependencies musdb and museval not found") logger.error("Please install musdb and museval first, abort") raise Exit(10) # Separate musdb sources. songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/")) mixtures = [join(song, EVALUATION_MIXTURE) for song in songs] audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY) separate( deprecated_files=None, files=mixtures, adapter=adapter, bitrate="128k", codec=Codec.WAV, duration=600.0, offset=0, output_path=join(audio_output_directory, EVALUATION_SPLIT), stft_backend=stft_backend, filename_format="{foldername}/{instrument}.{codec}", params_filename=params_filename, mwf=mwf, verbose=verbose, ) # Compute metrics with musdb. metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY) logger.info("Starting musdb evaluation (this could be long) ...") dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT]) museval.eval_mus_dir( dataset=dataset, estimates_dir=audio_output_directory, output_dir=metrics_output_directory, ) logger.info("musdb evaluation done") # Compute and pretty print median metrics. metrics = _compile_metrics(metrics_output_directory) for instrument, metric in metrics.items(): logger.info(f"{instrument}:") for metric, value in metric.items(): logger.info(f"{metric}: {np.median(value):.3f}") return metrics def entrypoint(): """ Application entrypoint. """ try: spleeter() except SpleeterError as e: logger.error(e) if __name__ == "__main__": entrypoint()