提交 d94996f2 编写于 作者: Y Yang Zhou

format audio

上级 b336ccfe
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import backends
from . import compliance
from . import datasets
from . import features
......@@ -18,4 +19,4 @@ from . import functional
from . import io
from . import metric
from . import sox_effects
from . import backends
from . import utils
......@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import utils
from .soundfile_backend import depth_convert
from .soundfile_backend import soundfile_load
from .soundfile_backend import normalize
from .soundfile_backend import resample
from .soundfile_backend import soundfile_load
from .soundfile_backend import soundfile_save
from .soundfile_backend import to_mono
from . import utils
from .utils import get_audio_backend
from .utils import list_audio_backends
from .utils import set_audio_backend
utils._init_audio_backend()
\ No newline at end of file
utils._init_audio_backend()
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from typing import Optional
......@@ -204,6 +203,7 @@ def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
wavfile.write(file, sr, y_out)
def soundfile_load(
file: os.PathLike,
sr: Optional[int]=None,
......@@ -256,9 +256,13 @@ def soundfile_load(
y = depth_convert(y, dtype)
return y, r
#the code below token form: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py with modificaion.
def _get_subtype_for_wav(dtype: paddle.dtype, encoding: str, bits_per_sample: int):
def _get_subtype_for_wav(dtype: paddle.dtype,
encoding: str,
bits_per_sample: int):
if not encoding:
if not bits_per_sample:
subtype = {
......@@ -315,7 +319,10 @@ def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
raise ValueError(f"sph does not support {encoding}.")
def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sample: int):
def _get_subtype(dtype: paddle.dtype,
format: str,
encoding: str,
bits_per_sample: int):
if format == "wav":
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
......@@ -328,7 +335,8 @@ def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sampl
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
raise ValueError(
"ogg/vorbis does not support encoding/bits_per_sample.")
return "VORBIS"
if format == "sph":
return _get_subtype_for_sphere(encoding, bits_per_sample)
......@@ -336,16 +344,16 @@ def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sampl
return "PCM_16"
raise ValueError(f"Unsupported format: {format}")
def save(
filepath: str,
src: paddle.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
filepath: str,
src: paddle.Tensor,
sample_rate: int,
channels_first: bool=True,
compression: Optional[float]=None,
format: Optional[str]=None,
encoding: Optional[str]=None,
bits_per_sample: Optional[int]=None, ):
"""Save audio data to file.
Note:
......@@ -441,11 +449,11 @@ def save(
if compression is not None:
warnings.warn(
'`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored."
)
"The argument is silently ignored.")
if hasattr(filepath, "write"):
if format is None:
raise RuntimeError("`format` is required when saving to file object.")
raise RuntimeError(
"`format` is required when saving to file object.")
ext = format.lower()
else:
ext = str(filepath).split(".")[-1].lower()
......@@ -455,8 +463,7 @@ def save(
if bits_per_sample == 24:
warnings.warn(
"Saving audio with 24 bits per sample might warp samples near -1. "
"Using 16 bits per sample might be able to avoid this."
)
"Using 16 bits per sample might be able to avoid this.")
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
......@@ -467,7 +474,13 @@ def save(
if channels_first:
src = src.t()
soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format)
soundfile.write(
file=filepath,
data=src,
samplerate=sample_rate,
subtype=subtype,
format=format)
_SUBTYPE2DTYPE = {
"PCM_S8": "int8",
......@@ -478,14 +491,14 @@ _SUBTYPE2DTYPE = {
"DOUBLE": "float64",
}
def load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[paddle.Tensor, int]:
filepath: str,
frame_offset: int=0,
num_frames: int=-1,
normalize: bool=True,
channels_first: bool=True,
format: Optional[str]=None, ) -> Tuple[paddle.Tensor, int]:
"""Load audio data from file.
Note:
......@@ -564,7 +577,7 @@ def load(
waveform = paddle.to_tensor(waveform)
if channels_first:
waveform = paddle.transpose(waveform, perm=[1,0])
waveform = paddle.transpose(waveform, perm=[1, 0])
return waveform, sample_rate
......@@ -588,7 +601,8 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
"ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
"IMA_ADPCM": 0, # IMA ADPCM.
"MS_ADPCM": 0, # Microsoft ADPCM.
"GSM610": 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
"GSM610":
0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
"VOX_ADPCM": 0, # OKI / Dialogix ADPCM
"G721_32": 0, # 32kbs G721 ADPCM encoding.
"G723_24": 0, # 24kbs G723 ADPCM encoding.
......@@ -606,16 +620,17 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
"ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
}
def _get_bit_depth(subtype):
if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
warnings.warn(
f"The {subtype} subtype is unknown to PaddleAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
"You may otherwise ignore this warning.")
return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
_SUBTYPE_TO_ENCODING = {
"PCM_S8": "PCM_S",
"PCM_16": "PCM_S",
......@@ -629,12 +644,14 @@ _SUBTYPE_TO_ENCODING = {
"VORBIS": "VORBIS",
}
def _get_encoding(format: str, subtype: str):
if format == "FLAC":
return "FLAC"
return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
def info(filepath: str, format: Optional[str] = None) -> AudioInfo:
def info(filepath: str, format: Optional[str]=None) -> AudioInfo:
"""Get signal information of an audio file.
Note:
......@@ -657,5 +674,4 @@ def info(filepath: str, format: Optional[str] = None) -> AudioInfo:
sinfo.frames,
sinfo.channels,
bits_per_sample=_get_bit_depth(sinfo.subtype),
encoding=_get_encoding(sinfo.format, sinfo.subtype),
)
\ No newline at end of file
encoding=_get_encoding(sinfo.format, sinfo.subtype), )
from pathlib import Path
from typing import Callable
from typing import Optional, Tuple, Union
import os
from typing import Optional
from typing import Tuple
import paddle
import paddleaudio
from paddle import Tensor
from .common import AudioInfo
import os
from paddleaudio._internal import module_utils as _mod_utils
from paddleaudio._internal import module_utils as _mod_utils
from .common import AudioInfo
#https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py
def _fail_info(filepath: str, format: Optional[str]) -> AudioInfo:
raise RuntimeError("Failed to fetch metadata from {}".format(filepath))
......@@ -22,73 +22,78 @@ def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioInfo:
# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[Tensor, int]:
filepath: str,
frame_offset: int=0,
num_frames: int=-1,
normalize: bool=True,
channels_first: bool=True,
format: Optional[str]=None, ) -> Tuple[Tensor, int]:
raise RuntimeError("Failed to load audio from {}".format(filepath))
def _fail_load_fileobj(fileobj, *args, **kwargs):
raise RuntimeError(f"Failed to load audio from {fileobj}")
_fallback_info = _fail_info
_fallback_info_fileobj = _fail_info_fileobj
_fallback_load = _fail_load
_fallback_load_filebj = _fail_load_fileobj
@_mod_utils.requires_sox()
def load(
filepath: str,
frame_offset: int = 0,
frame_offset: int=0,
num_frames: int=-1,
normalize: bool = True,
channels_first: bool = True,
normalize: bool=True,
channels_first: bool=True,
format: Optional[str]=None, ) -> Tuple[Tensor, int]:
if hasattr(filepath, "read"):
ret = paddleaudio._paddleaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
filepath, frame_offset, num_frames, normalize, channels_first,
format)
if ret is not None:
audio_tensor = paddle.to_tensor(ret[0])
return (audio_tensor, ret[1])
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
return _fallback_load_fileobj(filepath, frame_offset, num_frames,
normalize, channels_first, format)
filepath = os.fspath(filepath)
ret = paddleaudio._paddleaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
filepath, frame_offset, num_frames, normalize, channels_first, format)
if ret is not None:
audio_tensor = paddle.to_tensor(ret[0])
return (audio_tensor, ret[1])
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
return _fallback_load(filepath, frame_offset, num_frames, normalize,
channels_first, format)
@_mod_utils.requires_sox()
def save(filepath: str,
src: Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
def save(
filepath: str,
src: Tensor,
sample_rate: int,
channels_first: bool=True,
compression: Optional[float]=None,
format: Optional[str]=None,
encoding: Optional[str]=None,
bits_per_sample: Optional[int]=None, ):
src_arr = src.numpy()
if hasattr(filepath, "write"):
paddleaudio._paddleaudio.save_audio_fileobj(
filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
filepath, src_arr, sample_rate, channels_first, compression, format,
encoding, bits_per_sample)
return
filepath = os.fspath(filepath)
paddleaudio._paddleaudio.sox_io_save_audio_file(
filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
filepath, src_arr, sample_rate, channels_first, compression, format,
encoding, bits_per_sample)
@_mod_utils.requires_sox()
def info(filepath: str, format: Optional[str] = None,) -> AudioInfo:
def info(
filepath: str,
format: Optional[str]=None, ) -> AudioInfo:
if hasattr(filepath, "read"):
sinfo = paddleaudio._paddleaudio.get_info_fileobj(filepath, format)
if sinfo is not None:
......
"""Defines utilities for switching audio backends"""
#code is from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/utils.py
import warnings
from typing import List
from typing import Optional
......@@ -8,7 +7,9 @@ from typing import Optional
import paddleaudio
from paddleaudio._internal import module_utils as _mod_utils
from . import no_backend, soundfile_backend, sox_io_backend
from . import no_backend
from . import soundfile_backend
from . import sox_io_backend
__all__ = [
"list_audio_backends",
......@@ -55,6 +56,7 @@ def set_audio_backend(backend: Optional[str]):
for func in ["save", "load", "info"]:
setattr(paddleaudio, func, getattr(module, func))
def _init_audio_backend():
backends = list_audio_backends()
if "soundfile" in backends:
......
......@@ -21,7 +21,7 @@ from .env import USER_HOME
from .error import ParameterError
from .log import Logger
from .log import logger
from .time import seconds_to_hms
from .time import Timer
from .numeric import depth_convert
from .numeric import pcm16to32
from .time import seconds_to_hms
from .time import Timer
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unility functions for Transformer."""
from typing import List
from typing import Tuple
import paddle
from .log import Logger
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Logger(__name__)
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).shape
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = paddle.shape(sequences[0])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = tuple(
max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
max_len = max([s.shape[0] for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype)
for i, tensor in enumerate(sequences):
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor
if length != 0:
out_tensor[i, :length] = tensor
else:
out_tensor[i, length] = tensor
else:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if length != 0:
out_tensor[:length, i] = tensor
else:
out_tensor[length, i] = tensor
return out_tensor
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (paddle.Tensor) : (B, Lmax + 1)
ys_out (paddle.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1)
mask_pad = (ys_in == ignore_id)
ys_in = ys_in.masked_fill(mask_pad, eos)
ys_out = paddle.cat([ys_pad, _eos], dim=1)
ys_out = ys_out.masked_fill(mask_pad, eos)
mask_eos = (ys_out == ignore_id)
ys_out = ys_out.masked_fill(mask_eos, eos)
ys_out = ys_out.masked_fill(mask_pad, ignore_id)
return ys_in, ys_out
def th_accuracy(pad_outputs: paddle.Tensor,
pad_targets: paddle.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = paddle.sum(numerator.type_as(pad_targets))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)
import itertools
from unittest import skipIf
from parameterized import parameterized
from paddleaudio._internal.module_utils import is_module_available
from parameterized import parameterized
def name_func(func, _, params):
......@@ -31,7 +31,8 @@ def skipIfFormatNotSupported(fmt):
def parameterize(*params):
return parameterized.expand(list(itertools.product(*params)), name_func=name_func)
return parameterized.expand(
list(itertools.product(*params)), name_func=name_func)
def fetch_wav_subtype(dtype, encoding, bits_per_sample):
......@@ -54,4 +55,3 @@ def fetch_wav_subtype(dtype, encoding, bits_per_sample):
if subtype:
return subtype
raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).")
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/info_test.py
import tarfile
import warnings
import unittest
import warnings
from unittest.mock import patch
import paddle
from paddleaudio._internal import module_utils as _mod_utils
import soundfile
from common import parameterize
from common import skipIfFormatNotSupported
from paddleaudio.backends import soundfile_backend
from tests.backends.common import get_bits_per_sample, get_encoding
from tests.common_utils import (
get_wav_data,
nested_params,
save_wav,
TempDirMixin,
)
from common import parameterize, skipIfFormatNotSupported
import soundfile
from tests.backends.common import get_bits_per_sample
from tests.backends.common import get_encoding
from tests.common_utils import get_wav_data
from tests.common_utils import nested_params
from tests.common_utils import save_wav
from tests.common_utils import TempDirMixin
class TestInfo(TempDirMixin, unittest.TestCase):
@parameterize(
["float32", "int32"],
[8000, 16000],
[1, 2],
)
[1, 2], )
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
data = get_wav_data(
dtype,
num_channels,
normalize=False,
num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
......@@ -62,32 +62,31 @@ class TestInfo(TempDirMixin, unittest.TestCase):
#@parameterize([8000, 16000], [1, 2])
#@skipIfFormatNotSupported("OGG")
#def test_ogg(self, sample_rate, num_channels):
#"""`soundfile_backend.info` can check ogg file correctly"""
#duration = 1
#num_frames = sample_rate * duration
##data = torch.randn(num_frames, num_channels).numpy()
#data = paddle.randn(shape=[num_frames, num_channels]).numpy()
#print(len(data))
#path = self.get_temp_path("data.ogg")
#soundfile.write(path, data, sample_rate)
#info = soundfile_backend.info(path)
#print(info)
#assert info.sample_rate == sample_rate
#print("info")
#print(info.num_frames)
#print("jiji")
#print(sample_rate*duration)
##assert info.num_frames == sample_rate * duration
#assert info.num_channels == num_channels
#assert info.bits_per_sample == 0
#assert info.encoding == "VORBIS"
#"""`soundfile_backend.info` can check ogg file correctly"""
#duration = 1
#num_frames = sample_rate * duration
##data = torch.randn(num_frames, num_channels).numpy()
#data = paddle.randn(shape=[num_frames, num_channels]).numpy()
#print(len(data))
#path = self.get_temp_path("data.ogg")
#soundfile.write(path, data, sample_rate)
#info = soundfile_backend.info(path)
#print(info)
#assert info.sample_rate == sample_rate
#print("info")
#print(info.num_frames)
#print("jiji")
#print(sample_rate*duration)
##assert info.num_frames == sample_rate * duration
#assert info.num_channels == num_channels
#assert info.bits_per_sample == 0
#assert info.encoding == "VORBIS"
@nested_params(
[8000, 16000],
[1, 2],
[("PCM_24", 24), ("PCM_32", 32)],
)
[("PCM_24", 24), ("PCM_32", 32)], )
@skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
"""`soundfile_backend.info` can check sph file correctly"""
......@@ -127,7 +126,8 @@ class TestInfo(TempDirMixin, unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
info = soundfile_backend.info("foo")
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to PaddleAudio" in str(w[-1].message)
assert "UNSEEN_SUBTYPE subtype is unknown to PaddleAudio" in str(
w[-1].message)
assert info.bits_per_sample == 0
......@@ -195,5 +195,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
"""Query compressed audio via file-like object works"""
self._test_tarobj("flac", "PCM_16", 16)
if __name__ == '__main__':
unittest.main()
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/load_test.py
import os
import tarfile
import unittest
from unittest.mock import patch
import numpy as np
from parameterized import parameterized
import numpy as np
import paddle
from paddleaudio._internal import module_utils as _mod_utils
import soundfile
from common import dtype2subtype
from common import parameterize
from common import skipIfFormatNotSupported
from paddleaudio.backends import soundfile_backend
from tests.backends.common import get_bits_per_sample, get_encoding
from tests.common_utils import (
get_wav_data,
load_wav,
nested_params,
normalize_wav,
save_wav,
TempDirMixin,
)
from common import dtype2subtype, parameterize, skipIfFormatNotSupported
from parameterized import parameterized
import soundfile
from tests.common_utils import get_wav_data
from tests.common_utils import load_wav
from tests.common_utils import normalize_wav
from tests.common_utils import save_wav
from tests.common_utils import TempDirMixin
def _get_mock_path(
ext: str,
dtype: str,
sample_rate: int,
num_channels: int,
num_frames: int,
):
ext: str,
dtype: str,
sample_rate: int,
num_channels: int,
num_frames: int, ):
return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}"
......@@ -87,9 +81,8 @@ class SoundFileMock:
self._params["num_channels"],
normalize=False,
num_frames=self._params["num_frames"],
channels_first=False,
).numpy()
return data[self._start : self._start + frames]
channels_first=False, ).numpy()
return data[self._start:self._start + frames]
def __enter__(self):
return self
......@@ -99,13 +92,17 @@ class SoundFileMock:
class MockedLoadTest(unittest.TestCase):
def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first):
def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize,
channels_first):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames = 3 * sample_rate
path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames)
expected_dtype = paddle.float32 if normalize or ext not in ["wav", "nist"] else getattr(paddle, dtype)
expected_dtype = paddle.float32 if normalize or ext not in [
"wav", "nist"
] else getattr(paddle, dtype)
with patch("soundfile.SoundFile", SoundFileMock):
found, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
found, sr = soundfile_backend.load(
path, normalize=normalize, channels_first=channels_first)
assert found.dtype == expected_dtype
assert sample_rate == sr
......@@ -114,44 +111,47 @@ class MockedLoadTest(unittest.TestCase):
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
[True, False], )
def test_wav(self, dtype, sample_rate, num_channels, normalize,
channels_first):
"""Returns native dtype when normalize=False else float32"""
self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize, channels_first)
self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize,
channels_first)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first):
[True, False], )
def test_sphere(self, dtype, sample_rate, num_channels, normalize,
channels_first):
"""Returns float32 always"""
self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize, channels_first)
self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize,
channels_first)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_ogg(self, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize, channels_first)
self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize,
channels_first)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_flac(self, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load ogg format."""
self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize, channels_first)
self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize,
channels_first)
class LoadTestBase(TempDirMixin, unittest.TestCase):
def assert_wav(
self,
dtype,
sample_rate,
num_channels,
normalize,
channels_first=True,
duration=1,
):
self,
dtype,
sample_rate,
num_channels,
normalize,
channels_first=True,
duration=1, ):
"""`soundfile_backend.load` can load wav format correctly.
Wav data loaded with soundfile backend should match those with scipy
......@@ -163,22 +163,22 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
num_channels,
normalize=normalize,
num_frames=num_frames,
channels_first=channels_first,
)
channels_first=channels_first, )
save_wav(path, data, sample_rate, channels_first=channels_first)
expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0]
data, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
expected = load_wav(
path, normalize=normalize, channels_first=channels_first)[0]
data, sr = soundfile_backend.load(
path, normalize=normalize, channels_first=channels_first)
assert sr == sample_rate
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
def assert_sphere(
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
):
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1, ):
"""`soundfile_backend.load` can load SPHERE format correctly."""
path = self.get_temp_path("reference.sph")
num_frames = duration * sample_rate
......@@ -187,9 +187,9 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST")
channels_first=False, )
soundfile.write(
path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST")
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
......@@ -197,13 +197,12 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
def assert_flac(
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
):
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1, ):
"""`soundfile_backend.load` can load FLAC format correctly."""
path = self.get_temp_path("reference.flac")
num_frames = duration * sample_rate
......@@ -212,15 +211,13 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
channels_first=False, )
soundfile.write(path, raw, sample_rate)
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
#self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
np.testing.assert_array_almost_equal(data.numpy(), expected.numpy())
class TestLoad(LoadTestBase):
......@@ -231,41 +228,43 @@ class TestLoad(LoadTestBase):
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
[False, True], )
def test_wav(self, dtype, sample_rate, num_channels, normalize,
channels_first):
"""`soundfile_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
self.assert_wav(dtype, sample_rate, num_channels, normalize,
channels_first)
@parameterize(
["int32"],
[16000],
[2],
[False],
)
[False], )
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`soundfile_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours)
self.assert_wav(
dtype, sample_rate, num_channels, normalize, duration=two_hours)
@parameterize(["float32", "int32"], [4, 8, 16, 32], [False, True])
def test_multiple_channels(self, dtype, num_channels, channels_first):
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
self.assert_wav(dtype, sample_rate, num_channels, normalize,
channels_first)
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
#@skipIfFormatNotSupported("NIST")
#def test_sphere(self, dtype, sample_rate, num_channels, channels_first):
#"""`soundfile_backend.load` can load sphere format correctly."""
#self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
#"""`soundfile_backend.load` can load sphere format correctly."""
#self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
#@skipIfFormatNotSupported("FLAC")
#def test_flac(self, dtype, sample_rate, num_channels, channels_first):
#"""`soundfile_backend.load` can load flac format correctly."""
#self.assert_flac(dtype, sample_rate, num_channels, channels_first)
#"""`soundfile_backend.load` can load flac format correctly."""
#self.assert_flac(dtype, sample_rate, num_channels, channels_first)
class TestLoadFormat(TempDirMixin, unittest.TestCase):
......@@ -291,21 +290,17 @@ class TestLoadFormat(TempDirMixin, unittest.TestCase):
#self.assertEqual(found, expected)
np.testing.assert_array_almost_equal(found, expected)
@parameterized.expand(
[
("WAV",),
("wav",),
]
)
@parameterized.expand([
("WAV", ),
("wav", ),
])
def test_wav(self, format_):
self._test_format(format_)
@parameterized.expand(
[
("FLAC",),
("flac",),
]
)
@parameterized.expand([
("FLAC", ),
("flac", ),
])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, format_):
self._test_format(format_)
......@@ -356,7 +351,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
#self.assertEqual(expected, found)
np.testing.assert_array_almost_equal(found.numpy(), expected)
def test_tarfile_wav(self):
"""Loading audio via file-like object works"""
self._test_tarfile("wav")
......@@ -365,5 +359,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
"""Loading audio via file-like object works"""
self._test_tarfile("flac")
if __name__ == '__main__':
unittest.main()
......@@ -2,23 +2,18 @@ import io
import unittest
from unittest.mock import patch
from paddleaudio._internal import module_utils as _mod_utils
from paddleaudio.backends import soundfile_backend
from tests.common_utils import (
get_wav_data,
load_wav,
nested_params,
normalize_wav,
save_wav,
TempDirMixin,
)
from common import fetch_wav_subtype, parameterize, skipIfFormatNotSupported
import paddle
import numpy as np
import paddle
import soundfile
from common import fetch_wav_subtype
from common import parameterize
from common import skipIfFormatNotSupported
from paddleaudio.backends import soundfile_backend
from tests.common_utils import get_wav_data
from tests.common_utils import load_wav
from tests.common_utils import nested_params
from tests.common_utils import TempDirMixin
class MockedSaveTest(unittest.TestCase):
......@@ -41,10 +36,10 @@ class MockedSaveTest(unittest.TestCase):
("ULAW", 8),
("ALAW", None),
("ALAW", 8),
],
)
], )
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first, enc_params, mocked_write):
def test_wav(self, dtype, sample_rate, num_channels, channels_first,
enc_params, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
......@@ -52,8 +47,7 @@ class MockedSaveTest(unittest.TestCase):
num_channels,
num_frames=3 * sample_rate,
normalize=dtype == "float32",
channels_first=channels_first,
)
channels_first=channels_first, )
input_tensor = paddle.transpose(input_tensor, [1, 0])
encoding, bits_per_sample = enc_params
......@@ -63,33 +57,32 @@ class MockedSaveTest(unittest.TestCase):
sample_rate,
channels_first=channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
bits_per_sample=bits_per_sample, )
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == fetch_wav_subtype(dtype, encoding, bits_per_sample)
assert args["subtype"] == fetch_wav_subtype(dtype, encoding,
bits_per_sample)
assert args["format"] is None
tensor_result = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor
tensor_result = paddle.transpose(
input_tensor, [1, 0]) if channels_first else input_tensor
#self.assertEqual(args["data"], tensor_result.numpy())
np.testing.assert_array_almost_equal(args["data"].numpy(), tensor_result.numpy())
np.testing.assert_array_almost_equal(args["data"].numpy(),
tensor_result.numpy())
@patch("soundfile.write")
def assert_non_wav(
self,
fmt,
dtype,
sample_rate,
num_channels,
channels_first,
mocked_write,
encoding=None,
bits_per_sample=None,
):
self,
fmt,
dtype,
sample_rate,
num_channels,
channels_first,
mocked_write,
encoding=None,
bits_per_sample=None, ):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
input_tensor = get_wav_data(
......@@ -97,11 +90,11 @@ class MockedSaveTest(unittest.TestCase):
num_channels,
num_frames=3 * sample_rate,
normalize=False,
channels_first=channels_first,
)
channels_first=channels_first, )
input_tensor = paddle.transpose(input_tensor, [1, 0])
expected_data = paddle.transpose(input_tensor, [1, 0]) if channels_first else input_tensor
expected_data = paddle.transpose(
input_tensor, [1, 0]) if channels_first else input_tensor
soundfile_backend.save(
filepath,
......@@ -109,8 +102,7 @@ class MockedSaveTest(unittest.TestCase):
sample_rate,
channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
bits_per_sample=bits_per_sample, )
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
......@@ -120,7 +112,8 @@ class MockedSaveTest(unittest.TestCase):
assert args["format"] == "NIST"
else:
assert args["format"] is None
np.testing.assert_array_almost_equal(args["data"].numpy(), expected_data.numpy())
np.testing.assert_array_almost_equal(args["data"].numpy(),
expected_data.numpy())
#self.assertEqual(args["data"], expected_data)
@nested_params(
......@@ -139,45 +132,57 @@ class MockedSaveTest(unittest.TestCase):
("ALAW", 16),
("ALAW", 24),
("ALAW", 32),
],
)
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
], )
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first,
enc_params):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
encoding, bits_per_sample = enc_params
self.assert_non_wav(
fmt, dtype, sample_rate, num_channels, channels_first, encoding=encoding, bits_per_sample=bits_per_sample
)
fmt,
dtype,
sample_rate,
num_channels,
channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
[False, True],
[8, 16, 24],
)
def test_flac(self, dtype, sample_rate, num_channels, channels_first, bits_per_sample):
[8, 16, 24], )
def test_flac(self, dtype, sample_rate, num_channels, channels_first,
bits_per_sample):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first, bits_per_sample=bits_per_sample)
self.assert_non_wav(
"flac",
dtype,
sample_rate,
num_channels,
channels_first,
bits_per_sample=bits_per_sample)
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
[False, True],
)
[False, True], )
def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first)
self.assert_non_wav("ogg", dtype, sample_rate, num_channels,
channels_first)
class SaveTestBase(TempDirMixin, unittest.TestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`soundfile_backend.save` can save wav format."""
path = self.get_temp_path("data.wav")
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False)
soundfile_backend.save(path, expected, sample_rate)
found, sr = load_wav(path, normalize=False)
assert sample_rate == sr
......@@ -192,7 +197,8 @@ class SaveTestBase(TempDirMixin, unittest.TestCase):
"""
num_frames = sample_rate * 3
path = self.get_temp_path(f"data.{fmt}")
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False)
soundfile_backend.save(path, expected, sample_rate)
sinfo = soundfile.info(path)
assert sinfo.format == fmt.upper()
......@@ -220,16 +226,14 @@ class TestSave(SaveTestBase):
@parameterize(
["float32", "int32"],
[8000, 16000],
[1, 2],
)
[1, 2], )
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["float32", "int32"],
[4, 8, 16, 32],
)
[4, 8, 16, 32], )
def test_multiple_channels(self, dtype, num_channels):
"""`soundfile_backend.save` can save wav with more than 2 channels."""
sample_rate = 8000
......@@ -238,8 +242,7 @@ class TestSave(SaveTestBase):
@parameterize(
["int32"],
[8000, 16000],
[1, 2],
)
[1, 2], )
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
......@@ -247,8 +250,7 @@ class TestSave(SaveTestBase):
@parameterize(
[8000, 16000],
[1, 2],
)
[1, 2], )
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
......@@ -256,8 +258,7 @@ class TestSave(SaveTestBase):
@parameterize(
[8000, 16000],
[1, 2],
)
[1, 2], )
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg/vorbis format."""
......@@ -318,5 +319,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
"""Saving audio via file-like object works"""
self._test_fileobj("OGG")
if __name__ == '__main__':
unittest.main()
from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav
from .parameterized_utils import nested_params
from .case_utils import (
TempDirMixin,
name_func
)
from .case_utils import name_func
from .case_utils import TempDirMixin
from .parameterized_utils import nested_params
from .wav_utils import get_wav_data
from .wav_utils import load_wav
from .wav_utils import normalize_wav
from .wav_utils import save_wav
__all__ = [
"get_wav_data",
"load_wav",
"save_wav",
"normalize_wav",
"get_sinusoid",
"name_func",
"nested_params",
"TempDirMixin"
"get_wav_data", "load_wav", "save_wav", "normalize_wav", "get_sinusoid",
"name_func", "nested_params", "TempDirMixin"
]
from typing import Optional
import scipy.io.wavfile
import paddle
import numpy as np
import scipy.io.wavfile
def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor:
if tensor.dtype == paddle.float32:
......@@ -23,13 +23,12 @@ def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor:
def get_wav_data(
dtype: str,
num_channels: int,
*,
num_frames: Optional[int] = None,
normalize: bool = True,
channels_first: bool = True,
):
dtype: str,
num_channels: int,
*,
num_frames: Optional[int]=None,
normalize: bool=True,
channels_first: bool=True, ):
"""Generate linear signal of the given dtype and num_channels
Data range is
......@@ -53,25 +52,26 @@ def get_wav_data(
# paddle linspace not support uint8, int8, int16
#if dtype == "uint8":
# base = paddle.linspace(0, 255, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(0, 255, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(0, 255, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#elif dtype == "int8":
# base = paddle.linspace(-128, 127, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(-128, 127, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(-128, 127, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
if dtype == "float32":
base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_)
elif dtype == "float64":
base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_)
elif dtype == "int32":
base = paddle.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
base = paddle.linspace(
-2147483648, 2147483647, num_frames, dtype=dtype_)
#elif dtype == "int16":
# base = paddle.linspace(-32768, 32767, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(-32768, 32767, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(-32768, 32767, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
data = base.tile([num_channels, 1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册