提交 c96897df 编写于 作者: F Faylixe

🐛 fix backend resolution

上级 232bf0d3
......@@ -12,6 +12,11 @@
from enum import Enum
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
......@@ -34,3 +39,12 @@ class STFTBackend(str, Enum):
AUTO: str = 'auto'
TENSORFLOW: str = 'tensorflow'
LIBROSA: str = 'librosa'
def resolve(cls: type, backend: str) -> str:
if backend not in cls.__members__.items():
raise ValueError(f'Unsupported backend {backend}')
if backend == cls.AUTO:
if len(tf.config.list_physical_devices('GPU')):
return cls.TENSORFLOW
return STFTBackend.LIBROSA
return backend
......@@ -129,8 +129,7 @@ class Separator(object):
else:
self._pool = None
self._tasks = []
# NOTE: provide type check here ?
self._params['stft_backend'] = stft_backend
self._params['stft_backend'] = STFTBackend.resolve(stft_backend)
self._data_generator = DataGenerator()
def __del__(self) -> None:
......@@ -333,11 +332,6 @@ class Separator(object):
(Optional) string describing the waveform (e.g. filename).
"""
backend: str = self._params['stft_backend']
if backend == STFTBackend.AUTO:
if len(tf.config.list_physical_devices('GPU')):
backend = STFTBackend.TENSORFLOW
else:
backend = STFTBackend.LIBROSA
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor)
elif backend == STFTBackend.LIBROSA:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册