未验证 提交 a3911ab5 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #2089 from zh794390558/cpplint

[audio] format code
......@@ -5,7 +5,7 @@ repos:
- id: yapf
files: \.py$
exclude: (?=third_party).*(\.py)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
......@@ -76,4 +76,4 @@ repos:
entry: bash .pre-commit-hooks/cpplint.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
\ No newline at end of file
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|paddlespeech/audio/src/optional).*(\.cpp|\.cc|\.h|\.hpp)$
\ No newline at end of file
#include "pybind/sox/io.h"
PYBIND11_MODULE(_paddleaudio, m) {
m.def("get_info_file", &paddleaudio::sox_io::get_info_file,
"Get metadata of audio file.");
m.def("get_info_fileobj", &paddleaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def("get_info_file",
&paddleaudio::sox_io::get_info_file,
"Get metadata of audio file.");
m.def("get_info_fileobj",
&paddleaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
}
\ No newline at end of file
......@@ -8,51 +8,54 @@ namespace sox_io {
auto get_info_file(const std::string &path, const std::string &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
SoxFormat sf(sox_open_read(path.data(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
validate_input_file(sf, path);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
SoxFormat sf(
sox_open_read(path.data(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
validate_input_file(sf, path);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
auto get_info_fileobj(py::object fileobj, const std::string &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
const auto capacity = [&]() {
const auto bufsiz = get_buffer_size();
const int64_t kDefaultCapacityInBytes = 4096;
return (bufsiz > kDefaultCapacityInBytes) ? bufsiz
: kDefaultCapacityInBytes;
}();
std::string buffer(capacity, '\0');
auto *buf = const_cast<char *>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(buf, buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
const auto capacity = [&]() {
const auto bufsiz = get_buffer_size();
const int64_t kDefaultCapacityInBytes = 4096;
return (bufsiz > kDefaultCapacityInBytes) ? bufsiz
: kDefaultCapacityInBytes;
}();
std::string buffer(capacity, '\0');
auto *buf = const_cast<char *>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(
buf,
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
} // namespace paddleaudio
} // namespace sox_io
} // namespace paddleaudio
} // namespace sox_io
......@@ -12,7 +12,7 @@ auto get_info_file(const std::string &path, const std::string &format)
auto get_info_fileobj(py::object fileobj, const std::string &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
} // namespace paddleaudio
} // namespace sox_io
} // namespace paddleaudio
} // namespace sox_io
#endif
......@@ -12,86 +12,87 @@ sox_format_t *SoxFormat::operator->() const noexcept { return fd_; }
SoxFormat::operator sox_format_t *() const noexcept { return fd_; }
void SoxFormat::close() {
if (fd_ != nullptr) {
sox_close(fd_);
fd_ = nullptr;
}
if (fd_ != nullptr) {
sox_close(fd_);
fd_ = nullptr;
}
}
auto read_fileobj(py::object *fileobj, const uint64_t size, char *buffer)
-> uint64_t {
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message << "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file "
"object.";
throw std::runtime_error(message.str());
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file "
"object.";
throw std::runtime_error(message.str());
}
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
return num_read;
return num_read;
}
int64_t get_buffer_size() { return sox_get_globals()->bufsiz; }
void validate_input_file(const SoxFormat &sf, const std::string &path) {
if (static_cast<sox_format_t *>(sf) == nullptr) {
throw std::runtime_error("Error loading audio file: failed to open file " +
path);
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
if (static_cast<sox_format_t *>(sf) == nullptr) {
throw std::runtime_error(
"Error loading audio file: failed to open file " + path);
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
}
void validate_input_memfile(const SoxFormat &sf) {
return validate_input_file(sf, "<in memory buffer>");
return validate_input_file(sf, "<in memory buffer>");
}
std::string get_encoding(sox_encoding_t encoding) {
switch (encoding) {
case SOX_ENCODING_UNKNOWN:
return "UNKNOWN";
case SOX_ENCODING_SIGN2:
return "PCM_S";
case SOX_ENCODING_UNSIGNED:
return "PCM_U";
case SOX_ENCODING_FLOAT:
return "PCM_F";
case SOX_ENCODING_FLAC:
return "FLAC";
case SOX_ENCODING_ULAW:
return "ULAW";
case SOX_ENCODING_ALAW:
return "ALAW";
case SOX_ENCODING_MP3:
return "MP3";
case SOX_ENCODING_VORBIS:
return "VORBIS";
case SOX_ENCODING_AMR_WB:
return "AMR_WB";
case SOX_ENCODING_AMR_NB:
return "AMR_NB";
case SOX_ENCODING_OPUS:
return "OPUS";
case SOX_ENCODING_GSM:
return "GSM";
default:
return "UNKNOWN";
}
switch (encoding) {
case SOX_ENCODING_UNKNOWN:
return "UNKNOWN";
case SOX_ENCODING_SIGN2:
return "PCM_S";
case SOX_ENCODING_UNSIGNED:
return "PCM_U";
case SOX_ENCODING_FLOAT:
return "PCM_F";
case SOX_ENCODING_FLAC:
return "FLAC";
case SOX_ENCODING_ULAW:
return "ULAW";
case SOX_ENCODING_ALAW:
return "ALAW";
case SOX_ENCODING_MP3:
return "MP3";
case SOX_ENCODING_VORBIS:
return "VORBIS";
case SOX_ENCODING_AMR_WB:
return "AMR_WB";
case SOX_ENCODING_AMR_NB:
return "AMR_NB";
case SOX_ENCODING_OPUS:
return "OPUS";
case SOX_ENCODING_GSM:
return "GSM";
default:
return "UNKNOWN";
}
}
} // namespace paddleaudio
} // namespace sox_utils
} // namespace paddleaudio
} // namespace sox_utils
......@@ -11,19 +11,19 @@ namespace sox_utils {
/// helper class to automatically close sox_format_t*
struct SoxFormat {
explicit SoxFormat(sox_format_t *fd) noexcept;
SoxFormat(const SoxFormat &other) = delete;
SoxFormat(SoxFormat &&other) = delete;
SoxFormat &operator=(const SoxFormat &other) = delete;
SoxFormat &operator=(SoxFormat &&other) = delete;
~SoxFormat();
sox_format_t *operator->() const noexcept;
operator sox_format_t *() const noexcept;
void close();
private:
sox_format_t *fd_;
explicit SoxFormat(sox_format_t *fd) noexcept;
SoxFormat(const SoxFormat &other) = delete;
SoxFormat(SoxFormat &&other) = delete;
SoxFormat &operator=(const SoxFormat &other) = delete;
SoxFormat &operator=(SoxFormat &&other) = delete;
~SoxFormat();
sox_format_t *operator->() const noexcept;
operator sox_format_t *() const noexcept;
void close();
private:
sox_format_t *fd_;
};
auto read_fileobj(py::object *fileobj, uint64_t size, char *buffer) -> uint64_t;
......@@ -36,7 +36,7 @@ void validate_input_memfile(const SoxFormat &sf);
std::string get_encoding(sox_encoding_t encoding);
} // namespace paddleaudio
} // namespace sox_utils
} // namespace paddleaudio
} // namespace sox_utils
#endif
......@@ -14,5 +14,3 @@
import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
......@@ -28,4 +28,4 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```
\ No newline at end of file
```
......@@ -19,4 +19,4 @@ from . import io
from . import metric
from . import sox_effects
from .backends import load
from .backends import save
\ No newline at end of file
from .backends import save
import types
class _ClassNamespace(types.ModuleType):
def __init__(self, name):
super(_ClassNamespace, self).__init__('paddlespeech.classes' + name)
......@@ -11,6 +12,7 @@ class _ClassNamespace(types.ModuleType):
raise RuntimeError(f'Class {self.name}.{attr} not registered!')
return proxy
class _Classes(types.ModuleType):
__file__ = '_classes.py'
......@@ -43,5 +45,6 @@ class _Classes(types.ModuleType):
"""
paddlespeech.ops.load_library(path)
# The classes "namespace"
classes = _Classes()
\ No newline at end of file
classes = _Classes()
......@@ -64,7 +64,8 @@ def _init_ffmpeg():
try:
_load_lib("libpaddlleaudio_ffmpeg")
except OSError as err:
raise ImportError("FFmpeg libraries are not found. Please install FFmpeg.") from err
raise ImportError(
"FFmpeg libraries are not found. Please install FFmpeg.") from err
import paddllespeech._paddlleaudio_ffmpeg # noqa
......@@ -95,4 +96,4 @@ def _init_extension():
pass
_init_extension()
\ No newline at end of file
_init_extension()
......@@ -3,6 +3,7 @@ import warnings
from functools import wraps
from typing import Optional
def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
......@@ -26,19 +27,21 @@ def requires_module(*modules: str):
return func
else:
req = f"module: {missing[0]}" if len(missing) == 1 else f"modules: {missing}"
req = f"module: {missing[0]}" if len(
missing) == 1 else f"modules: {missing}"
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f"{func.__module__}.{func.__name__} requires {req}")
raise RuntimeError(
f"{func.__module__}.{func.__name__} requires {req}")
return wrapped
return decorator
def deprecated(direction: str, version: Optional[str] = None):
def deprecated(direction: str, version: Optional[str]=None):
"""Decorator to add deprecation message
Args:
direction (str): Migration steps to be given to users.
......@@ -51,8 +54,7 @@ def deprecated(direction: str, version: Optional[str] = None):
message = (
f"{func.__module__}.{func.__name__} has been deprecated "
f'and will be removed from {"future" if version is None else version} release. '
f"{direction}"
)
f"{direction}")
warnings.warn(message, stacklevel=2)
return func(*args, **kwargs)
......@@ -62,7 +64,7 @@ def deprecated(direction: str, version: Optional[str] = None):
def is_kaldi_available():
return is_module_available("paddlespeech"._paddleaudio") and paddlespeech.ops.paddleaudio.is_kaldi_available()
return is_module_available("paddlespeech.audio._paddleaudio")
def requires_kaldi():
......@@ -76,7 +78,8 @@ def requires_kaldi():
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f"{func.__module__}.{func.__name__} requires kaldi")
raise RuntimeError(
f"{func.__module__}.{func.__name__} requires kaldi")
return wrapped
......@@ -91,7 +94,8 @@ def _check_soundfile_importable():
return True
except Exception:
warnings.warn("Failed to import soundfile. 'soundfile' backend is not available.")
warnings.warn(
"Failed to import soundfile. 'soundfile' backend is not available.")
return False
......@@ -113,7 +117,8 @@ def requires_soundfile():
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f"{func.__module__}.{func.__name__} requires soundfile")
raise RuntimeError(
f"{func.__module__}.{func.__name__} requires soundfile")
return wrapped
......@@ -121,7 +126,7 @@ def requires_soundfile():
def is_sox_available():
return is_module_available("paddlespeech._paddleaudio") and paddlespeech.ops.paddleaudio.is_sox_available()
return is_module_available("paddlespeech.audio._paddleaudio")
def requires_sox():
......@@ -135,8 +140,9 @@ def requires_sox():
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f"{func.__module__}.{func.__name__} requires sox")
raise RuntimeError(
f"{func.__module__}.{func.__name__} requires sox")
return wrapped
return
\ No newline at end of file
return
import contextlib
import ctypes
import sys
import os
import sys
import types
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')
_SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys,
'setdlopenflags')
@contextlib.contextmanager
......@@ -22,7 +23,7 @@ def dl_open_guard():
if _SET_GLOBAL_FLAGS:
sys.setdlopenflags(old_flags)
def resolve_library_path(path: str) -> str:
return os.path.realpath(path)
......@@ -59,4 +60,4 @@ class _Ops(types.ModuleType):
# The ops "namespace"
ops = _Ops()
\ No newline at end of file
ops = _Ops()
......@@ -14,9 +14,9 @@
#pragma once
#include "feat/feature-window.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include "feat/feature-window.h"
namespace paddleaudio {
......@@ -27,18 +27,14 @@ class StreamingFeatureTpl {
public:
typedef typename F::Options Options;
StreamingFeatureTpl(const Options& opts);
bool ComputeFeature(const kaldi::VectorBase<kaldi::BaseFloat>& wav,
bool ComputeFeature(const kaldi::VectorBase<kaldi::BaseFloat>& wav,
kaldi::Vector<kaldi::BaseFloat>* feats);
void Reset() {
remained_wav_.Resize(0);
}
void Reset() { remained_wav_.Resize(0); }
int Dim() {
return computer_.Dim();
}
int Dim() { return computer_.Dim(); }
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,
kaldi::Vector<kaldi::BaseFloat>* feats);
Options opts_;
kaldi::FeatureWindowFunction window_function_;
......@@ -49,4 +45,3 @@ class StreamingFeatureTpl {
} // namespace ppspeech
#include "feature_common_inl.h"
......@@ -17,16 +17,15 @@
namespace paddleaudio {
template <class F>
StreamingFeatureTpl<F>::StreamingFeatureTpl(
const Options& opts)
: opts_(opts), computer_(opts),
window_function_(opts.frame_opts) {
//window_function_(computer_.GetFrameOptions()) { the opt set to zero
StreamingFeatureTpl<F>::StreamingFeatureTpl(const Options& opts)
: opts_(opts), computer_(opts), window_function_(opts.frame_opts) {
// window_function_(computer_.GetFrameOptions()) { the opt set to zero
}
template <class F>
bool StreamingFeatureTpl<F>::ComputeFeature(const kaldi::VectorBase<kaldi::BaseFloat>& wav,
kaldi::Vector<kaldi::BaseFloat>* feats) {
bool StreamingFeatureTpl<F>::ComputeFeature(
const kaldi::VectorBase<kaldi::BaseFloat>& wav,
kaldi::Vector<kaldi::BaseFloat>* feats) {
// append remaned waves
kaldi::int32 wav_len = wav.Dim();
if (wav_len == 0) return false;
......@@ -61,7 +60,7 @@ bool StreamingFeatureTpl<F>::Compute(
kaldi::int32 frame_length = frame_opts.WindowSize();
kaldi::int32 sample_rate = frame_opts.samp_freq;
if (num_samples < frame_length) {
return false;
return false;
}
kaldi::int32 num_frames = kaldi::NumFrames(num_samples, frame_opts);
......
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include "kaldi_feature_wrapper.h"
namespace py=pybind11;
namespace py = pybind11;
bool InitFbank(
float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
float blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power) {
bool InitFbank(float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
float blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power) {
kaldi::FbankOptions opts;
opts.frame_opts.samp_freq = samp_freq; // frame opts
opts.frame_opts.samp_freq = samp_freq; // frame opts
opts.frame_opts.frame_shift_ms = frame_shift_ms;
opts.frame_opts.frame_length_ms = frame_length_ms;
opts.frame_opts.dither = dither;
opts.frame_opts.preemph_coeff = preemph_coeff;
opts.frame_opts.remove_dc_offset = remove_dc_offset;
opts.frame_opts.window_type = window_type;
opts.frame_opts.window_type = window_type;
opts.frame_opts.round_to_power_of_two = round_to_power_of_two;
opts.frame_opts.blackman_coeff = blackman_coeff;
opts.frame_opts.snip_edges = snip_edges;
......@@ -48,7 +47,7 @@ bool InitFbank(
opts.frame_opts.allow_upsample = allow_upsample;
opts.frame_opts.max_feature_vectors = max_feature_vectors;
opts.mel_opts.num_bins = num_bins; // mel opts
opts.mel_opts.num_bins = num_bins; // mel opts
opts.mel_opts.low_freq = low_freq;
opts.mel_opts.high_freq = high_freq;
opts.mel_opts.vtln_low = vtln_low;
......@@ -56,7 +55,7 @@ bool InitFbank(
opts.mel_opts.debug_mel = debug_mel;
opts.mel_opts.htk_mode = htk_mode;
opts.use_energy = use_energy; // fbank opts
opts.use_energy = use_energy; // fbank opts
opts.energy_floor = energy_floor;
opts.raw_energy = raw_energy;
opts.htk_compat = htk_compat;
......@@ -67,71 +66,71 @@ bool InitFbank(
}
py::array_t<double> ComputeFbankStreaming(const py::array_t<double>& wav) {
return paddleaudio::KaldiFeatureWrapper::GetInstance()->ComputeFbank(wav);
return paddleaudio::KaldiFeatureWrapper::GetInstance()->ComputeFbank(wav);
}
py::array_t<double> ComputeFbank(
float samp_freq, // frame opts
float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
float blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power,
bool use_power,
const py::array_t<double>& wav) {
InitFbank(samp_freq, // frame opts
frame_shift_ms,
frame_length_ms,
dither,
preemph_coeff,
remove_dc_offset,
window_type, // e.g. Hamming window
round_to_power_of_two,
blackman_coeff,
snip_edges,
allow_downsample,
allow_upsample,
max_feature_vectors,
num_bins, // mel opts
low_freq,
high_freq,
vtln_low,
vtln_high,
debug_mel,
htk_mode,
use_energy, // fbank opts
energy_floor,
raw_energy,
htk_compat,
use_log_fbank,
use_power);
py::array_t<double> result = ComputeFbankStreaming(wav);
paddleaudio::KaldiFeatureWrapper::GetInstance()->ResetFbank();
return result;
InitFbank(samp_freq, // frame opts
frame_shift_ms,
frame_length_ms,
dither,
preemph_coeff,
remove_dc_offset,
window_type, // e.g. Hamming window
round_to_power_of_two,
blackman_coeff,
snip_edges,
allow_downsample,
allow_upsample,
max_feature_vectors,
num_bins, // mel opts
low_freq,
high_freq,
vtln_low,
vtln_high,
debug_mel,
htk_mode,
use_energy, // fbank opts
energy_floor,
raw_energy,
htk_compat,
use_log_fbank,
use_power);
py::array_t<double> result = ComputeFbankStreaming(wav);
paddleaudio::KaldiFeatureWrapper::GetInstance()->ResetFbank();
return result;
}
void ResetFbank() {
paddleaudio::KaldiFeatureWrapper::GetInstance()->ResetFbank();
paddleaudio::KaldiFeatureWrapper::GetInstance()->ResetFbank();
}
PYBIND11_MODULE(kaldi_featurepy, m) {
......@@ -139,5 +138,7 @@ PYBIND11_MODULE(kaldi_featurepy, m) {
m.def("InitFbank", &InitFbank, "init fbank");
m.def("ResetFbank", &ResetFbank, "reset fbank");
m.def("ComputeFbank", &ComputeFbank, "compute fbank");
m.def("ComputeFbankStreaming", &ComputeFbankStreaming, "compute fbank streaming");
m.def("ComputeFbankStreaming",
&ComputeFbankStreaming,
"compute fbank streaming");
}
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include "kaldi_feature_wrapper.h"
namespace py=pybind11;
namespace py = pybind11;
bool InitFbank(
float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
float blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power);
bool InitFbank(float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
float blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power);
py::array_t<double> ComputeFbank(
float samp_freq, // frame opts
float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
kaldi::BaseFloat blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
......
......@@ -8,17 +8,18 @@ KaldiFeatureWrapper* KaldiFeatureWrapper::GetInstance() {
}
bool KaldiFeatureWrapper::InitFbank(kaldi::FbankOptions opts) {
fbank_.reset(new Fbank(opts));
return true;
fbank_.reset(new Fbank(opts));
return true;
}
py::array_t<double> KaldiFeatureWrapper::ComputeFbank(const py::array_t<double> wav) {
py::array_t<double> KaldiFeatureWrapper::ComputeFbank(
const py::array_t<double> wav) {
py::buffer_info info = wav.request();
kaldi::Vector<kaldi::BaseFloat> input_wav(info.size);
double* wav_ptr = (double*)info.ptr;
for (int idx = 0; idx < info.size; ++idx) {
input_wav(idx) = *wav_ptr;
wav_ptr++;
input_wav(idx) = *wav_ptr;
wav_ptr++;
}
......@@ -28,8 +29,8 @@ py::array_t<double> KaldiFeatureWrapper::ComputeFbank(const py::array_t<double>
auto result = py::array_t<double>(feats.Dim());
py::buffer_info xs = result.request();
for (int idx = 0; idx < 10; ++idx) {
float val = feats(idx);
std::cout << val << " ";
float val = feats(idx);
std::cout << val << " ";
}
std::cout << std::endl;
double* res_ptr = (double*)xs.ptr;
......@@ -38,20 +39,21 @@ py::array_t<double> KaldiFeatureWrapper::ComputeFbank(const py::array_t<double>
res_ptr++;
}
return result.reshape({ feats.Dim() / Dim(), Dim()});
/*
py::buffer_info info = wav.request();
std::cout << info.size << std::endl;
auto result = py::array_t<double>(info.size);
//kaldi::Vector<kaldi::BaseFloat> input_wav(info.size);
kaldi::Vector<double> input_wav(info.size);
py::buffer_info info_re = result.request();
memcpy(input_wav.Data(), (double*)info.ptr, wav.nbytes());
memcpy((double*)info_re.ptr, input_wav.Data(), input_wav.Dim()* sizeof(double));
return result;
*/
return result.reshape({feats.Dim() / Dim(), Dim()});
/*
py::buffer_info info = wav.request();
std::cout << info.size << std::endl;
auto result = py::array_t<double>(info.size);
//kaldi::Vector<kaldi::BaseFloat> input_wav(info.size);
kaldi::Vector<double> input_wav(info.size);
py::buffer_info info_re = result.request();
memcpy(input_wav.Data(), (double*)info.ptr, wav.nbytes());
memcpy((double*)info_re.ptr, input_wav.Data(), input_wav.Dim()*
sizeof(double));
return result;
*/
}
} // namespace paddleaudio
} // namespace paddleaudio
#include "base/kaldi-common.h"
#include "feature_common.h"
#include "feat/feature-fbank.h"
#include "feature_common.h"
#pragma once
......@@ -14,12 +14,8 @@ class KaldiFeatureWrapper {
static KaldiFeatureWrapper* GetInstance();
bool InitFbank(kaldi::FbankOptions opts);
py::array_t<double> ComputeFbank(const py::array_t<double> wav);
int Dim() {
return fbank_->Dim();
}
void ResetFbank() {
fbank_->Reset();
}
int Dim() { return fbank_->Dim(); }
void ResetFbank() { fbank_->Reset(); }
private:
std::unique_ptr<paddleaudio::Fbank> fbank_;
......
//Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
//All rights reserved.
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved.
#include "paddlespeech/audio/src/pybind/sox/io.h"
PYBIND11_MODULE(_paddleaudio, m) {
m.def("get_info_file", &paddleaudio::sox_io::get_info_file,
"Get metadata of audio file.");
m.def("get_info_fileobj", &paddleaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def("get_info_file",
&paddleaudio::sox_io::get_info_file,
"Get metadata of audio file.");
m.def("get_info_fileobj",
&paddleaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
}
\ No newline at end of file
//Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
//All rights reserved.
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved.
#include "paddlespeech/audio/src/pybind/sox/io.h"
#include "paddlespeech/audio/src/pybind/sox/utils.h"
......@@ -11,51 +11,54 @@ namespace sox_io {
auto get_info_file(const std::string &path, const std::string &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
SoxFormat sf(sox_open_read(path.data(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
validate_input_file(sf, path);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
SoxFormat sf(
sox_open_read(path.data(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
validate_input_file(sf, path);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
auto get_info_fileobj(py::object fileobj, const std::string &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
const auto capacity = [&]() {
const auto bufsiz = get_buffer_size();
const int64_t kDefaultCapacityInBytes = 4096;
return (bufsiz > kDefaultCapacityInBytes) ? bufsiz
: kDefaultCapacityInBytes;
}();
std::string buffer(capacity, '\0');
auto *buf = const_cast<char *>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(buf, buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
const auto capacity = [&]() {
const auto bufsiz = get_buffer_size();
const int64_t kDefaultCapacityInBytes = 4096;
return (bufsiz > kDefaultCapacityInBytes) ? bufsiz
: kDefaultCapacityInBytes;
}();
std::string buffer(capacity, '\0');
auto *buf = const_cast<char *>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(
buf,
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
} // namespace paddleaudio
} // namespace sox_io
} // namespace paddleaudio
} // namespace sox_io
//Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
//All rights reserved.
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved.
#ifndef PADDLEAUDIO_PYBIND_SOX_IO_H
#define PADDLEAUDIO_PYBIND_SOX_IO_H
......@@ -15,7 +15,7 @@ auto get_info_file(const std::string &path, const std::string &format)
auto get_info_fileobj(py::object fileobj, const std::string &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
} // namespace paddleaudio
} // namespace sox_io
} // namespace paddleaudio
} // namespace sox_io
#endif
//Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
//All rights reserved.
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved.
#include "paddlespeech/audio/src/pybind/sox/utils.h"
......@@ -15,86 +15,87 @@ sox_format_t *SoxFormat::operator->() const noexcept { return fd_; }
SoxFormat::operator sox_format_t *() const noexcept { return fd_; }
void SoxFormat::close() {
if (fd_ != nullptr) {
sox_close(fd_);
fd_ = nullptr;
}
if (fd_ != nullptr) {
sox_close(fd_);
fd_ = nullptr;
}
}
auto read_fileobj(py::object *fileobj, const uint64_t size, char *buffer)
-> uint64_t {
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message << "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file "
"object.";
throw std::runtime_error(message.str());
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file "
"object.";
throw std::runtime_error(message.str());
}
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
return num_read;
return num_read;
}
int64_t get_buffer_size() { return sox_get_globals()->bufsiz; }
void validate_input_file(const SoxFormat &sf, const std::string &path) {
if (static_cast<sox_format_t *>(sf) == nullptr) {
throw std::runtime_error("Error loading audio file: failed to open file " +
path);
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
if (static_cast<sox_format_t *>(sf) == nullptr) {
throw std::runtime_error(
"Error loading audio file: failed to open file " + path);
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
}
void validate_input_memfile(const SoxFormat &sf) {
return validate_input_file(sf, "<in memory buffer>");
return validate_input_file(sf, "<in memory buffer>");
}
std::string get_encoding(sox_encoding_t encoding) {
switch (encoding) {
case SOX_ENCODING_UNKNOWN:
return "UNKNOWN";
case SOX_ENCODING_SIGN2:
return "PCM_S";
case SOX_ENCODING_UNSIGNED:
return "PCM_U";
case SOX_ENCODING_FLOAT:
return "PCM_F";
case SOX_ENCODING_FLAC:
return "FLAC";
case SOX_ENCODING_ULAW:
return "ULAW";
case SOX_ENCODING_ALAW:
return "ALAW";
case SOX_ENCODING_MP3:
return "MP3";
case SOX_ENCODING_VORBIS:
return "VORBIS";
case SOX_ENCODING_AMR_WB:
return "AMR_WB";
case SOX_ENCODING_AMR_NB:
return "AMR_NB";
case SOX_ENCODING_OPUS:
return "OPUS";
case SOX_ENCODING_GSM:
return "GSM";
default:
return "UNKNOWN";
}
switch (encoding) {
case SOX_ENCODING_UNKNOWN:
return "UNKNOWN";
case SOX_ENCODING_SIGN2:
return "PCM_S";
case SOX_ENCODING_UNSIGNED:
return "PCM_U";
case SOX_ENCODING_FLOAT:
return "PCM_F";
case SOX_ENCODING_FLAC:
return "FLAC";
case SOX_ENCODING_ULAW:
return "ULAW";
case SOX_ENCODING_ALAW:
return "ALAW";
case SOX_ENCODING_MP3:
return "MP3";
case SOX_ENCODING_VORBIS:
return "VORBIS";
case SOX_ENCODING_AMR_WB:
return "AMR_WB";
case SOX_ENCODING_AMR_NB:
return "AMR_NB";
case SOX_ENCODING_OPUS:
return "OPUS";
case SOX_ENCODING_GSM:
return "GSM";
default:
return "UNKNOWN";
}
}
} // namespace paddleaudio
} // namespace sox_utils
} // namespace paddleaudio
} // namespace sox_utils
//Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
//All rights reserved.
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved.
#ifndef PADDLEAUDIO_PYBIND_SOX_UTILS_H
#define PADDLEAUDIO_PYBIND_SOX_UTILS_H
......@@ -14,19 +14,19 @@ namespace sox_utils {
/// helper class to automatically close sox_format_t*
struct SoxFormat {
explicit SoxFormat(sox_format_t *fd) noexcept;
SoxFormat(const SoxFormat &other) = delete;
SoxFormat(SoxFormat &&other) = delete;
SoxFormat &operator=(const SoxFormat &other) = delete;
SoxFormat &operator=(SoxFormat &&other) = delete;
~SoxFormat();
sox_format_t *operator->() const noexcept;
operator sox_format_t *() const noexcept;
void close();
private:
sox_format_t *fd_;
explicit SoxFormat(sox_format_t *fd) noexcept;
SoxFormat(const SoxFormat &other) = delete;
SoxFormat(SoxFormat &&other) = delete;
SoxFormat &operator=(const SoxFormat &other) = delete;
SoxFormat &operator=(SoxFormat &&other) = delete;
~SoxFormat();
sox_format_t *operator->() const noexcept;
operator sox_format_t *() const noexcept;
void close();
private:
sox_format_t *fd_;
};
auto read_fileobj(py::object *fileobj, uint64_t size, char *buffer) -> uint64_t;
......@@ -39,7 +39,7 @@ void validate_input_memfile(const SoxFormat &sf);
std::string get_encoding(sox_encoding_t encoding);
} // namespace paddleaudio
} // namespace sox_utils
} // namespace paddleaudio
} // namespace sox_utils
#endif
......@@ -11,54 +11,53 @@ namespace paddleaudio {
namespace sox_io {
tl::optional<MetaDataTuple> get_info_file(
const std::string& path,
const tl::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
const std::string& path, const tl::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
std::vector<std::vector<std::string>> get_effects(
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects;
if (frames != -1) {
std::ostringstream os_offset, os_frames;
os_offset << offset << "s";
os_frames << "+" << frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
} else if (offset != 0) {
std::ostringstream os_offset;
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return effects;
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects;
if (frames != -1) {
std::ostringstream os_offset, os_frames;
os_offset << offset << "s";
os_frames << "+" << frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
} else if (offset != 0) {
std::ostringstream os_offset;
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return effects;
}
tl::optional<std::tuple<torch::Tensor, int64_t>> load_audio_file(
......@@ -68,79 +67,73 @@ tl::optional<std::tuple<torch::Tensor, int64_t>> load_audio_file(
tl::optional<bool> normalize,
tl::optional<bool> channels_first,
const tl::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return paddleaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first, format);
auto effects = get_effects(frame_offset, num_frames);
return paddleaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first, format);
}
void save_audio_file(
const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
const auto filetype = [&]() {
if (format.has_value())
return format.value();
return get_filetype(path);
}();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "amr-nb format only supports single channel audio.");
} else if (filetype == "htk") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "htk format only supports single channel audio.");
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "gsm format only supports single channel audio.");
TORCH_CHECK(
sample_rate == 8000,
"gsm format only supports a sampling rate of 8kHz.");
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype, tensor.dtype(), compression, encoding, bits_per_sample);
SoxFormat sf(sox_open_write(
path.c_str(),
&signal_info,
&encoding_info,
/*filetype=*/filetype.c_str(),
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open file " + path);
}
paddleaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFile(sf);
chain.run();
void save_audio_file(const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
const auto filetype = [&]() {
if (format.has_value()) return format.value();
return get_filetype(path);
}();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(num_channels == 1,
"amr-nb format only supports single channel audio.");
} else if (filetype == "htk") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(num_channels == 1,
"htk format only supports single channel audio.");
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(num_channels == 1,
"gsm format only supports single channel audio.");
TORCH_CHECK(sample_rate == 8000,
"gsm format only supports a sampling rate of 8kHz.");
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype, tensor.dtype(), compression, encoding, bits_per_sample);
SoxFormat sf(sox_open_write(path.c_str(),
&signal_info,
&encoding_info,
/*filetype=*/filetype.c_str(),
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open file " + path);
}
paddleaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFile(sf);
chain.run();
}
TORCH_LIBRARY_FRAGMENT(paddleaudio, m) {
m.def("paddleaudio::sox_io_get_info", &paddleaudio::sox_io::get_info_file);
m.def(
"paddleaudio::sox_io_load_audio_file",
&paddleaudio::sox_io::load_audio_file);
m.def(
"paddleaudio::sox_io_save_audio_file",
&paddleaudio::sox_io::save_audio_file);
m.def("paddleaudio::sox_io_get_info", &paddleaudio::sox_io::get_info_file);
m.def("paddleaudio::sox_io_load_audio_file",
&paddleaudio::sox_io::load_audio_file);
m.def("paddleaudio::sox_io_save_audio_file",
&paddleaudio::sox_io::save_audio_file);
}
} // namespace sox_io
} // namespace paddleaudio
\ No newline at end of file
} // namespace sox_io
} // namespace paddleaudio
\ No newline at end of file
//Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
//All rights reserved.
// Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
// All rights reserved.
#ifndef PADDLEAUDIO_SOX_IO_H
#define PADDLEAUDIO_SOX_IO_H
......@@ -11,17 +11,15 @@
namespace paddleaudio {
namespace sox_io {
auto get_effects(
const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames)
auto get_effects(const tl::optional<int64_t>& frame_offset,
const tl::optional<int64_t>& num_frames)
-> std::vector<std::vector<std::string>>;
using MetaDataTuple =
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
tl::optional<MetaDataTuple> get_info_file(
const std::string& path,
const tl::optional<std::string>& format);
const std::string& path, const tl::optional<std::string>& format);
tl::optional<std::tuple<torch::Tensor, int64_t>> load_audio_file(
const std::string& path,
......@@ -31,17 +29,16 @@ tl::optional<std::tuple<torch::Tensor, int64_t>> load_audio_file(
tl::optional<bool> channels_first,
const tl::optional<std::string>& format);
void save_audio_file(
const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample);
} // namespace sox_io
} // namespace paddleaudio
void save_audio_file(const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
tl::optional<double> compression,
tl::optional<std::string> format,
tl::optional<std::string> encoding,
tl::optional<int64_t> bits_per_sample);
} // namespace sox_io
} // namespace paddleaudio
#endif
\ No newline at end of file
......@@ -4,17 +4,17 @@ namespace {
bool is_sox_available() {
#ifdef INCLUDE_SOX
return true;
return true;
#else
return false;
return false;
#endif
}
bool is_kaldi_available() {
#ifdef INCLUDE_KALDI
return true;
return true;
#else
return false;
return false;
#endif
}
......@@ -22,12 +22,12 @@ bool is_kaldi_available() {
// not the runtime availability.
bool is_ffmpeg_available() {
#ifdef USE_FFMPEG
return true;
return true;
#else
return false;
return false;
#endif
}
} // namespace
} // namespace
} // namespace paddleaudio
\ No newline at end of file
} // namespace paddleaudio
\ No newline at end of file
此差异已折叠。
from .extension import *
\ No newline at end of file
from .extension import *
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册