提交 d2641184 编写于 作者: Y YangZhou

add save test&&fix effects_chain bug

上级 59d82c0c
from pathlib import Path
from typing import Callable
from typing import Optional
from typing import Tuple
from typing import Union
from typing import Optional, Tuple, Union
import paddle
from paddle import Tensor
from .common import AudioMetaData
import os
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio import _paddleaudio as paddleaudio
......@@ -48,31 +48,53 @@ def load(
normalize: bool = True,
channels_first: bool = True,
format: Optional[str]=None, ) -> Tuple[Tensor, int]:
if hasattr(filepath, "read"):
ret = paddleaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
audio_tensor = paddle.to_tensor(ret[0])
return (audio_tensor, ret[1])
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath = os.fspath(filepath)
ret = paddleaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
return ret
audio_tensor = paddle.to_tensor(ret[0])
return (audio_tensor, ret[1])
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_sox()
def save(filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None) -> Tuple[Tensor, int]:
ret = paddleaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
def save(filepath: str,
src: Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
src_arr = src.numpy()
if hasattr(filepath, "write"):
paddleaudio.save_audio_fileobj(
filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
return
filepath = os.fspath(filepath)
paddleaudio.sox_io_save_audio_file(
filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
if ret is not None:
return ret
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_sox()
def info(filepath: str, format: Optional[str]) -> None:
if hasattr(filepath, "read"):
sinfo = paddleaudio.get_info_fileojb(filepath, format)
if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info_fileobj(filepath, format)
filepath = os.fspath(filepath)
sinfo = paddleaudio.get_info_file(filepath, format)
if sinfo is not None:
return AudioMetaData(*sinfo)
......
......@@ -21,7 +21,7 @@ PYBIND11_MODULE(_paddleaudio, m) {
&paddleaudio::sox_io::get_info_file,
"Get metadata of audio file.");
// support obj later
/*m.def("get_info_fileobj",
m.def("get_info_fileobj",
&paddleaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def("load_audio_fileobj",
......@@ -30,7 +30,7 @@ PYBIND11_MODULE(_paddleaudio, m) {
m.def("save_audio_fileobj",
&paddleaudio::sox_io::save_audio_fileobj,
"Save audio to file obj.");
*/
// sox io
m.def("sox_io_get_info", &paddleaudio::sox_io::get_info_file);
m.def(
......
#include <sox.h>
#include <iostream>
#include <vector>
#include "paddlespeech/audio/src/pybind/sox/effects_chain.h"
#include "paddlespeech/audio/src/pybind/sox/utils.h"
......@@ -42,6 +43,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
if (index + *osamp > num_samples) {
*osamp = num_samples - index;
}
// Ensure that it's a multiple of the number of channels
*osamp -= *osamp % num_channels;
......@@ -49,52 +51,80 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// refacor this module, chunk
auto i_frame = index / num_channels;
auto num_frames = *osamp / num_channels;
py::array chunk(tensor.dtype(), {num_frames*num_channels});
std::vector<int> chunk(num_frames*num_channels);
py::buffer_info ori_info = tensor.request();
py::buffer_info info = chunk.request();
char* ori_start_ptr = (char*)ori_info.ptr + index * chunk.itemsize() / sizeof(char);
std::memcpy(info.ptr, ori_start_ptr, chunk.nbytes());
py::dtype chunk_type = py::dtype("i"); // dtype int32
py::array new_chunk = py::array(chunk_type, chunk.shape());
py::buffer_info new_info = new_chunk.request();
void* ptr = (void*) info.ptr;
int* new_ptr = (int*) new_info.ptr;
void* ptr = ori_info.ptr;
// Convert to sox_sample_t (int32_t)
switch (chunk.dtype().num()) {
switch (tensor.dtype().num()) {
//case c10::ScalarType::Float: {
case 11: {
break;
// Need to convert to 64-bit precision so that
// values around INT32_MIN/MAX are handled correctly.
float* ptr_f = (float*)ptr;
for (int idx = 0; idx < chunk.size(); ++idx) {
double elem = *ptr_f * 2147483648.;
int frame_idx = (idx + index) / num_channels;
int channels_idx = (idx + index) % num_channels;
double elem = 0;
if (priv->channels_first) {
elem = *(float*)tensor.data(channels_idx, frame_idx);
} else {
elem = *(float*)tensor.data(frame_idx, channels_idx);
}
elem = elem * 2147483648.;
// *new_ptr = std::clamp(elem, INT32_MIN, INT32_MAX);
if (elem > INT32_MAX) {
*new_ptr = INT32_MAX;
chunk[idx] = INT32_MAX;
} else if (elem < INT32_MIN) {
*new_ptr = INT32_MIN;
} else { *new_ptr = elem; }
chunk[idx] = INT32_MIN;
} else {
chunk[idx] = elem;
}
}
break;
}
//case c10::ScalarType::Int: {
case 5: {
for (int idx = 0; idx < chunk.size(); ++idx) {
int frame_idx = (idx + index) / num_channels;
int channels_idx = (idx + index) % num_channels;
int elem = 0;
if (priv->channels_first) {
elem = *(int*)tensor.data(channels_idx, frame_idx);
} else {
elem = *(int*)tensor.data(frame_idx, channels_idx);
}
chunk[idx] = elem;
}
break;
}
// case short
case 3: {
int16_t* ptr_s = (int16_t*) ptr;
for (int idx = 0; idx < chunk.size(); ++idx) {
*new_ptr = *ptr_s * 65536;
int frame_idx = (idx + index) / num_channels;
int channels_idx = (idx + index) % num_channels;
int16_t elem = 0;
if (priv->channels_first) {
elem = *(int16_t*)tensor.data(channels_idx, frame_idx);
} else {
elem = *(int16_t*)tensor.data(frame_idx, channels_idx);
}
chunk[idx] = elem * 65536;
}
break;
}
// case byte
case 1: {
int8_t* ptr_b = (int8_t*) ptr;
for (int idx = 0; idx < chunk.size(); ++idx) {
*new_ptr = (*ptr_b - 128) * 16777216;
int frame_idx = (idx + index) / num_channels;
int channels_idx = (idx + index) % num_channels;
int8_t elem = 0;
if (priv->channels_first) {
elem = *(int8_t*)tensor.data(channels_idx, frame_idx);
} else {
elem = *(int8_t*)tensor.data(frame_idx, channels_idx);
}
chunk[idx] = (elem - 128) * 16777216;
}
break;
}
......@@ -102,7 +132,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
throw std::runtime_error("Unexpected dtype.");
}
// Write to buffer
memcpy(obuf, (int*)new_info.ptr, *osamp * 4);
memcpy(obuf, chunk.data(), *osamp * 4);
priv->index += *osamp;
return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS;
}
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import unittest
import numpy as np
import paddle
from parameterized import parameterized
from paddlespeech.audio.backends import sox_io_backend
class TestInfo(unittest.TestCase):
def test_wav(self, dtype, sample_rate, num_channels, sample_size):
"""check wav file correctly """
path = 'testdata/test.wav'
info = sox_io_backend.get_info_file(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_size # duration*sample_rate
assert info.num_channels == num_channels
assert info.bits_per_sample == get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
from tests.unit.common_utils import (
get_wav_data,
load_wav,
save_wav,
nested_params,
TempDirMixin,
sox_utils
)
#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/save_test.py
def _get_sox_encoding(encoding):
encodings = {
"PCM_F": "floating-point",
"PCM_S": "signed-integer",
"PCM_U": "unsigned-integer",
"ULAW": "u-law",
"ALAW": "a-law",
}
return encodings.get(encoding)
class TestSaveBase(TempDirMixin):
def assert_save_consistency(
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
num_channels: int = 2,
num_frames: float = 3 * 8000,
src_dtype: str = "int32",
test_mode: str = "path",
):
"""`save` function produces file that is comparable with `sox` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `sox` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `sox` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `sox` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to the target
| then save it into the target | format depth with sox
| format with torchaudio |
v v
target format target format
| |
| 2.2. Convert to wav with sox | 3.2. Convert to wav with sox
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
cmp_encoding = "floating-point"
cmp_bit_depth = 32
src_path = self.get_temp_path("1.source.wav")
tgt_path = self.get_temp_path(f"2.1.torchaudio.{format}")
tst_path = self.get_temp_path("2.2.result.wav")
sox_path = self.get_temp_path(f"3.1.sox.{format}")
ref_path = self.get_temp_path("3.2.ref.wav")
# 1. Generate original wav
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
sox_io_backend.save(
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
sox_io_backend.save(
file_,
data,
sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
sox_io_backend.save(
file_,
data,
sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
file_.seek(0)
with open(tgt_path, "bw") as f:
f.write(file_.read())
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]
# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]
np.testing.assert_array_almost_equal(found, expected)
class TestSave(TestSaveBase, unittest.TestCase):
@nested_params(
["path",],
[
("PCM_U", 8),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", 8),
("ALAW", 8),
],
)
def test_save_wav(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", ],
[
("float32",),
("int32",),
("int16",),
("uint8",),
],
)
def test_save_wav_dtype(self, test_mode, params):
(dtype,) = params
self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav
from .parameterized_utils import load_params, nested_params
from .case_utils import (
TempDirMixin
)
__all__ = [
"get_wav_data",
"load_wav",
"save_wav",
"normalize_wav"
"normalize_wav",
"load_params",
"nested_params",
]
import functools
import os.path
import shutil
import subprocess
import sys
import tempfile
import time
import unittest
import paddle
from paddlespeech.audio._internal.module_utils import (
is_kaldi_available,
is_module_available,
is_sox_available,
)
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
@classmethod
def get_base_temp_dir(cls):
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = "TORCHAUDIO_TEST_TEMP_DIR"
if key in os.environ:
return os.environ[key]
if cls.temp_dir_ is None:
cls.temp_dir_ = tempfile.TemporaryDirectory()
return cls.temp_dir_.name
@classmethod
def tearDownClass(cls):
if cls.temp_dir_ is not None:
try:
cls.temp_dir_.cleanup()
cls.temp_dir_ = None
except PermissionError:
# On Windows there is a know issue with `shutil.rmtree`,
# which fails intermittenly.
#
# https://github.com/python/cpython/issues/74168
#
# We observed this on CircleCI, where Windows job raises
# PermissionError.
#
# Following the above thread, we ignore it.
pass
super().tearDownClass()
def get_temp_path(self, *paths):
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
path = os.path.join(temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
import json
from itertools import product
from parameterized import param, parameterized
def get_asset_path(*paths):
"""Return full path of a test asset"""
return os.path.join(_TEST_DIR_PATH, "assets", *paths)
def load_params(*paths):
with open(get_asset_path(*paths), "r") as file:
return [param(json.loads(line)) for line in file]
def _name_func(func, _, params):
strs = []
for arg in params.args:
if isinstance(arg, tuple):
strs.append("_".join(str(a) for a in arg))
else:
strs.append(str(arg))
# sanitize the test name
name = "_".join(strs)
return parameterized.to_safe_name(f"{func.__name__}_{name}")
def nested_params(*params_set, name_func=_name_func):
"""Generate the cartesian product of the given list of parameters.
Args:
params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
all the parameters have to be specified with the class, only using kwargs.
"""
flatten = [p for params in params_set for p in params]
# Parameters to be nested are given as list of plain objects
if all(not isinstance(p, param) for p in flatten):
args = list(product(*params_set))
return parameterized.expand(args, name_func=_name_func)
# Parameters to be nested are given as list of `parameterized.param`
if not all(isinstance(p, param) for p in flatten):
raise TypeError("When using ``parameterized.param``, " "all the parameters have to be of the ``param`` type.")
if any(p.args for p in flatten):
raise ValueError(
"When using ``parameterized.param``, " "all the parameters have to be provided as keyword argument."
)
args = [param()]
for params in params_set:
args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
return parameterized.expand(args)
import subprocess
import sys
import warnings
def get_encoding(dtype):
encodings = {
"float32": "floating-point",
"int32": "signed-integer",
"int16": "signed-integer",
"uint8": "unsigned-integer",
}
return encodings[dtype]
def get_bit_depth(dtype):
bit_depths = {
"float32": 32,
"int32": 32,
"int16": 16,
"uint8": 8,
}
return bit_depths[dtype]
def gen_audio_file(
path,
sample_rate,
num_channels,
*,
encoding=None,
bit_depth=None,
compression=None,
attenuation=None,
duration=1,
comment_file=None,
):
"""Generate synthetic audio file with `sox` command."""
if path.endswith(".wav"):
warnings.warn("Use get_wav_data and save_wav to generate wav file for accurate result.")
command = [
"sox",
"-V3", # verbose
"--no-dither", # disable automatic dithering
"-R",
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
# https://fossies.org/dox/sox-14.4.2/sox_8c_source.html
# search "sox_globals.repeatable"
]
if bit_depth is not None:
command += ["--bits", str(bit_depth)]
command += [
"--rate",
str(sample_rate),
"--null", # no input
"--channels",
str(num_channels),
]
if compression is not None:
command += ["--compression", str(compression)]
if bit_depth is not None:
command += ["--bits", str(bit_depth)]
if encoding is not None:
command += ["--encoding", str(encoding)]
if comment_file is not None:
command += ["--comment-file", str(comment_file)]
command += [
str(path),
"synth",
str(duration), # synthesizes for the given duration [sec]
"sawtooth",
"1",
# saw tooth covers the both ends of value range, which is a good property for test.
# similar to linspace(-1., 1.)
# this introduces bigger boundary effect than sine when converted to mp3
]
if attenuation is not None:
command += ["vol", f"-{attenuation}dB"]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
def convert_audio_file(src_path, dst_path, *, encoding=None, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ["sox", "-V3", "--no-dither", "-R", str(src_path)]
if encoding is not None:
command += ["--encoding", str(encoding)]
if bit_depth is not None:
command += ["--bits", str(bit_depth)]
if compression is not None:
command += ["--compression", str(compression)]
command += [dst_path]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
def _flattern(effects):
if not effects:
return effects
if isinstance(effects[0], str):
return effects
return [item for sublist in effects for item in sublist]
def run_sox_effect(input_file, output_file, effect, *, output_sample_rate=None, output_bitdepth=None):
"""Run sox effects"""
effect = _flattern(effect)
command = ["sox", "-V", "--no-dither", input_file]
if output_bitdepth:
command += ["--bits", str(output_bitdepth)]
command += [output_file] + effect
if output_sample_rate:
command += ["rate", str(output_sample_rate)]
print(" ".join(command))
subprocess.run(command, check=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册