未验证 提交 3bb904bd 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #2119 from SmileGoat/audio_add_pitch

[audio] paddlespeech add kaldi feature pitch
......@@ -65,4 +65,4 @@ add_subdirectory(paddlespeech/audio)
# Summary
include(cmake/summary.cmake)
onnx_print_configuration_summary()
\ No newline at end of file
onnx_print_configuration_summary()
......@@ -3,6 +3,7 @@ import warnings
from functools import wraps
from typing import Optional
#code is from https://github.com/pytorch/audio/blob/main/torchaudio/_internal/module_utils.py
def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
......
......@@ -6,6 +6,7 @@ from typing import Union
from paddle import Tensor
#code is from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/no_backend.py
def load(
filepath: Union[str, Path],
......
......@@ -7,6 +7,7 @@ from typing import Union
from paddle import Tensor
#https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py
def load(
filepath: Union[str, Path],
......
"""Defines utilities for switching audio backends"""
#code is from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/utils.py
import warnings
from typing import List
from typing import Optional
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import fbank
from . import pitch
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlespeech.audio._internal import module_utils
import paddlespeech.audio.ops.paddleaudio.ComputeFbank as ComputeFbank
import paddlespeech.audio.ops.paddleaudio.PitchExtractionOptions as PitchExtractionOptions
import paddlespeech.audio.ops.paddleaudio.FrameExtractionOptions as FrameExtractionOptions
import paddlespeech.audio.ops.paddleaudio.MelBanksOptions as MelBanksOptions
import paddlespeech.audio.ops.paddleaudio.FbankOptions as FbankOptions
import paddlespeech.audio.ops.paddleaudio.ComputeKaldiPitch as ComputeKaldiPitch
__all__ = [
'fbank',
'pitch',
]
@module_utils.requires_kaldi()
def fbank(wav,
samp_freq: int=16000,
frame_shift_ms: float=10.0,
frame_length_ms: float=25.0,
dither: float=0.0,
preemph_coeff: float=0.97,
remove_dc_offset: bool=True,
window_type: str='povey',
round_to_power_of_two: bool=True,
blackman_coeff: float=0.42,
snip_edges: bool=True,
allow_downsample: bool=False,
allow_upsample: bool=False,
max_feature_vectors: int=-1,
num_bins: int=23,
low_freq: float=20,
high_freq: float=0,
vtln_low: float=100,
vtln_high: float=-500,
debug_mel: bool=False,
htk_mode: bool=False,
use_energy: bool=False, # fbank opts
energy_floor: float=0.0,
raw_energy: bool=True,
htk_compat: bool=False,
use_log_fbank: bool=True,
use_power: bool=True):
frame_opts = FrameExtractionOptions()
mel_opts = MelBanksOptions()
fbank_opts = FbankOptions()
frame_opts.samp_freq = samp_freq
frame_opts.frame_shift_ms = frame_shift_ms
frame_opts.frame_length_ms = frame_length_ms
frame_opts.dither = dither
frame_opts.preemph_coeff = preemph_coeff
frame_opts.remove_dc_offset = remove_dc_offset
frame_opts.window_type = window_type
frame_opts.round_to_power_of_two = round_to_power_of_two
frame_opts.blackman_coeff = blackman_coeff
frame_opts.snip_edges = snip_edges
frame_opts.allow_downsample = allow_downsample
frame_opts.allow_upsample = allow_upsample
frame_opts.max_feature_vectors = max_feature_vectors
mel_opts.num_bins = num_bins
mel_opts.low_freq = low_freq
mel_opts.high_freq = high_freq
mel_opts.vtln_low = vtln_low
mel_opts.vtln_high = vtln_high
mel_opts.debug_mel = debug_mel
mel_opts.htk_mode = htk_mode
fbank_opts.use_energy = use_energy
fbank_opts.energy_floor = energy_floor
fbank_opts.raw_energy = raw_energy
fbank_opts.htk_compat = htk_compat
fbank_opts.use_log_fbank = use_log_fbank
fbank_opts.use_power = use_power
feat = ComputeFbank(frame_opts, mel_opts, fbank_opts, wav)
return feat
@module_utils.requires_kaldi()
def pitch(wav,
samp_freq: int=16000,
frame_shift_ms: float=10.0,
frame_length_ms: float=25.0,
preemph_coeff: float=0.0,
min_f0: int=50,
max_f0: int=400,
soft_min_f0: float=10.0,
penalty_factor: float=0.1,
lowpass_cutoff: int=1000,
resample_freq: int=4000,
delta_pitch: float=0.005,
nccf_ballast: int=7000,
lowpass_filter_width: int=1,
upsample_filter_width: int=5,
max_frames_latency: int=0,
frames_per_chunk: int=0,
simulate_first_pass_online: bool=False,
recompute_frame: int=500,
nccf_ballast_online: bool=False,
snip_edges: bool=True):
pitch_opts = PitchExtractionOptions()
pitch_opts.samp_freq = samp_freq
pitch_opts.frame_shift_ms = frame_shift_ms
pitch_opts.frame_length_ms = frame_length_ms
pitch_opts.preemph_coeff = preemph_coeff
pitch_opts.min_f0 = min_f0
pitch_opts.max_f0 = max_f0
pitch_opts.soft_min_f0 = soft_min_f0
pitch_opts.penalty_factor = penalty_factor
pitch_opts.lowpass_cutoff = lowpass_cutoff
pitch_opts.resample_freq = resample_freq
pitch_opts.delta_pitch = delta_pitch
pitch_opts.nccf_ballast = nccf_ballast
pitch_opts.lowpass_filter_width = lowpass_filter_width
pitch_opts.upsample_filter_width = upsample_filter_width
pitch_opts.max_frames_latency = max_frames_latency
pitch_opts.frames_per_chunk = frames_per_chunk
pitch_opts.simulate_first_pass_online = simulate_first_pass_online
pitch_opts.recompute_frame = recompute_frame
pitch_opts.nccf_ballast_online = nccf_ballast_online
pitch_opts.snip_edges = snip_edges
pitch = ComputeKaldiPitch(pitch_opts, wav)
return pitch
......@@ -13,137 +13,63 @@
// limitations under the License.
#include "paddlespeech/audio/src/pybind/kaldi/kaldi_feature.h"
#include "feat/pitch-functions.h"
namespace paddleaudio {
namespace kaldi {
bool InitFbank(float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
float blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power) {
bool InitFbank(
::kaldi::FrameExtractionOptions frame_opts,
::kaldi::MelBanksOptions mel_opts,
FbankOptions fbank_opts) {
::kaldi::FbankOptions opts;
opts.frame_opts.samp_freq = samp_freq; // frame opts
opts.frame_opts.frame_shift_ms = frame_shift_ms;
opts.frame_opts.frame_length_ms = frame_length_ms;
opts.frame_opts.dither = dither;
opts.frame_opts.preemph_coeff = preemph_coeff;
opts.frame_opts.remove_dc_offset = remove_dc_offset;
opts.frame_opts.window_type = window_type;
opts.frame_opts.round_to_power_of_two = round_to_power_of_two;
opts.frame_opts.blackman_coeff = blackman_coeff;
opts.frame_opts.snip_edges = snip_edges;
opts.frame_opts.allow_downsample = allow_downsample;
opts.frame_opts.allow_upsample = allow_upsample;
opts.frame_opts.max_feature_vectors = max_feature_vectors;
opts.mel_opts.num_bins = num_bins; // mel opts
opts.mel_opts.low_freq = low_freq;
opts.mel_opts.high_freq = high_freq;
opts.mel_opts.vtln_low = vtln_low;
opts.mel_opts.vtln_high = vtln_high;
opts.mel_opts.debug_mel = debug_mel;
opts.mel_opts.htk_mode = htk_mode;
opts.use_energy = use_energy; // fbank opts
opts.energy_floor = energy_floor;
opts.raw_energy = raw_energy;
opts.htk_compat = htk_compat;
opts.use_log_fbank = use_log_fbank;
opts.use_power = use_power;
opts.frame_opts = frame_opts;
opts.mel_opts = mel_opts;
opts.use_energy = fbank_opts.use_energy;
opts.energy_floor = fbank_opts.energy_floor;
opts.raw_energy = fbank_opts.raw_energy;
opts.htk_compat = fbank_opts.htk_compat;
opts.use_log_fbank = fbank_opts.use_log_fbank;
opts.use_power = fbank_opts.use_power;
paddleaudio::kaldi::KaldiFeatureWrapper::GetInstance()->InitFbank(opts);
return true;
}
py::array_t<double> ComputeFbankStreaming(const py::array_t<double>& wav) {
py::array_t<float> ComputeFbankStreaming(const py::array_t<float>& wav) {
return paddleaudio::kaldi::KaldiFeatureWrapper::GetInstance()->ComputeFbank(
wav);
}
py::array_t<double> ComputeFbank(
float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
float blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power,
const py::array_t<double>& wav) {
InitFbank(samp_freq, // frame opts
frame_shift_ms,
frame_length_ms,
dither,
preemph_coeff,
remove_dc_offset,
window_type, // e.g. Hamming window
round_to_power_of_two,
blackman_coeff,
snip_edges,
allow_downsample,
allow_upsample,
max_feature_vectors,
num_bins, // mel opts
low_freq,
high_freq,
vtln_low,
vtln_high,
debug_mel,
htk_mode,
use_energy, // fbank opts
energy_floor,
raw_energy,
htk_compat,
use_log_fbank,
use_power);
py::array_t<double> result = ComputeFbankStreaming(wav);
py::array_t<float> ComputeFbank(
::kaldi::FrameExtractionOptions frame_opts,
::kaldi::MelBanksOptions mel_opts,
FbankOptions fbank_opts,
const py::array_t<float>& wav) {
InitFbank(frame_opts, mel_opts, fbank_opts);
py::array_t<float> result = ComputeFbankStreaming(wav);
paddleaudio::kaldi::KaldiFeatureWrapper::GetInstance()->ResetFbank();
return result;
}
void ResetFbank() {
paddleaudio::kaldi::KaldiFeatureWrapper::GetInstance()->ResetFbank();
}
py::array_t<float> ComputeKaldiPitch(
const ::kaldi::PitchExtractionOptions& opts,
const py::array_t<float>& wav) {
py::buffer_info info = wav.request();
::kaldi::SubVector<::kaldi::BaseFloat> input_wav((float*)info.ptr, info.size);
::kaldi::Matrix<::kaldi::BaseFloat> features;
::kaldi::ComputeKaldiPitch(opts, input_wav, &features);
auto result = py::array_t<float>({features.NumRows(), features.NumCols()});
for (int row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
std::memcpy(result.mutable_data(row_idx), features.Row(row_idx).Data(),
sizeof(float)*features.NumCols());
}
return result;
}
} // namespace kaldi
} // namespace paddleaudio
......@@ -19,75 +19,46 @@
#include <string>
#include "paddlespeech/audio/src/pybind/kaldi/kaldi_feature_wrapper.h"
#include "feat/pitch-functions.h"
namespace py = pybind11;
namespace paddleaudio {
namespace kaldi {
bool InitFbank(float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
float blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power);
struct FbankOptions{
bool use_energy; // append an extra dimension with energy to the filter banks
float energy_floor;
bool raw_energy; // If true, compute energy before preemphasis and windowing
bool htk_compat; // If true, put energy last (if using energy)
bool use_log_fbank; // if true (default), produce log-filterbank, else linear
bool use_power;
FbankOptions(): use_energy(false),
energy_floor(0.0),
raw_energy(true),
htk_compat(false),
use_log_fbank(true),
use_power(true) {}
};
py::array_t<double> ComputeFbank(
float samp_freq, // frame opts
float frame_shift_ms,
float frame_length_ms,
float dither,
float preemph_coeff,
bool remove_dc_offset,
std::string window_type, // e.g. Hamming window
bool round_to_power_of_two,
::kaldi::BaseFloat blackman_coeff,
bool snip_edges,
bool allow_downsample,
bool allow_upsample,
int max_feature_vectors,
int num_bins, // mel opts
float low_freq,
float high_freq,
float vtln_low,
float vtln_high,
bool debug_mel,
bool htk_mode,
bool use_energy, // fbank opts
float energy_floor,
bool raw_energy,
bool htk_compat,
bool use_log_fbank,
bool use_power,
const py::array_t<double>& wav);
bool InitFbank(
::kaldi::FrameExtractionOptions frame_opts,
::kaldi::MelBanksOptions mel_opts,
FbankOptions fbank_opts);
py::array_t<double> ComputeFbankStreaming(const py::array_t<double>& wav);
py::array_t<float> ComputeFbank(
::kaldi::FrameExtractionOptions frame_opts,
::kaldi::MelBanksOptions mel_opts,
FbankOptions fbank_opts,
const py::array_t<float>& wav);
void ResetFbank();
py::array_t<float> ComputeFbankStreaming(const py::array_t<float>& wav);
py::array_t<double> ComputeFbankStreaming(const py::array_t<double>& wav);
void ResetFbank();
py::array_t<double> TestFun(const py::array_t<double>& wav);
py::array_t<float> ComputeKaldiPitch(
const ::kaldi::PitchExtractionOptions& opts,
const py::array_t<float>& wav);
} // namespace kaldi
} // namespace paddleaudio
......@@ -27,47 +27,24 @@ bool KaldiFeatureWrapper::InitFbank(::kaldi::FbankOptions opts) {
return true;
}
py::array_t<double> KaldiFeatureWrapper::ComputeFbank(
const py::array_t<double> wav) {
py::array_t<float> KaldiFeatureWrapper::ComputeFbank(
const py::array_t<float> wav) {
py::buffer_info info = wav.request();
::kaldi::Vector<::kaldi::BaseFloat> input_wav(info.size);
double* wav_ptr = (double*)info.ptr;
for (int idx = 0; idx < info.size; ++idx) {
input_wav(idx) = *wav_ptr;
wav_ptr++;
}
::kaldi::SubVector<::kaldi::BaseFloat> input_wav((float*)info.ptr, info.size);
::kaldi::Vector<::kaldi::BaseFloat> feats;
bool flag = fbank_->ComputeFeature(input_wav, &feats);
if (flag == false || feats.Dim() == 0) return py::array_t<double>();
auto result = py::array_t<double>(feats.Dim());
if (flag == false || feats.Dim() == 0) return py::array_t<float>();
auto result = py::array_t<float>(feats.Dim());
py::buffer_info xs = result.request();
for (int idx = 0; idx < 10; ++idx) {
float val = feats(idx);
std::cout << val << " ";
}
std::cout << std::endl;
double* res_ptr = (double*)xs.ptr;
float* res_ptr = (float*)xs.ptr;
for (int idx = 0; idx < feats.Dim(); ++idx) {
*res_ptr = feats(idx);
res_ptr++;
}
return result.reshape({feats.Dim() / Dim(), Dim()});
/*
py::buffer_info info = wav.request();
std::cout << info.size << std::endl;
auto result = py::array_t<double>(info.size);
//::kaldi::Vector<::kaldi::BaseFloat> input_wav(info.size);
::kaldi::Vector<double> input_wav(info.size);
py::buffer_info info_re = result.request();
memcpy(input_wav.Data(), (double*)info.ptr, wav.nbytes());
memcpy((double*)info_re.ptr, input_wav.Data(), input_wav.Dim()*
sizeof(double));
return result;
*/
}
} // namesapce kaldi
......
......@@ -28,7 +28,7 @@ class KaldiFeatureWrapper {
public:
static KaldiFeatureWrapper* GetInstance();
bool InitFbank(::kaldi::FbankOptions opts);
py::array_t<double> ComputeFbank(const py::array_t<double> wav);
py::array_t<float> ComputeFbank(const py::array_t<float> wav);
int Dim() { return fbank_->Dim(); }
void ResetFbank() { fbank_->Reset(); }
......
......@@ -3,20 +3,76 @@
#include "paddlespeech/audio/src/pybind/kaldi/kaldi_feature.h"
#include "paddlespeech/audio/src/pybind/sox/io.h"
#include "paddlespeech/audio/third_party/kaldi/feat/feature-fbank.h"
// Sox
PYBIND11_MODULE(_paddleaudio, m) {
#ifdef INCLUDE_SOX
m.def("get_info_file",
&paddleaudio::sox_io::get_info_file,
"Get metadata of audio file.");
m.def("get_info_fileobj",
&paddleaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
#endif
m.def("InitFbank", &paddleaudio::kaldi::InitFbank, "init fbank");
m.def("ResetFbank", &paddleaudio::kaldi::ResetFbank, "reset fbank");
#ifdef INCLUDE_KALDI
m.def("ComputeFbank", &paddleaudio::kaldi::ComputeFbank, "compute fbank");
m.def("ComputeFbankStreaming",
&paddleaudio::kaldi::ComputeFbankStreaming,
"compute fbank streaming");
py::class_<kaldi::PitchExtractionOptions>(m, "PitchExtractionOptions")
.def(py::init<>())
.def_readwrite("samp_freq", &kaldi::PitchExtractionOptions::samp_freq)
.def_readwrite("frame_shift_ms", &kaldi::PitchExtractionOptions::frame_shift_ms)
.def_readwrite("frame_length_ms", &kaldi::PitchExtractionOptions::frame_length_ms)
.def_readwrite("preemph_coeff", &kaldi::PitchExtractionOptions::preemph_coeff)
.def_readwrite("min_f0", &kaldi::PitchExtractionOptions::min_f0)
.def_readwrite("max_f0", &kaldi::PitchExtractionOptions::max_f0)
.def_readwrite("soft_min_f0", &kaldi::PitchExtractionOptions::soft_min_f0)
.def_readwrite("penalty_factor", &kaldi::PitchExtractionOptions::penalty_factor)
.def_readwrite("lowpass_cutoff", &kaldi::PitchExtractionOptions::lowpass_cutoff)
.def_readwrite("resample_freq", &kaldi::PitchExtractionOptions::resample_freq)
.def_readwrite("delta_pitch", &kaldi::PitchExtractionOptions::delta_pitch)
.def_readwrite("nccf_ballast", &kaldi::PitchExtractionOptions::nccf_ballast)
.def_readwrite("lowpass_filter_width", &kaldi::PitchExtractionOptions::lowpass_filter_width)
.def_readwrite("upsample_filter_width", &kaldi::PitchExtractionOptions::upsample_filter_width)
.def_readwrite("max_frames_latency", &kaldi::PitchExtractionOptions::max_frames_latency)
.def_readwrite("frames_per_chunk", &kaldi::PitchExtractionOptions::frames_per_chunk)
.def_readwrite("simulate_first_pass_online", &kaldi::PitchExtractionOptions::simulate_first_pass_online)
.def_readwrite("recompute_frame", &kaldi::PitchExtractionOptions::recompute_frame)
.def_readwrite("nccf_ballast_online", &kaldi::PitchExtractionOptions::nccf_ballast_online)
.def_readwrite("snip_edges", &kaldi::PitchExtractionOptions::snip_edges);
m.def("ComputeKaldiPitch", &paddleaudio::kaldi::ComputeKaldiPitch, "compute kaldi pitch");
py::class_<kaldi::FrameExtractionOptions>(m, "FrameExtractionOptions")
.def(py::init<>())
.def_readwrite("samp_freq", &kaldi::FrameExtractionOptions::samp_freq)
.def_readwrite("frame_shift_ms", &kaldi::FrameExtractionOptions::frame_shift_ms)
.def_readwrite("frame_length_ms", &kaldi::FrameExtractionOptions::frame_length_ms)
.def_readwrite("dither", &kaldi::FrameExtractionOptions::dither)
.def_readwrite("preemph_coeff", &kaldi::FrameExtractionOptions::preemph_coeff)
.def_readwrite("remove_dc_offset", &kaldi::FrameExtractionOptions::remove_dc_offset)
.def_readwrite("window_type", &kaldi::FrameExtractionOptions::window_type)
.def_readwrite("round_to_power_of_two", &kaldi::FrameExtractionOptions::round_to_power_of_two)
.def_readwrite("blackman_coeff", &kaldi::FrameExtractionOptions::blackman_coeff)
.def_readwrite("snip_edges", &kaldi::FrameExtractionOptions::snip_edges)
.def_readwrite("allow_downsample", &kaldi::FrameExtractionOptions::allow_downsample)
.def_readwrite("allow_upsample", &kaldi::FrameExtractionOptions::allow_upsample)
.def_readwrite("max_feature_vectors", &kaldi::FrameExtractionOptions::max_feature_vectors);
py::class_<kaldi::MelBanksOptions>(m, "MelBanksOptions")
.def(py::init<>())
.def_readwrite("num_bins", &kaldi::MelBanksOptions::num_bins)
.def_readwrite("low_freq", &kaldi::MelBanksOptions::low_freq)
.def_readwrite("high_freq", &kaldi::MelBanksOptions::high_freq)
.def_readwrite("vtln_low", &kaldi::MelBanksOptions::vtln_low)
.def_readwrite("vtln_high", &kaldi::MelBanksOptions::vtln_high)
.def_readwrite("debug_mel", &kaldi::MelBanksOptions::debug_mel)
.def_readwrite("htk_mode", &kaldi::MelBanksOptions::htk_mode);
py::class_<paddleaudio::kaldi::FbankOptions>(m, "FbankOptions")
.def(py::init<>())
.def_readwrite("use_energy", &paddleaudio::kaldi::FbankOptions::use_energy)
.def_readwrite("energy_floor", &paddleaudio::kaldi::FbankOptions::energy_floor)
.def_readwrite("raw_energy", &paddleaudio::kaldi::FbankOptions::raw_energy)
.def_readwrite("htk_compat", &paddleaudio::kaldi::FbankOptions::htk_compat)
.def_readwrite("use_log_fbank", &paddleaudio::kaldi::FbankOptions::use_log_fbank)
.def_readwrite("use_power", &paddleaudio::kaldi::FbankOptions::use_power);
#endif
}
......@@ -32,9 +32,9 @@ target_include_directories(kaldi-base PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
# kaldi-matrix
add_library(kaldi-matrix STATIC
matrix/compressed-matrix.cc
matrix/matrix-functions.cc
matrix/kaldi-matrix.cc
matrix/kaldi-vector.cc
matrix/matrix-functions.cc
matrix/optimization.cc
matrix/packed-matrix.cc
matrix/qr.cc
......@@ -65,13 +65,14 @@ target_link_libraries(kaldi-util PUBLIC kaldi-base kaldi-matrix)
# kaldi-feat-common
add_library(kaldi-feat-common STATIC
feat/wave-reader.cc
feat/signal.cc
feat/cmvn.cc
feat/feature-functions.cc
feat/feature-window.cc
feat/resample.cc
feat/mel-computations.cc
feat/cmvn.cc
feat/pitch-functions.cc
feat/resample.cc
feat/signal.cc
feat/wave-reader.cc
)
target_include_directories(kaldi-feat-common PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util)
......@@ -113,4 +114,4 @@ target_link_libraries(libkaldi INTERFACE
gfortran
-Wl,--no-whole-archive -Wl,--end-group
)
target_compile_definitions(libkaldi INTERFACE "-DCOMPILE_WITHOUT_OPENFST")
\ No newline at end of file
target_compile_definitions(libkaldi INTERFACE "-DCOMPILE_WITHOUT_OPENFST")
......@@ -27,7 +27,7 @@
#include "feat/feature-functions.h"
#include "feat/feature-window.h"
#include "feat/mel-computations.h"
#include "itf/options-itf.h"
#include "util/options-itf.h"
namespace kaldi {
/// @addtogroup feat FeatureExtraction
......
// feat/online-feature-itf.h
// Copyright 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FEAT_ONLINE_FEATURE_ITF_H_
#define KALDI_FEAT_ONLINE_FEATURE_ITF_H_ 1
#include "base/kaldi-common.h"
#include "matrix/matrix-lib.h"
namespace kaldi {
/// @ingroup Interfaces
/// @{
/**
OnlineFeatureInterface is an interface for online feature processing (it is
also usable in the offline setting, but currently we're not using it for
that). This is for use in the online2/ directory, and it supersedes the
interface in ../online/online-feat-input.h. We have a slightly different
model that puts more control in the hands of the calling thread, and won't
involve waiting on semaphores in the decoding thread.
This interface only specifies how the object *outputs* the features.
How it obtains the features, e.g. from a previous object or objects of type
OnlineFeatureInterface, is not specified in the interface and you will
likely define new constructors or methods in the derived type to do that.
You should appreciate that this interface is designed to allow random
access to features, as long as they are ready. That is, the user
can call GetFrame for any frame less than NumFramesReady(), and when
implementing a child class you must not make assumptions about the
order in which the user makes these calls.
*/
class OnlineFeatureInterface {
public:
virtual int32 Dim() const = 0; /// returns the feature dimension.
/// Returns the total number of frames, since the start of the utterance, that
/// are now available. In an online-decoding context, this will likely
/// increase with time as more data becomes available.
virtual int32 NumFramesReady() const = 0;
/// Returns true if this is the last frame. Frame indices are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). This function may return false for some frame if
/// we haven't yet decided to terminate decoding, but later true if we decide
/// to terminate decoding. This function exists mainly to correctly handle
/// end effects in feature extraction, and is not a mechanism to determine how
/// many frames are in the decodable object (as it used to be, and for backward
/// compatibility, still is, in the Decodable interface).
virtual bool IsLastFrame(int32 frame) const = 0;
/// Gets the feature vector for this frame. Before calling this for a given
/// frame, it is assumed that you called NumFramesReady() and it returned a
/// number greater than "frame". Otherwise this call will likely crash with
/// an assert failure. This function is not declared const, in case there is
/// some kind of caching going on, but most of the time it shouldn't modify
/// the class.
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat) = 0;
/// This is like GetFrame() but for a collection of frames. There is a
/// default implementation that just gets the frames one by one, but it
/// may be overridden for efficiency by child classes (since sometimes
/// it's more efficient to do things in a batch).
virtual void GetFrames(const std::vector<int32> &frames,
MatrixBase<BaseFloat> *feats) {
KALDI_ASSERT(static_cast<int32>(frames.size()) == feats->NumRows());
for (size_t i = 0; i < frames.size(); i++) {
SubVector<BaseFloat> feat(*feats, i);
GetFrame(frames[i], &feat);
}
}
// Returns frame shift in seconds. Helps to estimate duration from frame
// counts.
virtual BaseFloat FrameShiftInSeconds() const = 0;
/// Virtual destructor. Note: constructors that take another member of
/// type OnlineFeatureInterface are not expected to take ownership of
/// that pointer; the caller needs to keep track of that manually.
virtual ~OnlineFeatureInterface() { }
};
/// Add a virtual class for "source" features such as MFCC or PLP or pitch
/// features.
class OnlineBaseFeature: public OnlineFeatureInterface {
public:
/// This would be called from the application, when you get more wave data.
/// Note: the sampling_rate is typically only provided so the code can assert
/// that it matches the sampling rate expected in the options.
virtual void AcceptWaveform(BaseFloat sampling_rate,
const VectorBase<BaseFloat> &waveform) = 0;
/// InputFinished() tells the class you won't be providing any
/// more waveform. This will help flush out the last few frames
/// of delta or LDA features (it will typically affect the return value
/// of IsLastFrame.
virtual void InputFinished() = 0;
};
/// @}
} // namespace Kaldi
#endif // KALDI_ITF_ONLINE_FEATURE_ITF_H_
......@@ -34,7 +34,7 @@
#include "feat/feature-mfcc.h"
#include "feat/feature-plp.h"
#include "feat/feature-fbank.h"
#include "itf/online-feature-itf.h"
#include "feat/online-feature-itf.h"
namespace kaldi {
/// @addtogroup onlinefeat OnlineFeatureExtraction
......
......@@ -31,7 +31,7 @@
#include "base/kaldi-error.h"
#include "feat/mel-computations.h"
#include "itf/online-feature-itf.h"
#include "feat/online-feature-itf.h"
#include "matrix/matrix-lib.h"
#include "util/common-utils.h"
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddlespeech.audio.kaldi.fbank as fbank
import paddlespeech.audio.kaldi.pitch as pitch
from kaldiio import ReadHelper
# the groundtruth feats computed in kaldi command below.
#compute-fbank-feats --dither=0 scp:$wav_scp ark,t:fbank_feat.ark
#compute-kaldi-pitch-feats --sample-frequency=16000 scp:$wav_scp ark,t:pitch_feat.ark
class TestKaldiFbank(unittest.TestCase):
def test_fbank(self):
fbank_groundtruth = {}
with ReadHelper('ark:testdata/fbank_feat.ark') as reader:
for key, feat in reader:
fbank_groundtruth[key] = feat
with ReadHelper('ark:testdata/wav.ark') as reader:
for key, wav in reader:
fbank_feat = fbank(wav)
fbank_check = fbank_groundtruth[key]
np.testing.assert_array_almost_equal(
fbank_feat, fbank_check, decimal=4)
def test_pitch(self):
pitch_groundtruth = {}
with ReadHelper('ark:testdata/pitch_feat.ark') as reader:
for key, feat in reader:
pitch_groundtruth[key] = feat
with ReadHelper('ark:testdata/wav.ark') as reader:
for key, wav in reader:
pitch_feat = pitch(wav)
pitch_check = pitch_groundtruth[key]
np.testing.assert_array_almost_equal(
pitch_feat, pitch_check, decimal=4)
if __name__ == '__main__':
unittest.main()
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册